@epfml/discojs 3.0.1-p20241007204240.0 → 3.0.1-p20241024094708.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
- package/dist/aggregator/{base.js → aggregator.js} +48 -36
- package/dist/aggregator/get.d.ts +2 -2
- package/dist/aggregator/get.js +4 -4
- package/dist/aggregator/index.d.ts +1 -4
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +4 -4
- package/dist/aggregator/mean.js +5 -15
- package/dist/aggregator/secure.d.ts +4 -4
- package/dist/aggregator/secure.js +7 -17
- package/dist/client/client.d.ts +71 -17
- package/dist/client/client.js +118 -17
- package/dist/client/decentralized/decentralized_client.d.ts +11 -13
- package/dist/client/decentralized/decentralized_client.js +121 -84
- package/dist/client/decentralized/messages.d.ts +12 -6
- package/dist/client/decentralized/messages.js +9 -8
- package/dist/client/event_connection.js +2 -2
- package/dist/client/federated/federated_client.d.ts +1 -13
- package/dist/client/federated/federated_client.js +15 -94
- package/dist/client/federated/messages.d.ts +6 -11
- package/dist/client/local_client.d.ts +1 -0
- package/dist/client/local_client.js +3 -0
- package/dist/client/messages.d.ts +14 -7
- package/dist/client/messages.js +13 -11
- package/dist/default_tasks/cifar10.js +1 -1
- package/dist/default_tasks/lus_covid.js +1 -0
- package/dist/default_tasks/mnist.js +1 -1
- package/dist/default_tasks/simple_face.js +1 -0
- package/dist/default_tasks/titanic.js +1 -0
- package/dist/default_tasks/wikitext.js +1 -0
- package/dist/index.d.ts +0 -2
- package/dist/serialization/coder.d.ts +4 -0
- package/dist/serialization/coder.js +51 -0
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +1 -0
- package/dist/serialization/model.d.ts +1 -2
- package/dist/serialization/model.js +9 -24
- package/dist/serialization/weights.d.ts +2 -3
- package/dist/serialization/weights.js +15 -26
- package/dist/task/task_handler.d.ts +5 -5
- package/dist/task/task_handler.js +21 -15
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +6 -8
- package/dist/training/disco.d.ts +4 -1
- package/dist/training/trainer.js +1 -1
- package/dist/utils/event_emitter.d.ts +3 -3
- package/dist/utils/event_emitter.js +10 -9
- package/package.json +2 -3
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { Map, Set } from 'immutable';
|
|
2
|
-
import type { client } from '../index.js';
|
|
2
|
+
import type { client, WeightsContainer } from '../index.js';
|
|
3
3
|
import { EventEmitter } from '../utils/event_emitter.js';
|
|
4
4
|
export declare enum AggregationStep {
|
|
5
5
|
ADD = 0,
|
|
@@ -10,11 +10,11 @@ export declare enum AggregationStep {
|
|
|
10
10
|
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
11
11
|
* a result based off their aggregation, whenever some defined condition is met.
|
|
12
12
|
*
|
|
13
|
-
* Emits an event whenever an aggregation step is performed.
|
|
14
|
-
* Users
|
|
13
|
+
* Emits an event whenever an aggregation step is performed with the counrd's aggregated weights.
|
|
14
|
+
* Users subscribes to this event to get the aggregation result.
|
|
15
15
|
*/
|
|
16
|
-
export declare abstract class
|
|
17
|
-
'aggregation':
|
|
16
|
+
export declare abstract class Aggregator extends EventEmitter<{
|
|
17
|
+
'aggregation': WeightsContainer;
|
|
18
18
|
}> {
|
|
19
19
|
/**
|
|
20
20
|
* The round cut-off for contributions.
|
|
@@ -34,7 +34,7 @@ export declare abstract class Base<T> extends EventEmitter<{
|
|
|
34
34
|
* It defines the effective aggregation group, which is possibly a subset
|
|
35
35
|
* of all active nodes, depending on the aggregation scheme.
|
|
36
36
|
*/
|
|
37
|
-
protected contributions: Map<number, Map<client.NodeID,
|
|
37
|
+
protected contributions: Map<number, Map<client.NodeID, WeightsContainer>>;
|
|
38
38
|
/**
|
|
39
39
|
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
40
40
|
* or not.
|
|
@@ -56,36 +56,45 @@ export declare abstract class Base<T> extends EventEmitter<{
|
|
|
56
56
|
* The number of communication rounds occurring during any given aggregation round.
|
|
57
57
|
*/
|
|
58
58
|
communicationRounds?: number);
|
|
59
|
+
/**
|
|
60
|
+
* Convenience method to subscribe to the 'aggregation' event.
|
|
61
|
+
* Await this promise returns the aggregated weights for the current round.
|
|
62
|
+
*
|
|
63
|
+
* @returns a promise for the aggregated weights
|
|
64
|
+
*/
|
|
65
|
+
getPromiseForAggregation(): Promise<WeightsContainer>;
|
|
59
66
|
/**
|
|
60
67
|
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
|
|
61
68
|
* The aggregation round is increased whenever a new global model is obtained and local models are updated.
|
|
62
69
|
* Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
|
|
63
|
-
*
|
|
64
|
-
* The contribution
|
|
70
|
+
* which requires multiple steps to obtain a global model)
|
|
71
|
+
* The contribution is aggregated during the next aggregation step.
|
|
72
|
+
*
|
|
65
73
|
* @param nodeId The node's id
|
|
66
74
|
* @param contribution The node's contribution
|
|
67
|
-
* @param round aggregation round of the contribution was made
|
|
68
|
-
* @param communicationRound communication round the contribution was made within the aggregation round
|
|
69
|
-
* @returns boolean, true if the contribution has been successfully taken into account or False if it has been rejected
|
|
70
75
|
*/
|
|
71
|
-
|
|
76
|
+
add(nodeId: client.NodeID, contribution: WeightsContainer, aggregationRound: number, communicationRound?: number): void;
|
|
77
|
+
protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void;
|
|
72
78
|
/**
|
|
73
79
|
* Evaluates whether a given participant contribution can be used in the current aggregation round
|
|
74
80
|
* the boolean returned by `this.add` is obtained via `this.isValidContribution`
|
|
81
|
+
*
|
|
82
|
+
* @param nodeId the node id of the contribution to be added
|
|
83
|
+
* @param round the aggregation round of the contribution to be added
|
|
75
84
|
*/
|
|
76
85
|
isValidContribution(nodeId: client.NodeID, round: number): boolean;
|
|
77
86
|
/**
|
|
78
87
|
* Performs an aggregation step over the received node contributions.
|
|
79
88
|
* Must store the aggregation's result in the aggregator's result promise.
|
|
80
89
|
*/
|
|
81
|
-
abstract aggregate():
|
|
90
|
+
protected abstract aggregate(): WeightsContainer;
|
|
82
91
|
/**
|
|
83
92
|
* Returns whether the given round is recent enough, dependent on the
|
|
84
93
|
* aggregator's round cutoff.
|
|
85
94
|
* @param round The round
|
|
86
95
|
* @returns True if the round is recent enough, false otherwise
|
|
87
96
|
*/
|
|
88
|
-
isWithinRoundCutoff
|
|
97
|
+
private isWithinRoundCutoff;
|
|
89
98
|
/**
|
|
90
99
|
* Logs useful messages during the various aggregation steps.
|
|
91
100
|
* @param step The aggregation step
|
|
@@ -112,28 +121,17 @@ export declare abstract class Base<T> extends EventEmitter<{
|
|
|
112
121
|
* @param nodeIds The new set of nodes
|
|
113
122
|
*/
|
|
114
123
|
setNodes(nodeIds: Set<client.NodeID>): void;
|
|
115
|
-
/**
|
|
116
|
-
* Empties the current set of "nodes". Usually called at the end of an aggregation round,
|
|
117
|
-
* if the set of nodes is meant to change or to be actualized.
|
|
118
|
-
*/
|
|
119
|
-
resetNodes(): void;
|
|
120
124
|
/**
|
|
121
125
|
* Sets the aggregator's round number. To be used whenever the aggregator is out of sync
|
|
122
126
|
* with the network's round.
|
|
123
127
|
* @param round The new round
|
|
124
128
|
*/
|
|
125
129
|
setRound(round: number): void;
|
|
126
|
-
/**
|
|
127
|
-
* Updates the aggregator's state to proceed to the next communication round.
|
|
128
|
-
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
129
|
-
* and empties the collection of stored contributions.
|
|
130
|
-
*/
|
|
131
|
-
nextRound(): void;
|
|
132
130
|
/**
|
|
133
131
|
* Constructs the payloads sent to other nodes as contribution.
|
|
134
132
|
* @param base Object from which the payload is computed
|
|
135
133
|
*/
|
|
136
|
-
abstract makePayloads(base:
|
|
134
|
+
abstract makePayloads(base: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
137
135
|
abstract isFull(): boolean;
|
|
138
136
|
/**
|
|
139
137
|
* The set of node ids, representing our neighbors within the network.
|
|
@@ -143,11 +141,6 @@ export declare abstract class Base<T> extends EventEmitter<{
|
|
|
143
141
|
* The aggregation round.
|
|
144
142
|
*/
|
|
145
143
|
get round(): number;
|
|
146
|
-
/**
|
|
147
|
-
* The aggregator's current size, defined by its number of contributions. The size is bounded by
|
|
148
|
-
* the amount of all active nodes times the number of communication rounds.
|
|
149
|
-
*/
|
|
150
|
-
get size(): number;
|
|
151
144
|
/**
|
|
152
145
|
* The current communication round.
|
|
153
146
|
*/
|
|
@@ -12,10 +12,10 @@ export var AggregationStep;
|
|
|
12
12
|
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
13
13
|
* a result based off their aggregation, whenever some defined condition is met.
|
|
14
14
|
*
|
|
15
|
-
* Emits an event whenever an aggregation step is performed.
|
|
16
|
-
* Users
|
|
15
|
+
* Emits an event whenever an aggregation step is performed with the counrd's aggregated weights.
|
|
16
|
+
* Users subscribes to this event to get the aggregation result.
|
|
17
17
|
*/
|
|
18
|
-
export class
|
|
18
|
+
export class Aggregator extends EventEmitter {
|
|
19
19
|
roundCutoff;
|
|
20
20
|
communicationRounds;
|
|
21
21
|
/**
|
|
@@ -28,7 +28,7 @@ export class Base extends EventEmitter {
|
|
|
28
28
|
* It defines the effective aggregation group, which is possibly a subset
|
|
29
29
|
* of all active nodes, depending on the aggregation scheme.
|
|
30
30
|
*/
|
|
31
|
-
// communication round -> NodeID ->
|
|
31
|
+
// communication round -> NodeID -> WeightsContainer
|
|
32
32
|
contributions;
|
|
33
33
|
/**
|
|
34
34
|
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
@@ -56,13 +56,54 @@ export class Base extends EventEmitter {
|
|
|
56
56
|
this.communicationRounds = communicationRounds;
|
|
57
57
|
this.contributions = Map();
|
|
58
58
|
this._nodes = Set();
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
59
|
+
}
|
|
60
|
+
/**
|
|
61
|
+
* Convenience method to subscribe to the 'aggregation' event.
|
|
62
|
+
* Await this promise returns the aggregated weights for the current round.
|
|
63
|
+
*
|
|
64
|
+
* @returns a promise for the aggregated weights
|
|
65
|
+
*/
|
|
66
|
+
getPromiseForAggregation() {
|
|
67
|
+
return new Promise((resolve) => this.once('aggregation', resolve));
|
|
68
|
+
}
|
|
69
|
+
/**
|
|
70
|
+
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
|
|
71
|
+
* The aggregation round is increased whenever a new global model is obtained and local models are updated.
|
|
72
|
+
* Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
|
|
73
|
+
* which requires multiple steps to obtain a global model)
|
|
74
|
+
* The contribution is aggregated during the next aggregation step.
|
|
75
|
+
*
|
|
76
|
+
* @param nodeId The node's id
|
|
77
|
+
* @param contribution The node's contribution
|
|
78
|
+
*/
|
|
79
|
+
add(nodeId, contribution, aggregationRound, communicationRound) {
|
|
80
|
+
if (!this.isValidContribution(nodeId, aggregationRound))
|
|
81
|
+
throw new Error("Tried adding an invalid contribution. Handle this case before calling add.");
|
|
82
|
+
// call the abstract method _add, implemented by subclasses
|
|
83
|
+
this._add(nodeId, contribution, communicationRound);
|
|
84
|
+
// If the aggregator has enough contributions then aggregate the weights
|
|
85
|
+
// and emit the 'aggregation' event
|
|
86
|
+
if (this.isFull()) {
|
|
87
|
+
const aggregatedWeights = this.aggregate();
|
|
88
|
+
// On each aggregation, increment the communication round
|
|
89
|
+
// If all communication rounds were performed, proceed to the next aggregation round
|
|
90
|
+
// and empty the past contributions.
|
|
91
|
+
this._communicationRound++;
|
|
92
|
+
if (this.communicationRound === this.communicationRounds) {
|
|
93
|
+
this._communicationRound = 0;
|
|
94
|
+
this._round++;
|
|
95
|
+
this.contributions = Map();
|
|
96
|
+
}
|
|
97
|
+
// Emitting the 'aggregation' communicates the weights to subscribers
|
|
98
|
+
this.emit('aggregation', aggregatedWeights);
|
|
99
|
+
}
|
|
62
100
|
}
|
|
63
101
|
/**
|
|
64
102
|
* Evaluates whether a given participant contribution can be used in the current aggregation round
|
|
65
103
|
* the boolean returned by `this.add` is obtained via `this.isValidContribution`
|
|
104
|
+
*
|
|
105
|
+
* @param nodeId the node id of the contribution to be added
|
|
106
|
+
* @param round the aggregation round of the contribution to be added
|
|
66
107
|
*/
|
|
67
108
|
isValidContribution(nodeId, round) {
|
|
68
109
|
if (!this.nodes.has(nodeId)) {
|
|
@@ -139,13 +180,6 @@ export class Base extends EventEmitter {
|
|
|
139
180
|
setNodes(nodeIds) {
|
|
140
181
|
this._nodes = nodeIds;
|
|
141
182
|
}
|
|
142
|
-
/**
|
|
143
|
-
* Empties the current set of "nodes". Usually called at the end of an aggregation round,
|
|
144
|
-
* if the set of nodes is meant to change or to be actualized.
|
|
145
|
-
*/
|
|
146
|
-
resetNodes() {
|
|
147
|
-
this._nodes = Set();
|
|
148
|
-
}
|
|
149
183
|
/**
|
|
150
184
|
* Sets the aggregator's round number. To be used whenever the aggregator is out of sync
|
|
151
185
|
* with the network's round.
|
|
@@ -156,18 +190,6 @@ export class Base extends EventEmitter {
|
|
|
156
190
|
this._round = round;
|
|
157
191
|
}
|
|
158
192
|
}
|
|
159
|
-
/**
|
|
160
|
-
* Updates the aggregator's state to proceed to the next communication round.
|
|
161
|
-
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
162
|
-
* and empties the collection of stored contributions.
|
|
163
|
-
*/
|
|
164
|
-
nextRound() {
|
|
165
|
-
if (++this._communicationRound === this.communicationRounds) {
|
|
166
|
-
this._communicationRound = 0;
|
|
167
|
-
this._round++;
|
|
168
|
-
this.contributions = Map();
|
|
169
|
-
}
|
|
170
|
-
}
|
|
171
193
|
/**
|
|
172
194
|
* The set of node ids, representing our neighbors within the network.
|
|
173
195
|
*/
|
|
@@ -180,16 +202,6 @@ export class Base extends EventEmitter {
|
|
|
180
202
|
get round() {
|
|
181
203
|
return this._round;
|
|
182
204
|
}
|
|
183
|
-
/**
|
|
184
|
-
* The aggregator's current size, defined by its number of contributions. The size is bounded by
|
|
185
|
-
* the amount of all active nodes times the number of communication rounds.
|
|
186
|
-
*/
|
|
187
|
-
get size() {
|
|
188
|
-
return this.contributions
|
|
189
|
-
.valueSeq()
|
|
190
|
-
.map((m) => m.size)
|
|
191
|
-
.reduce((totalSize, size) => totalSize + size) ?? 0;
|
|
192
|
-
}
|
|
193
205
|
/**
|
|
194
206
|
* The current communication round.
|
|
195
207
|
*/
|
package/dist/aggregator/get.d.ts
CHANGED
|
@@ -9,9 +9,9 @@ type AggregatorOptions = Partial<{
|
|
|
9
9
|
/**
|
|
10
10
|
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
|
|
11
11
|
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
|
|
12
|
-
* task.trainingInformation.
|
|
12
|
+
* task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme
|
|
13
13
|
*
|
|
14
|
-
* If `task.trainingInformation.
|
|
14
|
+
* If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values.
|
|
15
15
|
* Otherwise, we default to a MeanAggregator for both training schemes.
|
|
16
16
|
*
|
|
17
17
|
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
|
package/dist/aggregator/get.js
CHANGED
|
@@ -2,9 +2,9 @@ import { aggregator } from '../index.js';
|
|
|
2
2
|
/**
|
|
3
3
|
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
|
|
4
4
|
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
|
|
5
|
-
* task.trainingInformation.
|
|
5
|
+
* task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme
|
|
6
6
|
*
|
|
7
|
-
* If `task.trainingInformation.
|
|
7
|
+
* If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values.
|
|
8
8
|
* Otherwise, we default to a MeanAggregator for both training schemes.
|
|
9
9
|
*
|
|
10
10
|
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
|
|
@@ -18,9 +18,9 @@ import { aggregator } from '../index.js';
|
|
|
18
18
|
* @returns The aggregator
|
|
19
19
|
*/
|
|
20
20
|
export function getAggregator(task, options = {}) {
|
|
21
|
-
const
|
|
21
|
+
const aggregationStrategy = task.trainingInformation.aggregationStrategy ?? 'mean';
|
|
22
22
|
const scheme = options.scheme ?? task.trainingInformation.scheme;
|
|
23
|
-
switch (
|
|
23
|
+
switch (aggregationStrategy) {
|
|
24
24
|
case 'mean':
|
|
25
25
|
if (scheme === 'decentralized') {
|
|
26
26
|
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
|
|
@@ -1,7 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
import type { Base } from './base.js';
|
|
3
|
-
export { Base as AggregatorBase, AggregationStep } from './base.js';
|
|
1
|
+
export { Aggregator, AggregationStep } from './aggregator.js';
|
|
4
2
|
export { MeanAggregator } from './mean.js';
|
|
5
3
|
export { SecureAggregator } from './secure.js';
|
|
6
4
|
export { getAggregator } from './get.js';
|
|
7
|
-
export type Aggregator = Base<WeightsContainer>;
|
package/dist/aggregator/index.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
export {
|
|
1
|
+
export { Aggregator, AggregationStep } from './aggregator.js';
|
|
2
2
|
export { MeanAggregator } from './mean.js';
|
|
3
3
|
export { SecureAggregator } from './secure.js';
|
|
4
4
|
export { getAggregator } from './get.js';
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
import type { Map } from "immutable";
|
|
2
|
-
import {
|
|
2
|
+
import { Aggregator } from "./aggregator.js";
|
|
3
3
|
import type { WeightsContainer, client } from "../index.js";
|
|
4
4
|
type ThresholdType = 'relative' | 'absolute';
|
|
5
5
|
/**
|
|
6
6
|
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
7
7
|
*
|
|
8
8
|
*/
|
|
9
|
-
export declare class MeanAggregator extends Aggregator
|
|
9
|
+
export declare class MeanAggregator extends Aggregator {
|
|
10
10
|
#private;
|
|
11
11
|
/**
|
|
12
12
|
* Create a mean aggregator that averages all weight updates received when a specified threshold is met.
|
|
@@ -31,8 +31,8 @@ export declare class MeanAggregator extends Aggregator<WeightsContainer> {
|
|
|
31
31
|
/** Checks whether the contributions buffer is full. */
|
|
32
32
|
isFull(): boolean;
|
|
33
33
|
set minNbOfParticipants(minNbOfParticipants: number);
|
|
34
|
-
|
|
35
|
-
aggregate():
|
|
34
|
+
_add(nodeId: client.NodeID, contribution: WeightsContainer): void;
|
|
35
|
+
aggregate(): WeightsContainer;
|
|
36
36
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
37
37
|
}
|
|
38
38
|
export {};
|
package/dist/aggregator/mean.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
|
-
import { AggregationStep,
|
|
2
|
+
import { AggregationStep, Aggregator } from "./aggregator.js";
|
|
3
3
|
import { aggregation } from "../index.js";
|
|
4
4
|
const debug = createDebug("discojs:aggregator:mean");
|
|
5
5
|
/**
|
|
@@ -54,7 +54,7 @@ export class MeanAggregator extends Aggregator {
|
|
|
54
54
|
else {
|
|
55
55
|
// Print a warning regarding the default behavior when thresholdType is not specified
|
|
56
56
|
if (thresholdType === undefined) {
|
|
57
|
-
// TODO enforce validity by splitting
|
|
57
|
+
// TODO enforce validity by splitting the different threshold types into separate classes instead of warning
|
|
58
58
|
debug("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
|
|
59
59
|
"To instead wait for a single contribution, set thresholdType = 'absolute'");
|
|
60
60
|
this.#thresholdType = 'relative';
|
|
@@ -79,18 +79,9 @@ export class MeanAggregator extends Aggregator {
|
|
|
79
79
|
set minNbOfParticipants(minNbOfParticipants) {
|
|
80
80
|
this.#minNbOfParticipants = minNbOfParticipants;
|
|
81
81
|
}
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
throw new Error("only a single communication round");
|
|
85
|
-
if (!this.isValidContribution(nodeId, round))
|
|
86
|
-
return false;
|
|
87
|
-
this.log(this.contributions.hasIn([0, nodeId])
|
|
88
|
-
? AggregationStep.UPDATE
|
|
89
|
-
: AggregationStep.ADD, nodeId);
|
|
82
|
+
_add(nodeId, contribution) {
|
|
83
|
+
this.log(this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
|
|
90
84
|
this.contributions = this.contributions.setIn([0, nodeId], contribution);
|
|
91
|
-
if (this.isFull())
|
|
92
|
-
this.aggregate();
|
|
93
|
-
return true;
|
|
94
85
|
}
|
|
95
86
|
aggregate() {
|
|
96
87
|
const currentContributions = this.contributions.get(0);
|
|
@@ -98,8 +89,7 @@ export class MeanAggregator extends Aggregator {
|
|
|
98
89
|
throw new Error("aggregating without any contribution");
|
|
99
90
|
this.log(AggregationStep.AGGREGATE);
|
|
100
91
|
const result = aggregation.avg(currentContributions.values());
|
|
101
|
-
|
|
102
|
-
this.emit('aggregation', result);
|
|
92
|
+
return result;
|
|
103
93
|
}
|
|
104
94
|
makePayloads(weights) {
|
|
105
95
|
// Communicate our local weights to every other node, be it a peer or a server
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { Map, List } from "immutable";
|
|
2
|
-
import {
|
|
2
|
+
import { Aggregator } from "./aggregator.js";
|
|
3
3
|
import type { WeightsContainer, client } from "../index.js";
|
|
4
4
|
/**
|
|
5
5
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
@@ -8,11 +8,11 @@ import type { WeightsContainer, client } from "../index.js";
|
|
|
8
8
|
* - then, they sum their received shares and communicate the result.
|
|
9
9
|
* Finally, nodes are able to average the received partial sums to establish the aggregation result.
|
|
10
10
|
*/
|
|
11
|
-
export declare class SecureAggregator extends Aggregator
|
|
11
|
+
export declare class SecureAggregator extends Aggregator {
|
|
12
12
|
private readonly maxShareValue;
|
|
13
13
|
constructor(maxShareValue?: number);
|
|
14
|
-
aggregate():
|
|
15
|
-
|
|
14
|
+
aggregate(): WeightsContainer;
|
|
15
|
+
_add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound: number): void;
|
|
16
16
|
isFull(): boolean;
|
|
17
17
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
18
18
|
/** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { Map, List, Range } from "immutable";
|
|
2
2
|
import * as tf from "@tensorflow/tfjs";
|
|
3
|
-
import { AggregationStep,
|
|
3
|
+
import { AggregationStep, Aggregator } from "./aggregator.js";
|
|
4
4
|
import { aggregation } from "../index.js";
|
|
5
5
|
/**
|
|
6
6
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
@@ -23,24 +23,20 @@ export class SecureAggregator extends Aggregator {
|
|
|
23
23
|
const currentContributions = this.contributions.get(0);
|
|
24
24
|
if (currentContributions === undefined)
|
|
25
25
|
throw new Error("aggregating without any contribution");
|
|
26
|
-
|
|
27
|
-
this.emit('aggregation', result);
|
|
28
|
-
break;
|
|
26
|
+
return aggregation.sum(currentContributions.values());
|
|
29
27
|
}
|
|
30
28
|
// Average the received partial sums
|
|
31
29
|
case 1: {
|
|
32
30
|
const currentContributions = this.contributions.get(1);
|
|
33
31
|
if (currentContributions === undefined)
|
|
34
32
|
throw new Error("aggregating without any contribution");
|
|
35
|
-
|
|
36
|
-
this.emit('aggregation', result);
|
|
37
|
-
break;
|
|
33
|
+
return aggregation.avg(currentContributions.values());
|
|
38
34
|
}
|
|
39
35
|
default:
|
|
40
36
|
throw new Error("communication round is out of bounds");
|
|
41
37
|
}
|
|
42
38
|
}
|
|
43
|
-
|
|
39
|
+
_add(nodeId, contribution, communicationRound) {
|
|
44
40
|
switch (communicationRound) {
|
|
45
41
|
case 0:
|
|
46
42
|
case 1:
|
|
@@ -48,15 +44,9 @@ export class SecureAggregator extends Aggregator {
|
|
|
48
44
|
default:
|
|
49
45
|
throw new Error("requires communication round to be 0 or 1");
|
|
50
46
|
}
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
this.log(this.contributions.hasIn([communicationRound, nodeId])
|
|
54
|
-
? AggregationStep.UPDATE
|
|
55
|
-
: AggregationStep.ADD, nodeId);
|
|
47
|
+
this.log(this.contributions.hasIn([communicationRound, nodeId]) ?
|
|
48
|
+
AggregationStep.UPDATE : AggregationStep.ADD, nodeId.slice(0, 4));
|
|
56
49
|
this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
|
|
57
|
-
if (this.isFull())
|
|
58
|
-
this.aggregate();
|
|
59
|
-
return true;
|
|
60
50
|
}
|
|
61
51
|
isFull() {
|
|
62
52
|
return ((this.contributions.get(this.communicationRound)?.size ?? 0) ===
|
|
@@ -66,7 +56,7 @@ export class SecureAggregator extends Aggregator {
|
|
|
66
56
|
switch (this.communicationRound) {
|
|
67
57
|
case 0: {
|
|
68
58
|
const shares = this.generateAllShares(weights);
|
|
69
|
-
//
|
|
59
|
+
// Arbitrarily assign our shares to the available nodes
|
|
70
60
|
return Map(List(this.nodes).zip(shares));
|
|
71
61
|
}
|
|
72
62
|
// Send our partial sum to every other nodes
|
package/dist/client/client.d.ts
CHANGED
|
@@ -13,21 +13,35 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
13
13
|
readonly url: URL;
|
|
14
14
|
readonly task: Task;
|
|
15
15
|
readonly aggregator: Aggregator;
|
|
16
|
-
/**
|
|
17
|
-
* Own ID provided by the network's server.
|
|
18
|
-
*/
|
|
19
16
|
protected _ownId?: NodeID;
|
|
17
|
+
protected _server?: EventConnection;
|
|
18
|
+
protected aggregationResult?: Promise<WeightsContainer>;
|
|
20
19
|
/**
|
|
21
|
-
*
|
|
20
|
+
* When the server notifies clients to pause and wait until more
|
|
21
|
+
* participants join, we rely on this promise to wait
|
|
22
|
+
* until the server signals that the training can resume
|
|
22
23
|
*/
|
|
23
|
-
protected
|
|
24
|
+
protected promiseForMoreParticipants: Promise<void> | undefined;
|
|
24
25
|
/**
|
|
25
|
-
*
|
|
26
|
+
* When the server notifies the client that they can resume training
|
|
27
|
+
* after waiting for more participants, we want to be able to display what
|
|
28
|
+
* we were doing before waiting (training locally or updating our model).
|
|
29
|
+
* We use this attribute to store the status to rollback to when we stop waiting
|
|
26
30
|
*/
|
|
27
|
-
|
|
31
|
+
private previousStatus;
|
|
28
32
|
constructor(url: URL, // The network server's URL to connect to
|
|
29
33
|
task: Task, // The client's corresponding task
|
|
30
34
|
aggregator: Aggregator);
|
|
35
|
+
/**
|
|
36
|
+
* Communication callback called at the beginning of every training round.
|
|
37
|
+
*/
|
|
38
|
+
abstract onRoundBeginCommunication(): Promise<void>;
|
|
39
|
+
/**
|
|
40
|
+
* Communication callback called the end of every training round.
|
|
41
|
+
* @param weights The local weight update resulting for the current local training round
|
|
42
|
+
* @returns aggregated weights or the local weights upon error
|
|
43
|
+
*/
|
|
44
|
+
abstract onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
31
45
|
/**
|
|
32
46
|
* Handles the connection process from the client to any sort of network server.
|
|
33
47
|
* This method is overriden by the federated and decentralized clients
|
|
@@ -39,21 +53,61 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
39
53
|
*/
|
|
40
54
|
disconnect(): Promise<void>;
|
|
41
55
|
/**
|
|
42
|
-
*
|
|
43
|
-
*
|
|
56
|
+
* Emits the round status specified. It also stores the status emitted such that
|
|
57
|
+
* if the server tells the client to wait for more participants, it can display
|
|
58
|
+
* the waiting status and once enough participants join, it can display the previous status again
|
|
44
59
|
*/
|
|
45
|
-
|
|
60
|
+
protected saveAndEmit(status: RoundStatus): void;
|
|
46
61
|
/**
|
|
47
|
-
*
|
|
62
|
+
* For both federated and decentralized clients, we listen to the server to tell
|
|
63
|
+
* us whether there are enough participants to train. If not, we pause until further notice.
|
|
64
|
+
* When a client connects to the server, the server answers with the session information (id,
|
|
65
|
+
* number of participants) and whether there are enough participants.
|
|
66
|
+
* When there are the server sends a new EnoughParticipant message to update the client.
|
|
67
|
+
*
|
|
68
|
+
* `setMessageInversionFlag` is used to address the following scenario:
|
|
69
|
+
* 1. Client 1 connect to the server
|
|
70
|
+
* 2. Server answers with message A containing "not enough participants"
|
|
71
|
+
* 3. Before A arrives a new client joins. There are enough participants now.
|
|
72
|
+
* 4. Server updates client 1 with message B saying "there are enough participants"
|
|
73
|
+
* 5. Due to network and message sizes message B can arrive before A.
|
|
74
|
+
* i.e. "there are enough participants" arrives before "not enough participants"
|
|
75
|
+
* ending up with client 1 thinking it needs to wait for more participants.
|
|
76
|
+
*
|
|
77
|
+
* To keep track of this message inversion, `setMessageInversionFlag`
|
|
78
|
+
* tells us whether a message inversion occurred (by setting a boolean to true)
|
|
79
|
+
*
|
|
80
|
+
* @param setMessageInversionFlag function flagging whether a message inversion occurred
|
|
81
|
+
* between a NewNodeInfo message and an EnoughParticipant message.
|
|
48
82
|
*/
|
|
49
|
-
|
|
83
|
+
protected setupServerCallbacks(setMessageInversionFlag: () => void): void;
|
|
50
84
|
/**
|
|
51
|
-
*
|
|
52
|
-
*
|
|
53
|
-
*
|
|
85
|
+
* Method called when the server notifies the client that there aren't enough
|
|
86
|
+
* participants (anymore) to start/continue training
|
|
87
|
+
* The method creates a promise that will resolve once the server notifies
|
|
88
|
+
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
89
|
+
* @returns a promise which resolves when enough participants joined the session
|
|
54
90
|
*/
|
|
55
|
-
|
|
56
|
-
|
|
91
|
+
protected createPromiseForMoreParticipants(): Promise<void>;
|
|
92
|
+
protected waitForParticipantsIfNeeded(): Promise<void>;
|
|
93
|
+
/**
|
|
94
|
+
* Fetches the latest model available on the network's server, for the adequate task.
|
|
95
|
+
* @returns The latest model
|
|
96
|
+
*/
|
|
97
|
+
getLatestModel(): Promise<Model>;
|
|
98
|
+
/**
|
|
99
|
+
* Number of contributors to a collaborative session
|
|
100
|
+
* If decentralized, it should be the number of peers
|
|
101
|
+
* If federated, it should the number of participants excluding the server
|
|
102
|
+
* If local it should be 1
|
|
103
|
+
*/
|
|
104
|
+
abstract getNbOfParticipants(): number;
|
|
57
105
|
get ownId(): NodeID;
|
|
58
106
|
get server(): EventConnection;
|
|
107
|
+
/**
|
|
108
|
+
* Whether the client should wait until more
|
|
109
|
+
* participants join the session, i.e. a promise has been created
|
|
110
|
+
*/
|
|
111
|
+
get waitingForMoreParticipants(): boolean;
|
|
59
112
|
}
|
|
113
|
+
export declare function shortenId(id: string): string;
|