@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/{core/task → task}/task.d.ts +7 -7
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +24 -15
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/digest.d.ts +0 -0
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import { Map, Set } from 'immutable';
|
|
2
|
+
import type { client, Model, AsyncInformant } from '../index.js';
|
|
3
|
+
export declare enum AggregationStep {
|
|
4
|
+
ADD = 0,
|
|
5
|
+
UPDATE = 1,
|
|
6
|
+
AGGREGATE = 2
|
|
7
|
+
}
|
|
8
|
+
/**
|
|
9
|
+
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
10
|
+
* a result based off their aggregation, whenever some defined condition is met.
|
|
11
|
+
*/
|
|
12
|
+
export declare abstract class Base<T> {
|
|
13
|
+
/**
|
|
14
|
+
* The Model whose weights are updated on aggregation.
|
|
15
|
+
*/
|
|
16
|
+
protected _model?: Model | undefined;
|
|
17
|
+
/**
|
|
18
|
+
* The round cut-off for contributions.
|
|
19
|
+
*/
|
|
20
|
+
protected readonly roundCutoff: number;
|
|
21
|
+
/**
|
|
22
|
+
* The number of communication rounds occurring during any given aggregation round.
|
|
23
|
+
*/
|
|
24
|
+
readonly communicationRounds: number;
|
|
25
|
+
/**
|
|
26
|
+
* Contains the ids of all active nodes, i.e. members of the aggregation group at
|
|
27
|
+
* a given round. It is a subset of all the nodes available in the network.
|
|
28
|
+
*/
|
|
29
|
+
protected _nodes: Set<client.NodeID>;
|
|
30
|
+
/**
|
|
31
|
+
* Contains the contributions received from active nodes, accessible by node id.
|
|
32
|
+
* It defines the effective aggregation group, which is possibly a subset
|
|
33
|
+
* of all active nodes, depending on the aggregation scheme.
|
|
34
|
+
*/
|
|
35
|
+
protected contributions: Map<number, Map<client.NodeID, T>>;
|
|
36
|
+
/**
|
|
37
|
+
* Emits the aggregation event whenever an aggregation step is performed.
|
|
38
|
+
* Triggers the resolve of the result promise and the preparation for the
|
|
39
|
+
* next aggregation round.
|
|
40
|
+
*/
|
|
41
|
+
private readonly eventEmitter;
|
|
42
|
+
protected informant?: AsyncInformant<T>;
|
|
43
|
+
/**
|
|
44
|
+
* The result promise which, on resolve, will contain the current aggregation result.
|
|
45
|
+
* This promise should be fetched by any object making use of an aggregator, in order
|
|
46
|
+
* to await upon aggregation.
|
|
47
|
+
*/
|
|
48
|
+
protected result: Promise<T>;
|
|
49
|
+
/**
|
|
50
|
+
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
51
|
+
* or not.
|
|
52
|
+
*/
|
|
53
|
+
protected _round: number;
|
|
54
|
+
/**
|
|
55
|
+
* The current communication round. A single aggregation round is made of possibly multiple
|
|
56
|
+
* communication rounds. This makes the aggregator free to perform intermediate aggregation
|
|
57
|
+
* steps based off communication with its nodes. Overall, this allows for more complex
|
|
58
|
+
* aggregation schemes requiring an exchange of information between nodes before aggregating.
|
|
59
|
+
*/
|
|
60
|
+
protected _communicationRound: number;
|
|
61
|
+
constructor(
|
|
62
|
+
/**
|
|
63
|
+
* The Model whose weights are updated on aggregation.
|
|
64
|
+
*/
|
|
65
|
+
_model?: Model | undefined,
|
|
66
|
+
/**
|
|
67
|
+
* The round cut-off for contributions.
|
|
68
|
+
*/
|
|
69
|
+
roundCutoff?: number,
|
|
70
|
+
/**
|
|
71
|
+
* The number of communication rounds occurring during any given aggregation round.
|
|
72
|
+
*/
|
|
73
|
+
communicationRounds?: number);
|
|
74
|
+
/**
|
|
75
|
+
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
|
|
76
|
+
* The contribution will be aggregated during the next aggregation step.
|
|
77
|
+
* @param nodeId The node's id
|
|
78
|
+
* @param contribution The node's contribution
|
|
79
|
+
* @param round For which aggregation round the contribution was made
|
|
80
|
+
* @param communicationRound For which communication round the contribution was made
|
|
81
|
+
*/
|
|
82
|
+
abstract add(nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean;
|
|
83
|
+
/**
|
|
84
|
+
* Performs an aggregation step over the received node contributions.
|
|
85
|
+
* Must store the aggregation's result in the aggregator's result promise.
|
|
86
|
+
*/
|
|
87
|
+
abstract aggregate(): void;
|
|
88
|
+
registerObserver(informant: AsyncInformant<T>): void;
|
|
89
|
+
/**
|
|
90
|
+
* Returns whether the given round is recent enough, dependent on the
|
|
91
|
+
* aggregator's round cutoff.
|
|
92
|
+
* @param round The round
|
|
93
|
+
* @returns True if the round is recent enough, false otherwise
|
|
94
|
+
*/
|
|
95
|
+
isWithinRoundCutoff(round: number): boolean;
|
|
96
|
+
/**
|
|
97
|
+
* Logs useful messages during the various aggregation steps.
|
|
98
|
+
* @param step The aggregation step
|
|
99
|
+
* @param from The node which triggered the logging message
|
|
100
|
+
*/
|
|
101
|
+
log(step: AggregationStep, from?: client.NodeID): void;
|
|
102
|
+
/**
|
|
103
|
+
* Sets the aggregator's TF.js model.
|
|
104
|
+
* @param model The new TF.js model
|
|
105
|
+
*/
|
|
106
|
+
setModel(model: Model): void;
|
|
107
|
+
/**
|
|
108
|
+
* Adds a node's id to the set of active nodes. A node represents an active neighbor
|
|
109
|
+
* peer/client within the network, whom we are communicating with during this aggregation
|
|
110
|
+
* round.
|
|
111
|
+
* @param nodeId The node to be added
|
|
112
|
+
*/
|
|
113
|
+
registerNode(nodeId: client.NodeID): boolean;
|
|
114
|
+
/**
|
|
115
|
+
* Overwrites the current set of active nodes with the given one. A node represents
|
|
116
|
+
* an active neighbor peer/client within the network, whom we are communicating with
|
|
117
|
+
* during this aggregation round.
|
|
118
|
+
* @param nodeIds The new set of nodes
|
|
119
|
+
*/
|
|
120
|
+
setNodes(nodeIds: Set<client.NodeID>): void;
|
|
121
|
+
/**
|
|
122
|
+
* Empties the current set of "nodes". Usually called at the end of an aggregation round,
|
|
123
|
+
* if the set of nodes is meant to change or to be actualized.
|
|
124
|
+
*/
|
|
125
|
+
resetNodes(): void;
|
|
126
|
+
/**
|
|
127
|
+
* Sets the aggregator's round number. To be used whenever the aggregator is out of sync
|
|
128
|
+
* with the network's round.
|
|
129
|
+
* @param round The new round
|
|
130
|
+
*/
|
|
131
|
+
setRound(round: number): void;
|
|
132
|
+
/**
|
|
133
|
+
* Emits the event containing the aggregation result, which allows the result
|
|
134
|
+
* promise to resolve and for the next aggregation round to take place.
|
|
135
|
+
* @param aggregated The aggregation result
|
|
136
|
+
*/
|
|
137
|
+
protected emit(aggregated: T): void;
|
|
138
|
+
/**
|
|
139
|
+
* Updates the aggregator's state to proceed to the next communication round.
|
|
140
|
+
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
141
|
+
* and empties the collection of stored contributions.
|
|
142
|
+
*/
|
|
143
|
+
nextRound(): void;
|
|
144
|
+
private makeResult;
|
|
145
|
+
/**
|
|
146
|
+
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
|
|
147
|
+
* This function gives access to the current aggregation result's promise, which will
|
|
148
|
+
* eventually resolve and contain the result of the very next aggregation step, at the
|
|
149
|
+
* time of the function call.
|
|
150
|
+
* @returns The promise containing the aggregation result
|
|
151
|
+
*/
|
|
152
|
+
receiveResult(): Promise<T>;
|
|
153
|
+
/**
|
|
154
|
+
* Constructs the payloads sent to other nodes as contribution.
|
|
155
|
+
* @param base Object from which the payload is computed
|
|
156
|
+
*/
|
|
157
|
+
abstract makePayloads(base: T): Map<client.NodeID, T>;
|
|
158
|
+
abstract isFull(): boolean;
|
|
159
|
+
/**
|
|
160
|
+
* The set of node ids, representing our neighbors within the network.
|
|
161
|
+
*/
|
|
162
|
+
get nodes(): Set<client.NodeID>;
|
|
163
|
+
/**
|
|
164
|
+
* The aggregation round.
|
|
165
|
+
*/
|
|
166
|
+
get round(): number;
|
|
167
|
+
/**
|
|
168
|
+
* The aggregator's current size, defined by its number of contributions. The size is bounded by
|
|
169
|
+
* the amount of all active nodes times the number of communication rounds.
|
|
170
|
+
*/
|
|
171
|
+
get size(): number;
|
|
172
|
+
/**
|
|
173
|
+
* The aggregator's current model.
|
|
174
|
+
*/
|
|
175
|
+
get model(): Model | undefined;
|
|
176
|
+
/**
|
|
177
|
+
* The current communication round.
|
|
178
|
+
*/
|
|
179
|
+
get communicationRound(): number;
|
|
180
|
+
}
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
import { Map, Set } from 'immutable';
|
|
2
|
+
import { EventEmitter } from '../utils/event_emitter.js';
|
|
3
|
+
export var AggregationStep;
|
|
4
|
+
(function (AggregationStep) {
|
|
5
|
+
AggregationStep[AggregationStep["ADD"] = 0] = "ADD";
|
|
6
|
+
AggregationStep[AggregationStep["UPDATE"] = 1] = "UPDATE";
|
|
7
|
+
AggregationStep[AggregationStep["AGGREGATE"] = 2] = "AGGREGATE";
|
|
8
|
+
})(AggregationStep || (AggregationStep = {}));
|
|
9
|
+
/**
|
|
10
|
+
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
11
|
+
* a result based off their aggregation, whenever some defined condition is met.
|
|
12
|
+
*/
|
|
13
|
+
export class Base {
|
|
14
|
+
_model;
|
|
15
|
+
roundCutoff;
|
|
16
|
+
communicationRounds;
|
|
17
|
+
/**
|
|
18
|
+
* Contains the ids of all active nodes, i.e. members of the aggregation group at
|
|
19
|
+
* a given round. It is a subset of all the nodes available in the network.
|
|
20
|
+
*/
|
|
21
|
+
_nodes;
|
|
22
|
+
/**
|
|
23
|
+
* Contains the contributions received from active nodes, accessible by node id.
|
|
24
|
+
* It defines the effective aggregation group, which is possibly a subset
|
|
25
|
+
* of all active nodes, depending on the aggregation scheme.
|
|
26
|
+
*/
|
|
27
|
+
contributions;
|
|
28
|
+
/**
|
|
29
|
+
* Emits the aggregation event whenever an aggregation step is performed.
|
|
30
|
+
* Triggers the resolve of the result promise and the preparation for the
|
|
31
|
+
* next aggregation round.
|
|
32
|
+
*/
|
|
33
|
+
eventEmitter = new EventEmitter();
|
|
34
|
+
informant;
|
|
35
|
+
/**
|
|
36
|
+
* The result promise which, on resolve, will contain the current aggregation result.
|
|
37
|
+
* This promise should be fetched by any object making use of an aggregator, in order
|
|
38
|
+
* to await upon aggregation.
|
|
39
|
+
*/
|
|
40
|
+
result;
|
|
41
|
+
/**
|
|
42
|
+
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
43
|
+
* or not.
|
|
44
|
+
*/
|
|
45
|
+
_round = 0;
|
|
46
|
+
/**
|
|
47
|
+
* The current communication round. A single aggregation round is made of possibly multiple
|
|
48
|
+
* communication rounds. This makes the aggregator free to perform intermediate aggregation
|
|
49
|
+
* steps based off communication with its nodes. Overall, this allows for more complex
|
|
50
|
+
* aggregation schemes requiring an exchange of information between nodes before aggregating.
|
|
51
|
+
*/
|
|
52
|
+
_communicationRound = 0;
|
|
53
|
+
constructor(
|
|
54
|
+
/**
|
|
55
|
+
* The Model whose weights are updated on aggregation.
|
|
56
|
+
*/
|
|
57
|
+
_model,
|
|
58
|
+
/**
|
|
59
|
+
* The round cut-off for contributions.
|
|
60
|
+
*/
|
|
61
|
+
roundCutoff = 0,
|
|
62
|
+
/**
|
|
63
|
+
* The number of communication rounds occurring during any given aggregation round.
|
|
64
|
+
*/
|
|
65
|
+
communicationRounds = 1) {
|
|
66
|
+
this._model = _model;
|
|
67
|
+
this.roundCutoff = roundCutoff;
|
|
68
|
+
this.communicationRounds = communicationRounds;
|
|
69
|
+
this.contributions = Map();
|
|
70
|
+
this._nodes = Set();
|
|
71
|
+
// Make the initial result promise
|
|
72
|
+
this.result = this.makeResult();
|
|
73
|
+
// On every aggregation, update the object's state to match the current aggregation
|
|
74
|
+
// and communication rounds.
|
|
75
|
+
this.eventEmitter.on('aggregation', () => {
|
|
76
|
+
this.nextRound();
|
|
77
|
+
});
|
|
78
|
+
}
|
|
79
|
+
registerObserver(informant) {
|
|
80
|
+
this.informant = informant;
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Returns whether the given round is recent enough, dependent on the
|
|
84
|
+
* aggregator's round cutoff.
|
|
85
|
+
* @param round The round
|
|
86
|
+
* @returns True if the round is recent enough, false otherwise
|
|
87
|
+
*/
|
|
88
|
+
isWithinRoundCutoff(round) {
|
|
89
|
+
return this.round - round <= this.roundCutoff;
|
|
90
|
+
}
|
|
91
|
+
/**
|
|
92
|
+
* Logs useful messages during the various aggregation steps.
|
|
93
|
+
* @param step The aggregation step
|
|
94
|
+
* @param from The node which triggered the logging message
|
|
95
|
+
*/
|
|
96
|
+
log(step, from) {
|
|
97
|
+
switch (step) {
|
|
98
|
+
case AggregationStep.ADD:
|
|
99
|
+
console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
|
|
100
|
+
break;
|
|
101
|
+
case AggregationStep.UPDATE:
|
|
102
|
+
if (from === undefined) {
|
|
103
|
+
return;
|
|
104
|
+
}
|
|
105
|
+
console.log(`> Updating contribution from node ${from} for round (${this.communicationRound}, ${this.round})`);
|
|
106
|
+
break;
|
|
107
|
+
case AggregationStep.AGGREGATE:
|
|
108
|
+
console.log('*'.repeat(80));
|
|
109
|
+
console.log(`Buffer is full. Aggregating weights for round (${this.communicationRound}, ${this.round})\n`);
|
|
110
|
+
break;
|
|
111
|
+
default: {
|
|
112
|
+
const _ = step;
|
|
113
|
+
throw new Error('should never happen');
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
}
|
|
117
|
+
/**
|
|
118
|
+
* Sets the aggregator's TF.js model.
|
|
119
|
+
* @param model The new TF.js model
|
|
120
|
+
*/
|
|
121
|
+
setModel(model) {
|
|
122
|
+
this._model = model;
|
|
123
|
+
}
|
|
124
|
+
/**
|
|
125
|
+
* Adds a node's id to the set of active nodes. A node represents an active neighbor
|
|
126
|
+
* peer/client within the network, whom we are communicating with during this aggregation
|
|
127
|
+
* round.
|
|
128
|
+
* @param nodeId The node to be added
|
|
129
|
+
*/
|
|
130
|
+
registerNode(nodeId) {
|
|
131
|
+
if (!this.nodes.has(nodeId)) {
|
|
132
|
+
this._nodes = this._nodes.add(nodeId);
|
|
133
|
+
return true;
|
|
134
|
+
}
|
|
135
|
+
return false;
|
|
136
|
+
}
|
|
137
|
+
/**
|
|
138
|
+
* Overwrites the current set of active nodes with the given one. A node represents
|
|
139
|
+
* an active neighbor peer/client within the network, whom we are communicating with
|
|
140
|
+
* during this aggregation round.
|
|
141
|
+
* @param nodeIds The new set of nodes
|
|
142
|
+
*/
|
|
143
|
+
setNodes(nodeIds) {
|
|
144
|
+
this._nodes = nodeIds;
|
|
145
|
+
}
|
|
146
|
+
/**
|
|
147
|
+
* Empties the current set of "nodes". Usually called at the end of an aggregation round,
|
|
148
|
+
* if the set of nodes is meant to change or to be actualized.
|
|
149
|
+
*/
|
|
150
|
+
resetNodes() {
|
|
151
|
+
this._nodes = Set();
|
|
152
|
+
}
|
|
153
|
+
/**
|
|
154
|
+
* Sets the aggregator's round number. To be used whenever the aggregator is out of sync
|
|
155
|
+
* with the network's round.
|
|
156
|
+
* @param round The new round
|
|
157
|
+
*/
|
|
158
|
+
setRound(round) {
|
|
159
|
+
if (round > this.round) {
|
|
160
|
+
this._round = round;
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
/**
|
|
164
|
+
* Emits the event containing the aggregation result, which allows the result
|
|
165
|
+
* promise to resolve and for the next aggregation round to take place.
|
|
166
|
+
* @param aggregated The aggregation result
|
|
167
|
+
*/
|
|
168
|
+
emit(aggregated) {
|
|
169
|
+
this.eventEmitter.emit('aggregation', aggregated);
|
|
170
|
+
}
|
|
171
|
+
/**
|
|
172
|
+
* Updates the aggregator's state to proceed to the next communication round.
|
|
173
|
+
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
174
|
+
* and empties the collection of stored contributions.
|
|
175
|
+
*/
|
|
176
|
+
nextRound() {
|
|
177
|
+
if (++this._communicationRound === this.communicationRounds) {
|
|
178
|
+
this._communicationRound = 0;
|
|
179
|
+
this._round++;
|
|
180
|
+
this.contributions = Map();
|
|
181
|
+
}
|
|
182
|
+
this.result = this.makeResult();
|
|
183
|
+
this.informant?.update();
|
|
184
|
+
}
|
|
185
|
+
async makeResult() {
|
|
186
|
+
return await new Promise((resolve) => {
|
|
187
|
+
this.eventEmitter.once('aggregation', (w) => {
|
|
188
|
+
resolve(w);
|
|
189
|
+
});
|
|
190
|
+
});
|
|
191
|
+
}
|
|
192
|
+
/**
|
|
193
|
+
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
|
|
194
|
+
* This function gives access to the current aggregation result's promise, which will
|
|
195
|
+
* eventually resolve and contain the result of the very next aggregation step, at the
|
|
196
|
+
* time of the function call.
|
|
197
|
+
* @returns The promise containing the aggregation result
|
|
198
|
+
*/
|
|
199
|
+
async receiveResult() {
|
|
200
|
+
return await this.result;
|
|
201
|
+
}
|
|
202
|
+
/**
|
|
203
|
+
* The set of node ids, representing our neighbors within the network.
|
|
204
|
+
*/
|
|
205
|
+
get nodes() {
|
|
206
|
+
return this._nodes;
|
|
207
|
+
}
|
|
208
|
+
/**
|
|
209
|
+
* The aggregation round.
|
|
210
|
+
*/
|
|
211
|
+
get round() {
|
|
212
|
+
return this._round;
|
|
213
|
+
}
|
|
214
|
+
/**
|
|
215
|
+
* The aggregator's current size, defined by its number of contributions. The size is bounded by
|
|
216
|
+
* the amount of all active nodes times the number of communication rounds.
|
|
217
|
+
*/
|
|
218
|
+
get size() {
|
|
219
|
+
return this.contributions
|
|
220
|
+
.valueSeq()
|
|
221
|
+
.map((m) => m.size)
|
|
222
|
+
.reduce((totalSize, size) => totalSize + size) ?? 0;
|
|
223
|
+
}
|
|
224
|
+
/**
|
|
225
|
+
* The aggregator's current model.
|
|
226
|
+
*/
|
|
227
|
+
get model() {
|
|
228
|
+
return this._model;
|
|
229
|
+
}
|
|
230
|
+
/**
|
|
231
|
+
* The current communication round.
|
|
232
|
+
*/
|
|
233
|
+
get communicationRound() {
|
|
234
|
+
return this._communicationRound;
|
|
235
|
+
}
|
|
236
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import type { Task } from '../index.js';
|
|
2
|
+
import { aggregator } from '../index.js';
|
|
3
|
+
/**
|
|
4
|
+
* Enumeration of the available types of aggregator.
|
|
5
|
+
*/
|
|
6
|
+
export declare enum AggregatorChoice {
|
|
7
|
+
MEAN = 0,
|
|
8
|
+
SECURE = 1,
|
|
9
|
+
BANDIT = 2
|
|
10
|
+
}
|
|
11
|
+
/**
|
|
12
|
+
* Provides the aggregator object adequate to the given task.
|
|
13
|
+
* @param task The task
|
|
14
|
+
* @returns The aggregator
|
|
15
|
+
*/
|
|
16
|
+
export declare function getAggregator(task: Task): aggregator.Aggregator;
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import { aggregator } from '../index.js';
|
|
2
|
+
/**
|
|
3
|
+
* Enumeration of the available types of aggregator.
|
|
4
|
+
*/
|
|
5
|
+
export var AggregatorChoice;
|
|
6
|
+
(function (AggregatorChoice) {
|
|
7
|
+
AggregatorChoice[AggregatorChoice["MEAN"] = 0] = "MEAN";
|
|
8
|
+
AggregatorChoice[AggregatorChoice["SECURE"] = 1] = "SECURE";
|
|
9
|
+
AggregatorChoice[AggregatorChoice["BANDIT"] = 2] = "BANDIT";
|
|
10
|
+
})(AggregatorChoice || (AggregatorChoice = {}));
|
|
11
|
+
/**
|
|
12
|
+
* Provides the aggregator object adequate to the given task.
|
|
13
|
+
* @param task The task
|
|
14
|
+
* @returns The aggregator
|
|
15
|
+
*/
|
|
16
|
+
export function getAggregator(task) {
|
|
17
|
+
const error = new Error('not implemented');
|
|
18
|
+
switch (task.trainingInformation.aggregator) {
|
|
19
|
+
case AggregatorChoice.MEAN:
|
|
20
|
+
return new aggregator.MeanAggregator();
|
|
21
|
+
case AggregatorChoice.BANDIT:
|
|
22
|
+
throw error;
|
|
23
|
+
case AggregatorChoice.SECURE:
|
|
24
|
+
if (task.trainingInformation.scheme !== 'decentralized') {
|
|
25
|
+
throw new Error('secure aggregation is currently supported for decentralized only');
|
|
26
|
+
}
|
|
27
|
+
return new aggregator.SecureAggregator();
|
|
28
|
+
default:
|
|
29
|
+
return new aggregator.MeanAggregator();
|
|
30
|
+
}
|
|
31
|
+
}
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import type { WeightsContainer } from '../weights/index.js';
|
|
2
|
+
import type { Base } from './base.js';
|
|
3
|
+
export { Base as AggregatorBase, AggregationStep } from './base.js';
|
|
4
|
+
export { MeanAggregator } from './mean.js';
|
|
5
|
+
export { SecureAggregator } from './secure.js';
|
|
6
|
+
export { getAggregator, AggregatorChoice } from './get.js';
|
|
7
|
+
export type Aggregator = Base<WeightsContainer>;
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import type { Map } from 'immutable';
|
|
2
|
+
import { Base as Aggregator } from './base.js';
|
|
3
|
+
import type { Model, WeightsContainer, client } from '../index.js';
|
|
4
|
+
/**
|
|
5
|
+
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
6
|
+
*/
|
|
7
|
+
export declare class MeanAggregator extends Aggregator<WeightsContainer> {
|
|
8
|
+
/**
|
|
9
|
+
* The threshold t to fulfill to trigger an aggregation step. It can either be:
|
|
10
|
+
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
|
|
11
|
+
* - absolute: t > 1, thus requiring t contributions
|
|
12
|
+
*/
|
|
13
|
+
readonly threshold: number;
|
|
14
|
+
constructor(model?: Model, roundCutoff?: number, threshold?: number);
|
|
15
|
+
/**
|
|
16
|
+
* Checks whether the contributions buffer is full, according to the set threshold.
|
|
17
|
+
* @returns Whether the contributions buffer is full
|
|
18
|
+
*/
|
|
19
|
+
isFull(): boolean;
|
|
20
|
+
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number): boolean;
|
|
21
|
+
aggregate(): void;
|
|
22
|
+
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
23
|
+
}
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import { AggregationStep, Base as Aggregator } from './base.js';
|
|
2
|
+
import { aggregation } from '../index.js';
|
|
3
|
+
/**
|
|
4
|
+
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
5
|
+
*/
|
|
6
|
+
export class MeanAggregator extends Aggregator {
|
|
7
|
+
/**
|
|
8
|
+
* The threshold t to fulfill to trigger an aggregation step. It can either be:
|
|
9
|
+
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
|
|
10
|
+
* - absolute: t > 1, thus requiring t contributions
|
|
11
|
+
*/
|
|
12
|
+
threshold;
|
|
13
|
+
constructor(model, roundCutoff = 0, threshold = 1) {
|
|
14
|
+
super(model, roundCutoff, 1);
|
|
15
|
+
// Default threshold is 100% of node participation
|
|
16
|
+
if (threshold === undefined) {
|
|
17
|
+
this.threshold = 1;
|
|
18
|
+
// Threshold must be positive
|
|
19
|
+
}
|
|
20
|
+
else if (threshold <= 0) {
|
|
21
|
+
throw new Error('threshold must be positive');
|
|
22
|
+
// Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
|
|
23
|
+
}
|
|
24
|
+
else if (threshold > 1 && Math.round(threshold) !== threshold) {
|
|
25
|
+
throw new Error('absolute thresholds must integers');
|
|
26
|
+
}
|
|
27
|
+
else {
|
|
28
|
+
this.threshold = threshold;
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
/**
|
|
32
|
+
* Checks whether the contributions buffer is full, according to the set threshold.
|
|
33
|
+
* @returns Whether the contributions buffer is full
|
|
34
|
+
*/
|
|
35
|
+
isFull() {
|
|
36
|
+
if (this.threshold <= 1) {
|
|
37
|
+
const contribs = this.contributions.get(this.communicationRound);
|
|
38
|
+
if (contribs === undefined) {
|
|
39
|
+
return false;
|
|
40
|
+
}
|
|
41
|
+
return contribs.size >= this.threshold * this.nodes.size;
|
|
42
|
+
}
|
|
43
|
+
return this.contributions.size >= this.threshold;
|
|
44
|
+
}
|
|
45
|
+
add(nodeId, contribution, round) {
|
|
46
|
+
if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
|
|
47
|
+
this.log(this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
|
|
48
|
+
this.contributions = this.contributions.setIn([0, nodeId], contribution);
|
|
49
|
+
this.informant?.update();
|
|
50
|
+
if (this.isFull()) {
|
|
51
|
+
this.aggregate();
|
|
52
|
+
}
|
|
53
|
+
return true;
|
|
54
|
+
}
|
|
55
|
+
return false;
|
|
56
|
+
}
|
|
57
|
+
aggregate() {
|
|
58
|
+
this.log(AggregationStep.AGGREGATE);
|
|
59
|
+
const result = aggregation.avg(this.contributions.get(0)?.values());
|
|
60
|
+
if (this.model !== undefined) {
|
|
61
|
+
this.model.weights = result;
|
|
62
|
+
}
|
|
63
|
+
this.emit(result);
|
|
64
|
+
}
|
|
65
|
+
makePayloads(weights) {
|
|
66
|
+
// Communicate our local weights to every other node, be it a peer or a server
|
|
67
|
+
return this.nodes.toMap().map(() => weights);
|
|
68
|
+
}
|
|
69
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import { Map, List } from 'immutable';
|
|
2
|
+
import { Base as Aggregator } from './base.js';
|
|
3
|
+
import type { Model, WeightsContainer, client } from '../index.js';
|
|
4
|
+
/**
|
|
5
|
+
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
6
|
+
* An aggregation consists of two communication rounds:
|
|
7
|
+
* - first, nodes communicate their secret shares to each other;
|
|
8
|
+
* - then, they sum their received shares and communicate the result.
|
|
9
|
+
* Finally, nodes are able to average the received partial sums to establish the aggregation result.
|
|
10
|
+
*/
|
|
11
|
+
export declare class SecureAggregator extends Aggregator<WeightsContainer> {
|
|
12
|
+
private readonly maxShareValue;
|
|
13
|
+
constructor(model?: Model, maxShareValue?: number);
|
|
14
|
+
aggregate(): void;
|
|
15
|
+
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound: number): boolean;
|
|
16
|
+
isFull(): boolean;
|
|
17
|
+
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
18
|
+
/**
|
|
19
|
+
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
|
|
20
|
+
*/
|
|
21
|
+
generateAllShares(secret: WeightsContainer): List<WeightsContainer>;
|
|
22
|
+
/**
|
|
23
|
+
* Generates one share in the same shape as the secret that is populated with values randomly chosen from
|
|
24
|
+
* a uniform distribution between (-maxShareValue, maxShareValue).
|
|
25
|
+
*/
|
|
26
|
+
generateRandomShare(secret: WeightsContainer): WeightsContainer;
|
|
27
|
+
}
|