@epfml/discojs 2.2.1 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +9 -48
- package/dist/aggregator/base.js +8 -69
- package/dist/aggregator/get.d.ts +23 -11
- package/dist/aggregator/get.js +40 -23
- package/dist/aggregator/index.d.ts +1 -1
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +25 -6
- package/dist/aggregator/mean.js +62 -17
- package/dist/aggregator/secure.d.ts +2 -2
- package/dist/aggregator/secure.js +4 -7
- package/dist/client/base.d.ts +3 -3
- package/dist/client/base.js +6 -8
- package/dist/client/decentralized/base.d.ts +27 -10
- package/dist/client/decentralized/base.js +123 -86
- package/dist/client/decentralized/peer.js +7 -12
- package/dist/client/decentralized/peer_pool.js +6 -2
- package/dist/client/event_connection.d.ts +1 -1
- package/dist/client/event_connection.js +3 -3
- package/dist/client/federated/base.d.ts +5 -21
- package/dist/client/federated/base.js +38 -61
- package/dist/client/federated/messages.d.ts +2 -10
- package/dist/client/federated/messages.js +0 -1
- package/dist/client/index.d.ts +1 -1
- package/dist/client/index.js +1 -1
- package/dist/client/local.d.ts +3 -1
- package/dist/client/local.js +4 -1
- package/dist/client/messages.d.ts +1 -2
- package/dist/client/messages.js +8 -3
- package/dist/client/utils.d.ts +4 -2
- package/dist/client/utils.js +18 -3
- package/dist/dataset/data/data.d.ts +1 -1
- package/dist/dataset/data/data.js +13 -2
- package/dist/dataset/data/preprocessing/image_preprocessing.js +6 -4
- package/dist/default_tasks/cifar10.js +1 -2
- package/dist/default_tasks/lus_covid.js +0 -5
- package/dist/default_tasks/mnist.js +15 -14
- package/dist/default_tasks/simple_face.js +0 -2
- package/dist/default_tasks/titanic.js +2 -4
- package/dist/default_tasks/wikitext.js +7 -1
- package/dist/index.d.ts +0 -1
- package/dist/index.js +0 -1
- package/dist/models/gpt/config.js +1 -1
- package/dist/privacy.d.ts +8 -10
- package/dist/privacy.js +25 -40
- package/dist/task/task_handler.js +10 -2
- package/dist/task/training_information.d.ts +7 -4
- package/dist/task/training_information.js +25 -6
- package/dist/training/disco.d.ts +30 -28
- package/dist/training/disco.js +75 -73
- package/dist/training/index.d.ts +1 -1
- package/dist/training/index.js +1 -0
- package/dist/training/trainer.d.ts +16 -0
- package/dist/training/trainer.js +72 -0
- package/dist/types.d.ts +0 -2
- package/dist/weights/weights_container.d.ts +0 -5
- package/dist/weights/weights_container.js +0 -7
- package/package.json +1 -1
- package/dist/async_informant.d.ts +0 -15
- package/dist/async_informant.js +0 -42
- package/dist/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/training/trainer/distributed_trainer.js +0 -41
- package/dist/training/trainer/local_trainer.d.ts +0 -12
- package/dist/training/trainer/local_trainer.js +0 -24
- package/dist/training/trainer/trainer.d.ts +0 -32
- package/dist/training/trainer/trainer.js +0 -61
- package/dist/training/trainer/trainer_builder.d.ts +0 -23
- package/dist/training/trainer/trainer_builder.js +0 -47
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import { Map, Set } from 'immutable';
|
|
2
|
-
import type { client
|
|
2
|
+
import type { client } from '../index.js';
|
|
3
|
+
import { EventEmitter } from '../utils/event_emitter.js';
|
|
3
4
|
export declare enum AggregationStep {
|
|
4
5
|
ADD = 0,
|
|
5
6
|
UPDATE = 1,
|
|
@@ -8,12 +9,13 @@ export declare enum AggregationStep {
|
|
|
8
9
|
/**
|
|
9
10
|
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
10
11
|
* a result based off their aggregation, whenever some defined condition is met.
|
|
12
|
+
*
|
|
13
|
+
* Emits an event whenever an aggregation step is performed.
|
|
14
|
+
* Users wait for this event to fetch the aggregation result.
|
|
11
15
|
*/
|
|
12
|
-
export declare abstract class Base<T> {
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
*/
|
|
16
|
-
protected _model?: Model | undefined;
|
|
16
|
+
export declare abstract class Base<T> extends EventEmitter<{
|
|
17
|
+
'aggregation': T;
|
|
18
|
+
}> {
|
|
17
19
|
/**
|
|
18
20
|
* The round cut-off for contributions.
|
|
19
21
|
*/
|
|
@@ -33,19 +35,6 @@ export declare abstract class Base<T> {
|
|
|
33
35
|
* of all active nodes, depending on the aggregation scheme.
|
|
34
36
|
*/
|
|
35
37
|
protected contributions: Map<number, Map<client.NodeID, T>>;
|
|
36
|
-
/**
|
|
37
|
-
* Emits the aggregation event whenever an aggregation step is performed.
|
|
38
|
-
* Triggers the resolve of the result promise and the preparation for the
|
|
39
|
-
* next aggregation round.
|
|
40
|
-
*/
|
|
41
|
-
private readonly eventEmitter;
|
|
42
|
-
protected informant?: AsyncInformant<T>;
|
|
43
|
-
/**
|
|
44
|
-
* The result promise which, on resolve, will contain the current aggregation result.
|
|
45
|
-
* This promise should be fetched by any object making use of an aggregator, in order
|
|
46
|
-
* to await upon aggregation.
|
|
47
|
-
*/
|
|
48
|
-
protected result: Promise<T>;
|
|
49
38
|
/**
|
|
50
39
|
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
51
40
|
* or not.
|
|
@@ -59,10 +48,6 @@ export declare abstract class Base<T> {
|
|
|
59
48
|
*/
|
|
60
49
|
protected _communicationRound: number;
|
|
61
50
|
constructor(
|
|
62
|
-
/**
|
|
63
|
-
* The Model whose weights are updated on aggregation.
|
|
64
|
-
*/
|
|
65
|
-
_model?: Model | undefined,
|
|
66
51
|
/**
|
|
67
52
|
* The round cut-off for contributions.
|
|
68
53
|
*/
|
|
@@ -85,7 +70,6 @@ export declare abstract class Base<T> {
|
|
|
85
70
|
* Must store the aggregation's result in the aggregator's result promise.
|
|
86
71
|
*/
|
|
87
72
|
abstract aggregate(): void;
|
|
88
|
-
registerObserver(informant: AsyncInformant<T>): void;
|
|
89
73
|
/**
|
|
90
74
|
* Returns whether the given round is recent enough, dependent on the
|
|
91
75
|
* aggregator's round cutoff.
|
|
@@ -99,16 +83,12 @@ export declare abstract class Base<T> {
|
|
|
99
83
|
* @param from The node which triggered the logging message
|
|
100
84
|
*/
|
|
101
85
|
log(step: AggregationStep, from?: client.NodeID): void;
|
|
102
|
-
/**
|
|
103
|
-
* Sets the aggregator's TF.js model.
|
|
104
|
-
* @param model The new TF.js model
|
|
105
|
-
*/
|
|
106
|
-
setModel(model: Model): void;
|
|
107
86
|
/**
|
|
108
87
|
* Adds a node's id to the set of active nodes. A node represents an active neighbor
|
|
109
88
|
* peer/client within the network, whom we are communicating with during this aggregation
|
|
110
89
|
* round.
|
|
111
90
|
* @param nodeId The node to be added
|
|
91
|
+
* @returns True is the node wasn't already in the list of nodes, False if already included
|
|
112
92
|
*/
|
|
113
93
|
registerNode(nodeId: client.NodeID): boolean;
|
|
114
94
|
/**
|
|
@@ -129,27 +109,12 @@ export declare abstract class Base<T> {
|
|
|
129
109
|
* @param round The new round
|
|
130
110
|
*/
|
|
131
111
|
setRound(round: number): void;
|
|
132
|
-
/**
|
|
133
|
-
* Emits the event containing the aggregation result, which allows the result
|
|
134
|
-
* promise to resolve and for the next aggregation round to take place.
|
|
135
|
-
* @param aggregated The aggregation result
|
|
136
|
-
*/
|
|
137
|
-
protected emit(aggregated: T): void;
|
|
138
112
|
/**
|
|
139
113
|
* Updates the aggregator's state to proceed to the next communication round.
|
|
140
114
|
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
141
115
|
* and empties the collection of stored contributions.
|
|
142
116
|
*/
|
|
143
117
|
nextRound(): void;
|
|
144
|
-
private makeResult;
|
|
145
|
-
/**
|
|
146
|
-
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
|
|
147
|
-
* This function gives access to the current aggregation result's promise, which will
|
|
148
|
-
* eventually resolve and contain the result of the very next aggregation step, at the
|
|
149
|
-
* time of the function call.
|
|
150
|
-
* @returns The promise containing the aggregation result
|
|
151
|
-
*/
|
|
152
|
-
receiveResult(): Promise<T>;
|
|
153
118
|
/**
|
|
154
119
|
* Constructs the payloads sent to other nodes as contribution.
|
|
155
120
|
* @param base Object from which the payload is computed
|
|
@@ -169,10 +134,6 @@ export declare abstract class Base<T> {
|
|
|
169
134
|
* the amount of all active nodes times the number of communication rounds.
|
|
170
135
|
*/
|
|
171
136
|
get size(): number;
|
|
172
|
-
/**
|
|
173
|
-
* The aggregator's current model.
|
|
174
|
-
*/
|
|
175
|
-
get model(): Model | undefined;
|
|
176
137
|
/**
|
|
177
138
|
* The current communication round.
|
|
178
139
|
*/
|
package/dist/aggregator/base.js
CHANGED
|
@@ -9,9 +9,11 @@ export var AggregationStep;
|
|
|
9
9
|
/**
|
|
10
10
|
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
11
11
|
* a result based off their aggregation, whenever some defined condition is met.
|
|
12
|
+
*
|
|
13
|
+
* Emits an event whenever an aggregation step is performed.
|
|
14
|
+
* Users wait for this event to fetch the aggregation result.
|
|
12
15
|
*/
|
|
13
|
-
export class Base {
|
|
14
|
-
_model;
|
|
16
|
+
export class Base extends EventEmitter {
|
|
15
17
|
roundCutoff;
|
|
16
18
|
communicationRounds;
|
|
17
19
|
/**
|
|
@@ -26,19 +28,6 @@ export class Base {
|
|
|
26
28
|
*/
|
|
27
29
|
// communication round -> NodeID -> T
|
|
28
30
|
contributions;
|
|
29
|
-
/**
|
|
30
|
-
* Emits the aggregation event whenever an aggregation step is performed.
|
|
31
|
-
* Triggers the resolve of the result promise and the preparation for the
|
|
32
|
-
* next aggregation round.
|
|
33
|
-
*/
|
|
34
|
-
eventEmitter = new EventEmitter();
|
|
35
|
-
informant;
|
|
36
|
-
/**
|
|
37
|
-
* The result promise which, on resolve, will contain the current aggregation result.
|
|
38
|
-
* This promise should be fetched by any object making use of an aggregator, in order
|
|
39
|
-
* to await upon aggregation.
|
|
40
|
-
*/
|
|
41
|
-
result;
|
|
42
31
|
/**
|
|
43
32
|
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
44
33
|
* or not.
|
|
@@ -52,10 +41,6 @@ export class Base {
|
|
|
52
41
|
*/
|
|
53
42
|
_communicationRound = 0;
|
|
54
43
|
constructor(
|
|
55
|
-
/**
|
|
56
|
-
* The Model whose weights are updated on aggregation.
|
|
57
|
-
*/
|
|
58
|
-
_model,
|
|
59
44
|
/**
|
|
60
45
|
* The round cut-off for contributions.
|
|
61
46
|
*/
|
|
@@ -64,21 +49,14 @@ export class Base {
|
|
|
64
49
|
* The number of communication rounds occurring during any given aggregation round.
|
|
65
50
|
*/
|
|
66
51
|
communicationRounds = 1) {
|
|
67
|
-
|
|
52
|
+
super();
|
|
68
53
|
this.roundCutoff = roundCutoff;
|
|
69
54
|
this.communicationRounds = communicationRounds;
|
|
70
55
|
this.contributions = Map();
|
|
71
56
|
this._nodes = Set();
|
|
72
|
-
// Make the initial result promise
|
|
73
|
-
this.result = this.makeResult();
|
|
74
57
|
// On every aggregation, update the object's state to match the current aggregation
|
|
75
58
|
// and communication rounds.
|
|
76
|
-
this.
|
|
77
|
-
this.nextRound();
|
|
78
|
-
});
|
|
79
|
-
}
|
|
80
|
-
registerObserver(informant) {
|
|
81
|
-
this.informant = informant;
|
|
59
|
+
this.on('aggregation', () => this.nextRound());
|
|
82
60
|
}
|
|
83
61
|
/**
|
|
84
62
|
* Returns whether the given round is recent enough, dependent on the
|
|
@@ -97,7 +75,7 @@ export class Base {
|
|
|
97
75
|
log(step, from) {
|
|
98
76
|
switch (step) {
|
|
99
77
|
case AggregationStep.ADD:
|
|
100
|
-
console.log(
|
|
78
|
+
console.log(`Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
|
|
101
79
|
break;
|
|
102
80
|
case AggregationStep.UPDATE:
|
|
103
81
|
if (from === undefined) {
|
|
@@ -115,18 +93,12 @@ export class Base {
|
|
|
115
93
|
}
|
|
116
94
|
}
|
|
117
95
|
}
|
|
118
|
-
/**
|
|
119
|
-
* Sets the aggregator's TF.js model.
|
|
120
|
-
* @param model The new TF.js model
|
|
121
|
-
*/
|
|
122
|
-
setModel(model) {
|
|
123
|
-
this._model = model;
|
|
124
|
-
}
|
|
125
96
|
/**
|
|
126
97
|
* Adds a node's id to the set of active nodes. A node represents an active neighbor
|
|
127
98
|
* peer/client within the network, whom we are communicating with during this aggregation
|
|
128
99
|
* round.
|
|
129
100
|
* @param nodeId The node to be added
|
|
101
|
+
* @returns True is the node wasn't already in the list of nodes, False if already included
|
|
130
102
|
*/
|
|
131
103
|
registerNode(nodeId) {
|
|
132
104
|
if (!this.nodes.has(nodeId)) {
|
|
@@ -161,14 +133,6 @@ export class Base {
|
|
|
161
133
|
this._round = round;
|
|
162
134
|
}
|
|
163
135
|
}
|
|
164
|
-
/**
|
|
165
|
-
* Emits the event containing the aggregation result, which allows the result
|
|
166
|
-
* promise to resolve and for the next aggregation round to take place.
|
|
167
|
-
* @param aggregated The aggregation result
|
|
168
|
-
*/
|
|
169
|
-
emit(aggregated) {
|
|
170
|
-
this.eventEmitter.emit('aggregation', aggregated);
|
|
171
|
-
}
|
|
172
136
|
/**
|
|
173
137
|
* Updates the aggregator's state to proceed to the next communication round.
|
|
174
138
|
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
@@ -180,25 +144,6 @@ export class Base {
|
|
|
180
144
|
this._round++;
|
|
181
145
|
this.contributions = Map();
|
|
182
146
|
}
|
|
183
|
-
this.result = this.makeResult();
|
|
184
|
-
this.informant?.update();
|
|
185
|
-
}
|
|
186
|
-
async makeResult() {
|
|
187
|
-
return await new Promise((resolve) => {
|
|
188
|
-
this.eventEmitter.once('aggregation', (w) => {
|
|
189
|
-
resolve(w);
|
|
190
|
-
});
|
|
191
|
-
});
|
|
192
|
-
}
|
|
193
|
-
/**
|
|
194
|
-
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
|
|
195
|
-
* This function gives access to the current aggregation result's promise, which will
|
|
196
|
-
* eventually resolve and contain the result of the very next aggregation step, at the
|
|
197
|
-
* time of the function call.
|
|
198
|
-
* @returns The promise containing the aggregation result
|
|
199
|
-
*/
|
|
200
|
-
async receiveResult() {
|
|
201
|
-
return await this.result;
|
|
202
147
|
}
|
|
203
148
|
/**
|
|
204
149
|
* The set of node ids, representing our neighbors within the network.
|
|
@@ -222,12 +167,6 @@ export class Base {
|
|
|
222
167
|
.map((m) => m.size)
|
|
223
168
|
.reduce((totalSize, size) => totalSize + size) ?? 0;
|
|
224
169
|
}
|
|
225
|
-
/**
|
|
226
|
-
* The aggregator's current model.
|
|
227
|
-
*/
|
|
228
|
-
get model() {
|
|
229
|
-
return this._model;
|
|
230
|
-
}
|
|
231
170
|
/**
|
|
232
171
|
* The current communication round.
|
|
233
172
|
*/
|
package/dist/aggregator/get.d.ts
CHANGED
|
@@ -1,16 +1,28 @@
|
|
|
1
1
|
import type { Task } from '../index.js';
|
|
2
2
|
import { aggregator } from '../index.js';
|
|
3
|
+
type AggregatorOptions = Partial<{
|
|
4
|
+
scheme: Task['trainingInformation']['scheme'];
|
|
5
|
+
roundCutOff: number;
|
|
6
|
+
threshold: number;
|
|
7
|
+
thresholdType: 'relative' | 'absolute';
|
|
8
|
+
}>;
|
|
3
9
|
/**
|
|
4
|
-
*
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
*
|
|
13
|
-
*
|
|
10
|
+
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
|
|
11
|
+
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
|
|
12
|
+
* task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
|
|
13
|
+
*
|
|
14
|
+
* If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
|
|
15
|
+
* Otherwise, we default to a MeanAggregator for both training schemes.
|
|
16
|
+
*
|
|
17
|
+
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
|
|
18
|
+
* Unless specified otherwise, for federated learning or local training the aggregator default to waiting
|
|
19
|
+
* for a single contribution to trigger a model update.
|
|
20
|
+
* (the server's model update for federated learning or our own contribution if training locally)
|
|
21
|
+
* For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
|
|
22
|
+
*
|
|
23
|
+
* @param task The task object associated with the current training session
|
|
24
|
+
* @param options Options passed down to the aggregator's constructor
|
|
14
25
|
* @returns The aggregator
|
|
15
26
|
*/
|
|
16
|
-
export declare function getAggregator(task: Task): aggregator.Aggregator;
|
|
27
|
+
export declare function getAggregator(task: Task, options?: AggregatorOptions): aggregator.Aggregator;
|
|
28
|
+
export {};
|
package/dist/aggregator/get.js
CHANGED
|
@@ -1,31 +1,48 @@
|
|
|
1
1
|
import { aggregator } from '../index.js';
|
|
2
2
|
/**
|
|
3
|
-
*
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
*
|
|
13
|
-
*
|
|
3
|
+
* Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
|
|
4
|
+
* Here is the ordered list of parameters used to define the aggregator and its default behavior:
|
|
5
|
+
* task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
|
|
6
|
+
*
|
|
7
|
+
* If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
|
|
8
|
+
* Otherwise, we default to a MeanAggregator for both training schemes.
|
|
9
|
+
*
|
|
10
|
+
* For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
|
|
11
|
+
* Unless specified otherwise, for federated learning or local training the aggregator default to waiting
|
|
12
|
+
* for a single contribution to trigger a model update.
|
|
13
|
+
* (the server's model update for federated learning or our own contribution if training locally)
|
|
14
|
+
* For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
|
|
15
|
+
*
|
|
16
|
+
* @param task The task object associated with the current training session
|
|
17
|
+
* @param options Options passed down to the aggregator's constructor
|
|
14
18
|
* @returns The aggregator
|
|
15
19
|
*/
|
|
16
|
-
export function getAggregator(task) {
|
|
17
|
-
const
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
20
|
+
export function getAggregator(task, options = {}) {
|
|
21
|
+
const aggregatorType = task.trainingInformation.aggregator ?? 'mean';
|
|
22
|
+
const scheme = options.scheme ?? task.trainingInformation.scheme;
|
|
23
|
+
switch (aggregatorType) {
|
|
24
|
+
case 'mean':
|
|
25
|
+
if (scheme === 'decentralized') {
|
|
26
|
+
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
|
|
27
|
+
options = {
|
|
28
|
+
roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
|
|
29
|
+
...options
|
|
30
|
+
};
|
|
31
|
+
}
|
|
32
|
+
else {
|
|
33
|
+
// If scheme == 'federated' then we only expect the server's contribution at each round
|
|
34
|
+
// so we set the aggregation threshold to 1 contribution
|
|
35
|
+
// If scheme == 'local' then we only expect our own contribution
|
|
36
|
+
options = {
|
|
37
|
+
roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
|
|
38
|
+
...options
|
|
39
|
+
};
|
|
40
|
+
}
|
|
41
|
+
return new aggregator.MeanAggregator(options.roundCutOff, options.threshold, options.thresholdType);
|
|
42
|
+
case 'secure':
|
|
43
|
+
if (scheme !== 'decentralized') {
|
|
25
44
|
throw new Error('secure aggregation is currently supported for decentralized only');
|
|
26
45
|
}
|
|
27
|
-
return new aggregator.SecureAggregator();
|
|
28
|
-
default:
|
|
29
|
-
return new aggregator.MeanAggregator();
|
|
46
|
+
return new aggregator.SecureAggregator(task.trainingInformation.maxShareValue);
|
|
30
47
|
}
|
|
31
48
|
}
|
|
@@ -3,5 +3,5 @@ import type { Base } from './base.js';
|
|
|
3
3
|
export { Base as AggregatorBase, AggregationStep } from './base.js';
|
|
4
4
|
export { MeanAggregator } from './mean.js';
|
|
5
5
|
export { SecureAggregator } from './secure.js';
|
|
6
|
-
export { getAggregator
|
|
6
|
+
export { getAggregator } from './get.js';
|
|
7
7
|
export type Aggregator = Base<WeightsContainer>;
|
package/dist/aggregator/index.js
CHANGED
|
@@ -1,18 +1,37 @@
|
|
|
1
1
|
import type { Map } from "immutable";
|
|
2
2
|
import { Base as Aggregator } from "./base.js";
|
|
3
|
-
import type {
|
|
4
|
-
|
|
3
|
+
import type { WeightsContainer, client } from "../index.js";
|
|
4
|
+
type ThresholdType = 'relative' | 'absolute';
|
|
5
|
+
/**
|
|
6
|
+
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
7
|
+
*
|
|
8
|
+
*/
|
|
5
9
|
export declare class MeanAggregator extends Aggregator<WeightsContainer> {
|
|
6
10
|
#private;
|
|
7
11
|
/**
|
|
8
|
-
*
|
|
9
|
-
*
|
|
10
|
-
*
|
|
12
|
+
* Create a mean aggregator that averages all weight updates received when a specified threshold is met.
|
|
13
|
+
* By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
|
|
14
|
+
* only accepts contributions from the current round (drops contributions from previous rounds).
|
|
15
|
+
*
|
|
16
|
+
* @param threshold - how many contributions trigger an aggregation step.
|
|
17
|
+
* It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
|
|
18
|
+
* Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
|
|
19
|
+
* It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
|
|
20
|
+
* Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
|
|
21
|
+
* set `threshold = 1` and `thresholdType = 'absolute'`
|
|
22
|
+
* @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
|
|
23
|
+
* If `threshold != 1` then the specified thresholdType is ignored and overwritten
|
|
24
|
+
* If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
|
|
25
|
+
* if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
|
|
26
|
+
* @param roundCutoff - from how many past rounds do we still accept contributions.
|
|
27
|
+
* If 0 then only accept contributions from the current round,
|
|
28
|
+
* if 1 then the current round and the previous one, etc.
|
|
11
29
|
*/
|
|
12
|
-
constructor(
|
|
30
|
+
constructor(roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
|
|
13
31
|
/** Checks whether the contributions buffer is full. */
|
|
14
32
|
isFull(): boolean;
|
|
15
33
|
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
|
|
16
34
|
aggregate(): void;
|
|
17
35
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
18
36
|
}
|
|
37
|
+
export {};
|
package/dist/aggregator/mean.js
CHANGED
|
@@ -1,39 +1,86 @@
|
|
|
1
1
|
import { AggregationStep, Base as Aggregator } from "./base.js";
|
|
2
2
|
import { aggregation } from "../index.js";
|
|
3
|
-
/**
|
|
3
|
+
/**
|
|
4
|
+
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
5
|
+
*
|
|
6
|
+
*/
|
|
4
7
|
export class MeanAggregator extends Aggregator {
|
|
5
8
|
#threshold;
|
|
9
|
+
#thresholdType;
|
|
6
10
|
/**
|
|
7
|
-
*
|
|
8
|
-
*
|
|
9
|
-
*
|
|
11
|
+
* Create a mean aggregator that averages all weight updates received when a specified threshold is met.
|
|
12
|
+
* By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
|
|
13
|
+
* only accepts contributions from the current round (drops contributions from previous rounds).
|
|
14
|
+
*
|
|
15
|
+
* @param threshold - how many contributions trigger an aggregation step.
|
|
16
|
+
* It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
|
|
17
|
+
* Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
|
|
18
|
+
* It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
|
|
19
|
+
* Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
|
|
20
|
+
* set `threshold = 1` and `thresholdType = 'absolute'`
|
|
21
|
+
* @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
|
|
22
|
+
* If `threshold != 1` then the specified thresholdType is ignored and overwritten
|
|
23
|
+
* If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
|
|
24
|
+
* if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
|
|
25
|
+
* @param roundCutoff - from how many past rounds do we still accept contributions.
|
|
26
|
+
* If 0 then only accept contributions from the current round,
|
|
27
|
+
* if 1 then the current round and the previous one, etc.
|
|
10
28
|
*/
|
|
11
|
-
|
|
12
|
-
constructor(model, roundCutoff = 0, threshold = 1) {
|
|
29
|
+
constructor(roundCutoff = 0, threshold = 1, thresholdType) {
|
|
13
30
|
if (threshold <= 0)
|
|
14
|
-
throw new Error("threshold must be
|
|
15
|
-
if (threshold > 1 && !Number.isInteger(threshold))
|
|
16
|
-
throw new Error("absolute thresholds must be
|
|
17
|
-
super(
|
|
31
|
+
throw new Error("threshold must be strictly positive");
|
|
32
|
+
if (threshold > 1 && (!Number.isInteger(threshold)))
|
|
33
|
+
throw new Error("absolute thresholds must be integral");
|
|
34
|
+
super(roundCutoff, 1);
|
|
18
35
|
this.#threshold = threshold;
|
|
36
|
+
if (threshold < 1) {
|
|
37
|
+
// Throw exception if threshold and thresholdType are conflicting
|
|
38
|
+
if (thresholdType === 'absolute') {
|
|
39
|
+
throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`);
|
|
40
|
+
}
|
|
41
|
+
this.#thresholdType = 'relative';
|
|
42
|
+
}
|
|
43
|
+
else if (threshold > 1) {
|
|
44
|
+
// Throw exception if threshold and thresholdType are conflicting
|
|
45
|
+
if (thresholdType === 'relative') {
|
|
46
|
+
throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`);
|
|
47
|
+
}
|
|
48
|
+
this.#thresholdType = 'absolute';
|
|
49
|
+
}
|
|
50
|
+
// remaining case: threshold == 1
|
|
51
|
+
else {
|
|
52
|
+
// Print a warning regarding the default behavior when thresholdType is not specified
|
|
53
|
+
if (thresholdType === undefined) {
|
|
54
|
+
console.warn("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
|
|
55
|
+
"To instead wait for a single contribution, set thresholdType = 'absolute'");
|
|
56
|
+
this.#thresholdType = 'relative';
|
|
57
|
+
}
|
|
58
|
+
else {
|
|
59
|
+
this.#thresholdType = thresholdType;
|
|
60
|
+
}
|
|
61
|
+
}
|
|
19
62
|
}
|
|
20
63
|
/** Checks whether the contributions buffer is full. */
|
|
21
64
|
isFull() {
|
|
22
|
-
const
|
|
65
|
+
const thresholdValue = this.#thresholdType == 'relative'
|
|
23
66
|
? this.#threshold * this.nodes.size
|
|
24
67
|
: this.#threshold;
|
|
25
|
-
return (this.contributions.get(0)?.size ?? 0) >=
|
|
68
|
+
return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
|
|
26
69
|
}
|
|
27
70
|
add(nodeId, contribution, round, currentContributions = 0) {
|
|
28
71
|
if (currentContributions !== 0)
|
|
29
72
|
throw new Error("only a single communication round");
|
|
30
|
-
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
|
|
73
|
+
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
|
|
74
|
+
if (!this.nodes.has(nodeId))
|
|
75
|
+
console.warn("Contribution rejected because node id is not registered");
|
|
76
|
+
if (!this.isWithinRoundCutoff(round))
|
|
77
|
+
console.warn(`Contribution rejected because round ${round} is not within round cutoff`);
|
|
31
78
|
return false;
|
|
79
|
+
}
|
|
32
80
|
this.log(this.contributions.hasIn([0, nodeId])
|
|
33
81
|
? AggregationStep.UPDATE
|
|
34
82
|
: AggregationStep.ADD, nodeId);
|
|
35
83
|
this.contributions = this.contributions.setIn([0, nodeId], contribution);
|
|
36
|
-
this.informant?.update();
|
|
37
84
|
if (this.isFull())
|
|
38
85
|
this.aggregate();
|
|
39
86
|
return true;
|
|
@@ -44,9 +91,7 @@ export class MeanAggregator extends Aggregator {
|
|
|
44
91
|
throw new Error("aggregating without any contribution");
|
|
45
92
|
this.log(AggregationStep.AGGREGATE);
|
|
46
93
|
const result = aggregation.avg(currentContributions.values());
|
|
47
|
-
|
|
48
|
-
this.model.weights = result;
|
|
49
|
-
this.emit(result);
|
|
94
|
+
this.emit('aggregation', result);
|
|
50
95
|
}
|
|
51
96
|
makePayloads(weights) {
|
|
52
97
|
// Communicate our local weights to every other node, be it a peer or a server
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { Map, List } from "immutable";
|
|
2
2
|
import { Base as Aggregator } from "./base.js";
|
|
3
|
-
import type {
|
|
3
|
+
import type { WeightsContainer, client } from "../index.js";
|
|
4
4
|
/**
|
|
5
5
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
6
6
|
* An aggregation consists of two communication rounds:
|
|
@@ -10,7 +10,7 @@ import type { Model, WeightsContainer, client } from "../index.js";
|
|
|
10
10
|
*/
|
|
11
11
|
export declare class SecureAggregator extends Aggregator<WeightsContainer> {
|
|
12
12
|
private readonly maxShareValue;
|
|
13
|
-
constructor(
|
|
13
|
+
constructor(maxShareValue?: number);
|
|
14
14
|
aggregate(): void;
|
|
15
15
|
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
|
|
16
16
|
isFull(): boolean;
|
|
@@ -11,8 +11,8 @@ import { aggregation } from "../index.js";
|
|
|
11
11
|
*/
|
|
12
12
|
export class SecureAggregator extends Aggregator {
|
|
13
13
|
maxShareValue;
|
|
14
|
-
constructor(
|
|
15
|
-
super(
|
|
14
|
+
constructor(maxShareValue = 100) {
|
|
15
|
+
super(0, 2);
|
|
16
16
|
this.maxShareValue = maxShareValue;
|
|
17
17
|
}
|
|
18
18
|
aggregate() {
|
|
@@ -24,7 +24,7 @@ export class SecureAggregator extends Aggregator {
|
|
|
24
24
|
if (currentContributions === undefined)
|
|
25
25
|
throw new Error("aggregating without any contribution");
|
|
26
26
|
const result = aggregation.sum(currentContributions.values());
|
|
27
|
-
this.emit(result);
|
|
27
|
+
this.emit('aggregation', result);
|
|
28
28
|
break;
|
|
29
29
|
}
|
|
30
30
|
// Average the received partial sums
|
|
@@ -33,9 +33,7 @@ export class SecureAggregator extends Aggregator {
|
|
|
33
33
|
if (currentContributions === undefined)
|
|
34
34
|
throw new Error("aggregating without any contribution");
|
|
35
35
|
const result = aggregation.avg(currentContributions.values());
|
|
36
|
-
|
|
37
|
-
this.model.weights = result;
|
|
38
|
-
this.emit(result);
|
|
36
|
+
this.emit('aggregation', result);
|
|
39
37
|
break;
|
|
40
38
|
}
|
|
41
39
|
default:
|
|
@@ -56,7 +54,6 @@ export class SecureAggregator extends Aggregator {
|
|
|
56
54
|
? AggregationStep.UPDATE
|
|
57
55
|
: AggregationStep.ADD, nodeId);
|
|
58
56
|
this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
|
|
59
|
-
this.informant?.update();
|
|
60
57
|
if (this.isFull())
|
|
61
58
|
this.aggregate();
|
|
62
59
|
return true;
|
package/dist/client/base.d.ts
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import type { Set } from 'immutable';
|
|
2
1
|
import type { Model, Task, WeightsContainer } from '../index.js';
|
|
3
2
|
import type { NodeID } from './types.js';
|
|
4
3
|
import type { EventConnection } from './event_connection.js';
|
|
@@ -68,9 +67,10 @@ export declare abstract class Base {
|
|
|
68
67
|
* Communication callback called the end of every training round.
|
|
69
68
|
* @param _weights The most recent local weight updates
|
|
70
69
|
* @param _round The current training round
|
|
70
|
+
* @returns aggregated weights or the local weights upon error
|
|
71
71
|
*/
|
|
72
|
-
onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<
|
|
73
|
-
get
|
|
72
|
+
abstract onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<WeightsContainer>;
|
|
73
|
+
get nbOfParticipants(): number;
|
|
74
74
|
get ownId(): NodeID;
|
|
75
75
|
get server(): EventConnection;
|
|
76
76
|
}
|
package/dist/client/base.js
CHANGED
|
@@ -64,14 +64,12 @@ export class Base {
|
|
|
64
64
|
* @param _round The current training round
|
|
65
65
|
*/
|
|
66
66
|
async onRoundBeginCommunication(_weights, _round) { }
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
get nodes() {
|
|
74
|
-
return this.aggregator.nodes;
|
|
67
|
+
// Number of contributors to a collaborative session
|
|
68
|
+
// If decentralized, it should be the number of peers
|
|
69
|
+
// If federated, it should the number of participants excluding the server
|
|
70
|
+
// If local it should be 1
|
|
71
|
+
get nbOfParticipants() {
|
|
72
|
+
return this.aggregator.nodes.size; // overriden by the federated client
|
|
75
73
|
}
|
|
76
74
|
get ownId() {
|
|
77
75
|
if (this._ownId === undefined) {
|