@epfml/discojs 2.1.2-p20240723143623.0 → 2.1.2-p20240723160120.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +8 -48
- package/dist/aggregator/base.js +6 -68
- package/dist/aggregator/get.d.ts +0 -2
- package/dist/aggregator/get.js +4 -4
- package/dist/aggregator/mean.d.ts +2 -2
- package/dist/aggregator/mean.js +3 -6
- package/dist/aggregator/secure.d.ts +2 -2
- package/dist/aggregator/secure.js +4 -7
- package/dist/client/base.d.ts +2 -1
- package/dist/client/base.js +0 -6
- package/dist/client/decentralized/base.d.ts +2 -2
- package/dist/client/decentralized/base.js +9 -8
- package/dist/client/federated/base.d.ts +1 -1
- package/dist/client/federated/base.js +2 -1
- package/dist/client/local.d.ts +3 -1
- package/dist/client/local.js +4 -1
- package/dist/default_tasks/cifar10.js +1 -2
- package/dist/default_tasks/mnist.js +0 -2
- package/dist/default_tasks/simple_face.js +0 -2
- package/dist/default_tasks/titanic.js +0 -2
- package/dist/index.d.ts +0 -1
- package/dist/index.js +0 -1
- package/dist/privacy.d.ts +8 -10
- package/dist/privacy.js +25 -40
- package/dist/task/training_information.d.ts +6 -2
- package/dist/task/training_information.js +17 -5
- package/dist/training/disco.d.ts +30 -28
- package/dist/training/disco.js +76 -61
- package/dist/training/index.d.ts +1 -1
- package/dist/training/index.js +1 -0
- package/dist/training/trainer.d.ts +16 -0
- package/dist/training/trainer.js +72 -0
- package/dist/weights/weights_container.d.ts +0 -5
- package/dist/weights/weights_container.js +0 -7
- package/package.json +1 -1
- package/dist/async_informant.d.ts +0 -15
- package/dist/async_informant.js +0 -42
- package/dist/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/training/trainer/distributed_trainer.js +0 -41
- package/dist/training/trainer/local_trainer.d.ts +0 -12
- package/dist/training/trainer/local_trainer.js +0 -24
- package/dist/training/trainer/trainer.d.ts +0 -32
- package/dist/training/trainer/trainer.js +0 -61
- package/dist/training/trainer/trainer_builder.d.ts +0 -23
- package/dist/training/trainer/trainer_builder.js +0 -47
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import { Map, Set } from 'immutable';
|
|
2
|
-
import type { client
|
|
2
|
+
import type { client } from '../index.js';
|
|
3
|
+
import { EventEmitter } from '../utils/event_emitter.js';
|
|
3
4
|
export declare enum AggregationStep {
|
|
4
5
|
ADD = 0,
|
|
5
6
|
UPDATE = 1,
|
|
@@ -8,12 +9,13 @@ export declare enum AggregationStep {
|
|
|
8
9
|
/**
|
|
9
10
|
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
10
11
|
* a result based off their aggregation, whenever some defined condition is met.
|
|
12
|
+
*
|
|
13
|
+
* Emits an event whenever an aggregation step is performed.
|
|
14
|
+
* Users wait for this event to fetch the aggregation result.
|
|
11
15
|
*/
|
|
12
|
-
export declare abstract class Base<T> {
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
*/
|
|
16
|
-
protected _model?: Model | undefined;
|
|
16
|
+
export declare abstract class Base<T> extends EventEmitter<{
|
|
17
|
+
'aggregation': T;
|
|
18
|
+
}> {
|
|
17
19
|
/**
|
|
18
20
|
* The round cut-off for contributions.
|
|
19
21
|
*/
|
|
@@ -33,19 +35,6 @@ export declare abstract class Base<T> {
|
|
|
33
35
|
* of all active nodes, depending on the aggregation scheme.
|
|
34
36
|
*/
|
|
35
37
|
protected contributions: Map<number, Map<client.NodeID, T>>;
|
|
36
|
-
/**
|
|
37
|
-
* Emits the aggregation event whenever an aggregation step is performed.
|
|
38
|
-
* Triggers the resolve of the result promise and the preparation for the
|
|
39
|
-
* next aggregation round.
|
|
40
|
-
*/
|
|
41
|
-
private readonly eventEmitter;
|
|
42
|
-
protected informant?: AsyncInformant<T>;
|
|
43
|
-
/**
|
|
44
|
-
* The result promise which, on resolve, will contain the current aggregation result.
|
|
45
|
-
* This promise should be fetched by any object making use of an aggregator, in order
|
|
46
|
-
* to await upon aggregation.
|
|
47
|
-
*/
|
|
48
|
-
protected result: Promise<T>;
|
|
49
38
|
/**
|
|
50
39
|
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
51
40
|
* or not.
|
|
@@ -59,10 +48,6 @@ export declare abstract class Base<T> {
|
|
|
59
48
|
*/
|
|
60
49
|
protected _communicationRound: number;
|
|
61
50
|
constructor(
|
|
62
|
-
/**
|
|
63
|
-
* The Model whose weights are updated on aggregation.
|
|
64
|
-
*/
|
|
65
|
-
_model?: Model | undefined,
|
|
66
51
|
/**
|
|
67
52
|
* The round cut-off for contributions.
|
|
68
53
|
*/
|
|
@@ -85,7 +70,6 @@ export declare abstract class Base<T> {
|
|
|
85
70
|
* Must store the aggregation's result in the aggregator's result promise.
|
|
86
71
|
*/
|
|
87
72
|
abstract aggregate(): void;
|
|
88
|
-
registerObserver(informant: AsyncInformant<T>): void;
|
|
89
73
|
/**
|
|
90
74
|
* Returns whether the given round is recent enough, dependent on the
|
|
91
75
|
* aggregator's round cutoff.
|
|
@@ -99,11 +83,6 @@ export declare abstract class Base<T> {
|
|
|
99
83
|
* @param from The node which triggered the logging message
|
|
100
84
|
*/
|
|
101
85
|
log(step: AggregationStep, from?: client.NodeID): void;
|
|
102
|
-
/**
|
|
103
|
-
* Sets the aggregator's model.
|
|
104
|
-
* @param model The new model
|
|
105
|
-
*/
|
|
106
|
-
setModel(model: Model): void;
|
|
107
86
|
/**
|
|
108
87
|
* Adds a node's id to the set of active nodes. A node represents an active neighbor
|
|
109
88
|
* peer/client within the network, whom we are communicating with during this aggregation
|
|
@@ -130,27 +109,12 @@ export declare abstract class Base<T> {
|
|
|
130
109
|
* @param round The new round
|
|
131
110
|
*/
|
|
132
111
|
setRound(round: number): void;
|
|
133
|
-
/**
|
|
134
|
-
* Emits the event containing the aggregation result, which allows the result
|
|
135
|
-
* promise to resolve and for the next aggregation round to take place.
|
|
136
|
-
* @param aggregated The aggregation result
|
|
137
|
-
*/
|
|
138
|
-
protected emit(aggregated: T): void;
|
|
139
112
|
/**
|
|
140
113
|
* Updates the aggregator's state to proceed to the next communication round.
|
|
141
114
|
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
142
115
|
* and empties the collection of stored contributions.
|
|
143
116
|
*/
|
|
144
117
|
nextRound(): void;
|
|
145
|
-
private makeResult;
|
|
146
|
-
/**
|
|
147
|
-
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
|
|
148
|
-
* This function gives access to the current aggregation result's promise, which will
|
|
149
|
-
* eventually resolve and contain the result of the very next aggregation step, at the
|
|
150
|
-
* time of the function call.
|
|
151
|
-
* @returns The promise containing the aggregation result
|
|
152
|
-
*/
|
|
153
|
-
receiveResult(): Promise<T>;
|
|
154
118
|
/**
|
|
155
119
|
* Constructs the payloads sent to other nodes as contribution.
|
|
156
120
|
* @param base Object from which the payload is computed
|
|
@@ -170,10 +134,6 @@ export declare abstract class Base<T> {
|
|
|
170
134
|
* the amount of all active nodes times the number of communication rounds.
|
|
171
135
|
*/
|
|
172
136
|
get size(): number;
|
|
173
|
-
/**
|
|
174
|
-
* The aggregator's current model.
|
|
175
|
-
*/
|
|
176
|
-
get model(): Model | undefined;
|
|
177
137
|
/**
|
|
178
138
|
* The current communication round.
|
|
179
139
|
*/
|
package/dist/aggregator/base.js
CHANGED
|
@@ -9,9 +9,11 @@ export var AggregationStep;
|
|
|
9
9
|
/**
|
|
10
10
|
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
|
|
11
11
|
* a result based off their aggregation, whenever some defined condition is met.
|
|
12
|
+
*
|
|
13
|
+
* Emits an event whenever an aggregation step is performed.
|
|
14
|
+
* Users wait for this event to fetch the aggregation result.
|
|
12
15
|
*/
|
|
13
|
-
export class Base {
|
|
14
|
-
_model;
|
|
16
|
+
export class Base extends EventEmitter {
|
|
15
17
|
roundCutoff;
|
|
16
18
|
communicationRounds;
|
|
17
19
|
/**
|
|
@@ -26,19 +28,6 @@ export class Base {
|
|
|
26
28
|
*/
|
|
27
29
|
// communication round -> NodeID -> T
|
|
28
30
|
contributions;
|
|
29
|
-
/**
|
|
30
|
-
* Emits the aggregation event whenever an aggregation step is performed.
|
|
31
|
-
* Triggers the resolve of the result promise and the preparation for the
|
|
32
|
-
* next aggregation round.
|
|
33
|
-
*/
|
|
34
|
-
eventEmitter = new EventEmitter();
|
|
35
|
-
informant;
|
|
36
|
-
/**
|
|
37
|
-
* The result promise which, on resolve, will contain the current aggregation result.
|
|
38
|
-
* This promise should be fetched by any object making use of an aggregator, in order
|
|
39
|
-
* to await upon aggregation.
|
|
40
|
-
*/
|
|
41
|
-
result;
|
|
42
31
|
/**
|
|
43
32
|
* The current aggregation round, used for assessing whether a node contribution is recent enough
|
|
44
33
|
* or not.
|
|
@@ -52,10 +41,6 @@ export class Base {
|
|
|
52
41
|
*/
|
|
53
42
|
_communicationRound = 0;
|
|
54
43
|
constructor(
|
|
55
|
-
/**
|
|
56
|
-
* The Model whose weights are updated on aggregation.
|
|
57
|
-
*/
|
|
58
|
-
_model,
|
|
59
44
|
/**
|
|
60
45
|
* The round cut-off for contributions.
|
|
61
46
|
*/
|
|
@@ -64,21 +49,14 @@ export class Base {
|
|
|
64
49
|
* The number of communication rounds occurring during any given aggregation round.
|
|
65
50
|
*/
|
|
66
51
|
communicationRounds = 1) {
|
|
67
|
-
|
|
52
|
+
super();
|
|
68
53
|
this.roundCutoff = roundCutoff;
|
|
69
54
|
this.communicationRounds = communicationRounds;
|
|
70
55
|
this.contributions = Map();
|
|
71
56
|
this._nodes = Set();
|
|
72
|
-
// Make the initial result promise
|
|
73
|
-
this.result = this.makeResult();
|
|
74
57
|
// On every aggregation, update the object's state to match the current aggregation
|
|
75
58
|
// and communication rounds.
|
|
76
|
-
this.
|
|
77
|
-
this.nextRound();
|
|
78
|
-
});
|
|
79
|
-
}
|
|
80
|
-
registerObserver(informant) {
|
|
81
|
-
this.informant = informant;
|
|
59
|
+
this.on('aggregation', () => this.nextRound());
|
|
82
60
|
}
|
|
83
61
|
/**
|
|
84
62
|
* Returns whether the given round is recent enough, dependent on the
|
|
@@ -115,13 +93,6 @@ export class Base {
|
|
|
115
93
|
}
|
|
116
94
|
}
|
|
117
95
|
}
|
|
118
|
-
/**
|
|
119
|
-
* Sets the aggregator's model.
|
|
120
|
-
* @param model The new model
|
|
121
|
-
*/
|
|
122
|
-
setModel(model) {
|
|
123
|
-
this._model = model;
|
|
124
|
-
}
|
|
125
96
|
/**
|
|
126
97
|
* Adds a node's id to the set of active nodes. A node represents an active neighbor
|
|
127
98
|
* peer/client within the network, whom we are communicating with during this aggregation
|
|
@@ -162,14 +133,6 @@ export class Base {
|
|
|
162
133
|
this._round = round;
|
|
163
134
|
}
|
|
164
135
|
}
|
|
165
|
-
/**
|
|
166
|
-
* Emits the event containing the aggregation result, which allows the result
|
|
167
|
-
* promise to resolve and for the next aggregation round to take place.
|
|
168
|
-
* @param aggregated The aggregation result
|
|
169
|
-
*/
|
|
170
|
-
emit(aggregated) {
|
|
171
|
-
this.eventEmitter.emit('aggregation', aggregated);
|
|
172
|
-
}
|
|
173
136
|
/**
|
|
174
137
|
* Updates the aggregator's state to proceed to the next communication round.
|
|
175
138
|
* If all communication rounds were performed, proceeds to the next aggregation round
|
|
@@ -181,25 +144,6 @@ export class Base {
|
|
|
181
144
|
this._round++;
|
|
182
145
|
this.contributions = Map();
|
|
183
146
|
}
|
|
184
|
-
this.result = this.makeResult();
|
|
185
|
-
this.informant?.update();
|
|
186
|
-
}
|
|
187
|
-
async makeResult() {
|
|
188
|
-
return await new Promise((resolve) => {
|
|
189
|
-
this.eventEmitter.once('aggregation', (w) => {
|
|
190
|
-
resolve(w);
|
|
191
|
-
});
|
|
192
|
-
});
|
|
193
|
-
}
|
|
194
|
-
/**
|
|
195
|
-
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
|
|
196
|
-
* This function gives access to the current aggregation result's promise, which will
|
|
197
|
-
* eventually resolve and contain the result of the very next aggregation step, at the
|
|
198
|
-
* time of the function call.
|
|
199
|
-
* @returns The promise containing the aggregation result
|
|
200
|
-
*/
|
|
201
|
-
async receiveResult() {
|
|
202
|
-
return await this.result;
|
|
203
147
|
}
|
|
204
148
|
/**
|
|
205
149
|
* The set of node ids, representing our neighbors within the network.
|
|
@@ -223,12 +167,6 @@ export class Base {
|
|
|
223
167
|
.map((m) => m.size)
|
|
224
168
|
.reduce((totalSize, size) => totalSize + size) ?? 0;
|
|
225
169
|
}
|
|
226
|
-
/**
|
|
227
|
-
* The aggregator's current model.
|
|
228
|
-
*/
|
|
229
|
-
get model() {
|
|
230
|
-
return this._model;
|
|
231
|
-
}
|
|
232
170
|
/**
|
|
233
171
|
* The current communication round.
|
|
234
172
|
*/
|
package/dist/aggregator/get.d.ts
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
import type { Task } from '../index.js';
|
|
2
2
|
import { aggregator } from '../index.js';
|
|
3
|
-
import type { Model } from "../index.js";
|
|
4
3
|
type AggregatorOptions = Partial<{
|
|
5
|
-
model: Model;
|
|
6
4
|
scheme: Task['trainingInformation']['scheme'];
|
|
7
5
|
roundCutOff: number;
|
|
8
6
|
threshold: number;
|
package/dist/aggregator/get.js
CHANGED
|
@@ -25,7 +25,7 @@ export function getAggregator(task, options = {}) {
|
|
|
25
25
|
if (scheme === 'decentralized') {
|
|
26
26
|
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
|
|
27
27
|
options = {
|
|
28
|
-
|
|
28
|
+
roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
|
|
29
29
|
...options
|
|
30
30
|
};
|
|
31
31
|
}
|
|
@@ -34,15 +34,15 @@ export function getAggregator(task, options = {}) {
|
|
|
34
34
|
// so we set the aggregation threshold to 1 contribution
|
|
35
35
|
// If scheme == 'local' then we only expect our own contribution
|
|
36
36
|
options = {
|
|
37
|
-
|
|
37
|
+
roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
|
|
38
38
|
...options
|
|
39
39
|
};
|
|
40
40
|
}
|
|
41
|
-
return new aggregator.MeanAggregator(options.
|
|
41
|
+
return new aggregator.MeanAggregator(options.roundCutOff, options.threshold, options.thresholdType);
|
|
42
42
|
case 'secure':
|
|
43
43
|
if (scheme !== 'decentralized') {
|
|
44
44
|
throw new Error('secure aggregation is currently supported for decentralized only');
|
|
45
45
|
}
|
|
46
|
-
return new aggregator.SecureAggregator(
|
|
46
|
+
return new aggregator.SecureAggregator(task.trainingInformation.maxShareValue);
|
|
47
47
|
}
|
|
48
48
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import type { Map } from "immutable";
|
|
2
2
|
import { Base as Aggregator } from "./base.js";
|
|
3
|
-
import type {
|
|
3
|
+
import type { WeightsContainer, client } from "../index.js";
|
|
4
4
|
type ThresholdType = 'relative' | 'absolute';
|
|
5
5
|
/**
|
|
6
6
|
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
@@ -27,7 +27,7 @@ export declare class MeanAggregator extends Aggregator<WeightsContainer> {
|
|
|
27
27
|
* If 0 then only accept contributions from the current round,
|
|
28
28
|
* if 1 then the current round and the previous one, etc.
|
|
29
29
|
*/
|
|
30
|
-
constructor(
|
|
30
|
+
constructor(roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
|
|
31
31
|
/** Checks whether the contributions buffer is full. */
|
|
32
32
|
isFull(): boolean;
|
|
33
33
|
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
|
package/dist/aggregator/mean.js
CHANGED
|
@@ -26,12 +26,12 @@ export class MeanAggregator extends Aggregator {
|
|
|
26
26
|
* If 0 then only accept contributions from the current round,
|
|
27
27
|
* if 1 then the current round and the previous one, etc.
|
|
28
28
|
*/
|
|
29
|
-
constructor(
|
|
29
|
+
constructor(roundCutoff = 0, threshold = 1, thresholdType) {
|
|
30
30
|
if (threshold <= 0)
|
|
31
31
|
throw new Error("threshold must be strictly positive");
|
|
32
32
|
if (threshold > 1 && (!Number.isInteger(threshold)))
|
|
33
33
|
throw new Error("absolute thresholds must be integral");
|
|
34
|
-
super(
|
|
34
|
+
super(roundCutoff, 1);
|
|
35
35
|
this.#threshold = threshold;
|
|
36
36
|
if (threshold < 1) {
|
|
37
37
|
// Throw exception if threshold and thresholdType are conflicting
|
|
@@ -81,7 +81,6 @@ export class MeanAggregator extends Aggregator {
|
|
|
81
81
|
? AggregationStep.UPDATE
|
|
82
82
|
: AggregationStep.ADD, nodeId);
|
|
83
83
|
this.contributions = this.contributions.setIn([0, nodeId], contribution);
|
|
84
|
-
this.informant?.update();
|
|
85
84
|
if (this.isFull())
|
|
86
85
|
this.aggregate();
|
|
87
86
|
return true;
|
|
@@ -92,9 +91,7 @@ export class MeanAggregator extends Aggregator {
|
|
|
92
91
|
throw new Error("aggregating without any contribution");
|
|
93
92
|
this.log(AggregationStep.AGGREGATE);
|
|
94
93
|
const result = aggregation.avg(currentContributions.values());
|
|
95
|
-
|
|
96
|
-
this.model.weights = result;
|
|
97
|
-
this.emit(result);
|
|
94
|
+
this.emit('aggregation', result);
|
|
98
95
|
}
|
|
99
96
|
makePayloads(weights) {
|
|
100
97
|
// Communicate our local weights to every other node, be it a peer or a server
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { Map, List } from "immutable";
|
|
2
2
|
import { Base as Aggregator } from "./base.js";
|
|
3
|
-
import type {
|
|
3
|
+
import type { WeightsContainer, client } from "../index.js";
|
|
4
4
|
/**
|
|
5
5
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
6
6
|
* An aggregation consists of two communication rounds:
|
|
@@ -10,7 +10,7 @@ import type { Model, WeightsContainer, client } from "../index.js";
|
|
|
10
10
|
*/
|
|
11
11
|
export declare class SecureAggregator extends Aggregator<WeightsContainer> {
|
|
12
12
|
private readonly maxShareValue;
|
|
13
|
-
constructor(
|
|
13
|
+
constructor(maxShareValue?: number);
|
|
14
14
|
aggregate(): void;
|
|
15
15
|
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
|
|
16
16
|
isFull(): boolean;
|
|
@@ -11,8 +11,8 @@ import { aggregation } from "../index.js";
|
|
|
11
11
|
*/
|
|
12
12
|
export class SecureAggregator extends Aggregator {
|
|
13
13
|
maxShareValue;
|
|
14
|
-
constructor(
|
|
15
|
-
super(
|
|
14
|
+
constructor(maxShareValue = 100) {
|
|
15
|
+
super(0, 2);
|
|
16
16
|
this.maxShareValue = maxShareValue;
|
|
17
17
|
}
|
|
18
18
|
aggregate() {
|
|
@@ -24,7 +24,7 @@ export class SecureAggregator extends Aggregator {
|
|
|
24
24
|
if (currentContributions === undefined)
|
|
25
25
|
throw new Error("aggregating without any contribution");
|
|
26
26
|
const result = aggregation.sum(currentContributions.values());
|
|
27
|
-
this.emit(result);
|
|
27
|
+
this.emit('aggregation', result);
|
|
28
28
|
break;
|
|
29
29
|
}
|
|
30
30
|
// Average the received partial sums
|
|
@@ -33,9 +33,7 @@ export class SecureAggregator extends Aggregator {
|
|
|
33
33
|
if (currentContributions === undefined)
|
|
34
34
|
throw new Error("aggregating without any contribution");
|
|
35
35
|
const result = aggregation.avg(currentContributions.values());
|
|
36
|
-
|
|
37
|
-
this.model.weights = result;
|
|
38
|
-
this.emit(result);
|
|
36
|
+
this.emit('aggregation', result);
|
|
39
37
|
break;
|
|
40
38
|
}
|
|
41
39
|
default:
|
|
@@ -56,7 +54,6 @@ export class SecureAggregator extends Aggregator {
|
|
|
56
54
|
? AggregationStep.UPDATE
|
|
57
55
|
: AggregationStep.ADD, nodeId);
|
|
58
56
|
this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
|
|
59
|
-
this.informant?.update();
|
|
60
57
|
if (this.isFull())
|
|
61
58
|
this.aggregate();
|
|
62
59
|
return true;
|
package/dist/client/base.d.ts
CHANGED
|
@@ -67,8 +67,9 @@ export declare abstract class Base {
|
|
|
67
67
|
* Communication callback called the end of every training round.
|
|
68
68
|
* @param _weights The most recent local weight updates
|
|
69
69
|
* @param _round The current training round
|
|
70
|
+
* @returns aggregated weights or the local weights upon error
|
|
70
71
|
*/
|
|
71
|
-
onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<
|
|
72
|
+
abstract onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<WeightsContainer>;
|
|
72
73
|
get nbOfParticipants(): number;
|
|
73
74
|
get ownId(): NodeID;
|
|
74
75
|
get server(): EventConnection;
|
package/dist/client/base.js
CHANGED
|
@@ -64,12 +64,6 @@ export class Base {
|
|
|
64
64
|
* @param _round The current training round
|
|
65
65
|
*/
|
|
66
66
|
async onRoundBeginCommunication(_weights, _round) { }
|
|
67
|
-
/**
|
|
68
|
-
* Communication callback called the end of every training round.
|
|
69
|
-
* @param _weights The most recent local weight updates
|
|
70
|
-
* @param _round The current training round
|
|
71
|
-
*/
|
|
72
|
-
async onRoundEndCommunication(_weights, _round) { }
|
|
73
67
|
// Number of contributors to a collaborative session
|
|
74
68
|
// If decentralized, it should be the number of peers
|
|
75
69
|
// If federated, it should the number of participants excluding the server
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import type { WeightsContainer } from "../../index.js";
|
|
2
2
|
import { Client } from '../index.js';
|
|
3
3
|
/**
|
|
4
4
|
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
@@ -45,5 +45,5 @@ export declare class Base extends Client {
|
|
|
45
45
|
* @param round
|
|
46
46
|
*/
|
|
47
47
|
private receivePayloads;
|
|
48
|
-
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<
|
|
48
|
+
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
|
|
49
49
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { Map, Set } from 'immutable';
|
|
2
|
-
import { serialization } from
|
|
2
|
+
import { serialization } from "../../index.js";
|
|
3
3
|
import { Client } from '../index.js';
|
|
4
4
|
import { type } from '../messages.js';
|
|
5
5
|
import { timeout } from '../utils.js';
|
|
@@ -133,7 +133,7 @@ export class Base extends Client {
|
|
|
133
133
|
}
|
|
134
134
|
// Store the promise for the current round's aggregation result.
|
|
135
135
|
// We will await for it to resolve at the end of the round when exchanging weight updates.
|
|
136
|
-
this.aggregationResult = this.aggregator.
|
|
136
|
+
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
137
137
|
}
|
|
138
138
|
/**
|
|
139
139
|
* At each communication rounds, awaits peers contributions and add them to the client's aggregator.
|
|
@@ -163,12 +163,15 @@ export class Base extends Client {
|
|
|
163
163
|
});
|
|
164
164
|
}
|
|
165
165
|
async onRoundEndCommunication(weights, round) {
|
|
166
|
-
|
|
166
|
+
if (this.aggregationResult === undefined) {
|
|
167
|
+
throw new TypeError('aggregation result promise is undefined');
|
|
168
|
+
}
|
|
167
169
|
// Perform the required communication rounds. Each communication round consists in sending our local payload,
|
|
168
170
|
// followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
|
|
169
171
|
// A communication round's payload is the aggregation result of the previous communication round. The first
|
|
170
172
|
// communication round simply sends our training result, i.e. model weights updates. This scheme allows for
|
|
171
173
|
// the aggregator to define any complex multi-round aggregation mechanism.
|
|
174
|
+
let result = weights;
|
|
172
175
|
for (let r = 0; r < this.aggregator.communicationRounds; r++) {
|
|
173
176
|
// Generate our payloads for this communication round and send them to all ready connected peers
|
|
174
177
|
if (this.connections !== undefined) {
|
|
@@ -198,9 +201,6 @@ export class Base extends Client {
|
|
|
198
201
|
throw new Error('error while sending weights');
|
|
199
202
|
}
|
|
200
203
|
}
|
|
201
|
-
if (this.aggregationResult === undefined) {
|
|
202
|
-
throw new TypeError('aggregation result promise is undefined');
|
|
203
|
-
}
|
|
204
204
|
// Wait for aggregation before proceeding to the next communication round.
|
|
205
205
|
// The current result will be used as payload for the eventual next communication round.
|
|
206
206
|
try {
|
|
@@ -211,7 +211,7 @@ export class Base extends Client {
|
|
|
211
211
|
}
|
|
212
212
|
catch (e) {
|
|
213
213
|
if (this.isDisconnected) {
|
|
214
|
-
return;
|
|
214
|
+
return weights;
|
|
215
215
|
}
|
|
216
216
|
console.error(e);
|
|
217
217
|
break;
|
|
@@ -219,10 +219,11 @@ export class Base extends Client {
|
|
|
219
219
|
// There is at least one communication round remaining
|
|
220
220
|
if (r < this.aggregator.communicationRounds - 1) {
|
|
221
221
|
// Reuse the aggregation result
|
|
222
|
-
this.aggregationResult = this.aggregator.
|
|
222
|
+
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
223
223
|
}
|
|
224
224
|
}
|
|
225
225
|
// Reset the peers list for the next round
|
|
226
226
|
this.aggregator.resetNodes();
|
|
227
|
+
return await this.aggregationResult;
|
|
227
228
|
}
|
|
228
229
|
}
|
|
@@ -28,7 +28,7 @@ export declare class Base extends Client {
|
|
|
28
28
|
*/
|
|
29
29
|
disconnect(): Promise<void>;
|
|
30
30
|
onRoundBeginCommunication(): Promise<void>;
|
|
31
|
-
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<
|
|
31
|
+
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
|
|
32
32
|
/**
|
|
33
33
|
* Send a message containing our local weight updates to the federated server.
|
|
34
34
|
* And waits for the server to reply with the most recent aggregated weights
|
|
@@ -68,7 +68,7 @@ export class Base extends Client {
|
|
|
68
68
|
}
|
|
69
69
|
onRoundBeginCommunication() {
|
|
70
70
|
// Prepare the result promise for the incoming round
|
|
71
|
-
this.aggregationResult = this.aggregator.
|
|
71
|
+
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
72
72
|
return Promise.resolve();
|
|
73
73
|
}
|
|
74
74
|
async onRoundEndCommunication(weights, round) {
|
|
@@ -90,6 +90,7 @@ export class Base extends Client {
|
|
|
90
90
|
console.info(`[${this.ownId}] Server result is either stale or not received`);
|
|
91
91
|
this.aggregator.nextRound();
|
|
92
92
|
}
|
|
93
|
+
return await this.aggregationResult;
|
|
93
94
|
}
|
|
94
95
|
/**
|
|
95
96
|
* Send a message containing our local weight updates to the federated server.
|
package/dist/client/local.d.ts
CHANGED
package/dist/client/local.js
CHANGED
|
@@ -30,8 +30,7 @@ export const cifar10 = {
|
|
|
30
30
|
IMAGE_W: 224,
|
|
31
31
|
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
|
|
32
32
|
scheme: 'decentralized',
|
|
33
|
-
noiseScale:
|
|
34
|
-
clippingRadius: 20,
|
|
33
|
+
privacy: { clippingRadius: 20, noiseScale: 1 },
|
|
35
34
|
decentralizedSecure: true,
|
|
36
35
|
minimumReadyPeers: 3,
|
|
37
36
|
maxShareValue: 100,
|
|
@@ -30,8 +30,6 @@ export const mnist = {
|
|
|
30
30
|
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
31
31
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
32
32
|
scheme: 'decentralized',
|
|
33
|
-
noiseScale: undefined,
|
|
34
|
-
clippingRadius: undefined,
|
|
35
33
|
decentralizedSecure: true,
|
|
36
34
|
minimumReadyPeers: 3,
|
|
37
35
|
maxShareValue: 100,
|
package/dist/index.d.ts
CHANGED
|
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
|
|
|
5
5
|
export * as client from './client/index.js';
|
|
6
6
|
export * as aggregator from './aggregator/index.js';
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
|
-
export { AsyncInformant } from './async_informant.js';
|
|
9
8
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
10
9
|
export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
|
|
11
10
|
export { Disco, RoundLogs } from './training/index.js';
|
package/dist/index.js
CHANGED
|
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
|
|
|
5
5
|
export * as client from './client/index.js';
|
|
6
6
|
export * as aggregator from './aggregator/index.js';
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
|
-
export { AsyncInformant } from './async_informant.js';
|
|
9
8
|
export { ConsoleLogger } from './logging/index.js';
|
|
10
9
|
export { Memory, Empty as EmptyMemory } from './memory/index.js';
|
|
11
10
|
export { Disco } from './training/index.js';
|
package/dist/privacy.d.ts
CHANGED
|
@@ -1,11 +1,9 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { WeightsContainer } from "./index.js";
|
|
2
|
+
/** Scramble weights */
|
|
3
|
+
export declare function addNoise(weights: WeightsContainer, deviation: number): WeightsContainer;
|
|
2
4
|
/**
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
* @param task the task
|
|
9
|
-
* @returns the noised weights for the current round
|
|
10
|
-
*/
|
|
11
|
-
export declare function addDifferentialPrivacy(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, task: Task): WeightsContainer;
|
|
5
|
+
* Keep weights' norm within radius
|
|
6
|
+
*
|
|
7
|
+
* @param radius maximum norm
|
|
8
|
+
**/
|
|
9
|
+
export declare function clipNorm(weights: WeightsContainer, radius: number): Promise<WeightsContainer>;
|