@epfml/discojs 2.2.2-p20240703101552.0 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +9 -48
- package/dist/aggregator/base.js +8 -69
- package/dist/aggregator/get.d.ts +23 -11
- package/dist/aggregator/get.js +40 -23
- package/dist/aggregator/index.d.ts +1 -1
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +25 -6
- package/dist/aggregator/mean.js +62 -17
- package/dist/aggregator/secure.d.ts +2 -2
- package/dist/aggregator/secure.js +4 -7
- package/dist/client/base.d.ts +3 -3
- package/dist/client/base.js +6 -8
- package/dist/client/decentralized/base.d.ts +27 -10
- package/dist/client/decentralized/base.js +123 -86
- package/dist/client/decentralized/peer.js +7 -12
- package/dist/client/decentralized/peer_pool.js +6 -2
- package/dist/client/event_connection.d.ts +1 -1
- package/dist/client/event_connection.js +3 -3
- package/dist/client/federated/base.d.ts +5 -21
- package/dist/client/federated/base.js +38 -61
- package/dist/client/federated/messages.d.ts +2 -10
- package/dist/client/federated/messages.js +0 -1
- package/dist/client/index.d.ts +1 -1
- package/dist/client/index.js +1 -1
- package/dist/client/local.d.ts +3 -1
- package/dist/client/local.js +4 -1
- package/dist/client/messages.d.ts +1 -2
- package/dist/client/messages.js +8 -3
- package/dist/client/utils.d.ts +4 -2
- package/dist/client/utils.js +18 -3
- package/dist/dataset/data/data.d.ts +1 -1
- package/dist/dataset/data/data.js +13 -2
- package/dist/dataset/data/preprocessing/image_preprocessing.js +6 -4
- package/dist/default_tasks/cifar10.js +1 -2
- package/dist/default_tasks/lus_covid.js +0 -5
- package/dist/default_tasks/mnist.js +15 -14
- package/dist/default_tasks/simple_face.js +0 -2
- package/dist/default_tasks/titanic.js +2 -4
- package/dist/default_tasks/wikitext.js +7 -1
- package/dist/index.d.ts +0 -1
- package/dist/index.js +0 -1
- package/dist/models/gpt/config.js +1 -1
- package/dist/privacy.d.ts +8 -10
- package/dist/privacy.js +25 -40
- package/dist/task/task_handler.js +10 -2
- package/dist/task/training_information.d.ts +7 -4
- package/dist/task/training_information.js +25 -6
- package/dist/training/disco.d.ts +30 -28
- package/dist/training/disco.js +75 -73
- package/dist/training/index.d.ts +1 -1
- package/dist/training/index.js +1 -0
- package/dist/training/trainer.d.ts +16 -0
- package/dist/training/trainer.js +72 -0
- package/dist/types.d.ts +0 -2
- package/dist/weights/weights_container.d.ts +0 -5
- package/dist/weights/weights_container.js +0 -7
- package/package.json +1 -1
- package/dist/async_informant.d.ts +0 -15
- package/dist/async_informant.js +0 -42
- package/dist/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/training/trainer/distributed_trainer.js +0 -41
- package/dist/training/trainer/local_trainer.d.ts +0 -12
- package/dist/training/trainer/local_trainer.js +0 -24
- package/dist/training/trainer/trainer.d.ts +0 -32
- package/dist/training/trainer/trainer.js +0 -61
- package/dist/training/trainer/trainer_builder.d.ts +0 -23
- package/dist/training/trainer/trainer_builder.js +0 -47
|
@@ -1,7 +1,5 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import type { WeightsContainer } from "../../index.js";
|
|
2
2
|
import { Client } from '../index.js';
|
|
3
|
-
import { type PeerConnection } from '../event_connection.js';
|
|
4
|
-
import * as messages from './messages.js';
|
|
5
3
|
/**
|
|
6
4
|
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
7
5
|
* help of the network's server, yet only exchange payloads between each other. Communication
|
|
@@ -14,19 +12,38 @@ export declare class Base extends Client {
|
|
|
14
12
|
*/
|
|
15
13
|
private pool?;
|
|
16
14
|
private connections?;
|
|
15
|
+
private get isDisconnected();
|
|
17
16
|
/**
|
|
18
|
-
*
|
|
17
|
+
* Public method called by disco.ts when starting training. This method sends
|
|
18
|
+
* a message to the server asking to join the task and be assigned a client ID.
|
|
19
|
+
*
|
|
20
|
+
* The peer also establishes a WebSocket connection with the server to then
|
|
21
|
+
* create peer-to-peer WebRTC connections with peers. The server is used to exchange
|
|
22
|
+
* peers network information.
|
|
19
23
|
*/
|
|
20
|
-
|
|
21
|
-
protected sendMessagetoPeer(peer: PeerConnection, msg: messages.PeerMessage): void;
|
|
24
|
+
connect(): Promise<void>;
|
|
22
25
|
/**
|
|
23
|
-
*
|
|
24
|
-
*
|
|
26
|
+
* Create a WebSocket connection with the server
|
|
27
|
+
* The client then waits for the server to forward it other client's network information.
|
|
28
|
+
* Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
|
|
25
29
|
*/
|
|
26
30
|
private connectServer;
|
|
27
|
-
connect(): Promise<void>;
|
|
28
31
|
disconnect(): Promise<void>;
|
|
32
|
+
/**
|
|
33
|
+
* At the beginning of a round, each peer tells the server it is ready to proceed
|
|
34
|
+
* The server answers with the list of all peers connected for the round
|
|
35
|
+
* Given the list, the peers then create peer-to-peer connections with each other.
|
|
36
|
+
* When connected, one peer creates a promise for every other peer's weight update
|
|
37
|
+
* and waits for it to resolve.
|
|
38
|
+
*
|
|
39
|
+
*/
|
|
29
40
|
onRoundBeginCommunication(_: WeightsContainer, round: number): Promise<void>;
|
|
30
|
-
|
|
41
|
+
/**
|
|
42
|
+
* At each communication rounds, awaits peers contributions and add them to the client's aggregator.
|
|
43
|
+
* This method is used as callback by getPeers when connecting to the rounds' peers
|
|
44
|
+
* @param connections
|
|
45
|
+
* @param round
|
|
46
|
+
*/
|
|
31
47
|
private receivePayloads;
|
|
48
|
+
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
|
|
32
49
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { Map, Set } from 'immutable';
|
|
2
|
-
import { serialization } from
|
|
2
|
+
import { serialization } from "../../index.js";
|
|
3
3
|
import { Client } from '../index.js';
|
|
4
4
|
import { type } from '../messages.js';
|
|
5
5
|
import { timeout } from '../utils.js';
|
|
@@ -18,63 +18,18 @@ export class Base extends Client {
|
|
|
18
18
|
*/
|
|
19
19
|
pool;
|
|
20
20
|
connections;
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
async waitForPeers(round) {
|
|
25
|
-
console.info(`[${this.ownId}] is ready for round`, round);
|
|
26
|
-
// Broadcast our readiness
|
|
27
|
-
const readyMessage = { type: type.PeerIsReady };
|
|
28
|
-
if (this.server === undefined) {
|
|
29
|
-
throw new Error('server undefined, could not connect peers');
|
|
30
|
-
}
|
|
31
|
-
this.server.send(readyMessage);
|
|
32
|
-
// Wait for peers to be connected before sending any update information
|
|
33
|
-
try {
|
|
34
|
-
const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound);
|
|
35
|
-
if (this.nodes.size > 0) {
|
|
36
|
-
throw new Error('got new peer list from server but was already received for this round');
|
|
37
|
-
}
|
|
38
|
-
const peers = Set(receivedMessage.peers);
|
|
39
|
-
console.info(`[${this.ownId}] received peers for round:`, peers.toJS());
|
|
40
|
-
if (this.ownId !== undefined && peers.has(this.ownId)) {
|
|
41
|
-
throw new Error('received peer list contains our own id');
|
|
42
|
-
}
|
|
43
|
-
this.aggregator.setNodes(peers.add(this.ownId));
|
|
44
|
-
if (this.pool === undefined) {
|
|
45
|
-
throw new Error('waiting for peers but peer pool is undefined');
|
|
46
|
-
}
|
|
47
|
-
const connections = await this.pool.getPeers(peers, this.server,
|
|
48
|
-
// Init receipt of peers weights
|
|
49
|
-
(conn) => { this.receivePayloads(conn, round); });
|
|
50
|
-
console.info(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS());
|
|
51
|
-
return connections;
|
|
52
|
-
}
|
|
53
|
-
catch (e) {
|
|
54
|
-
console.error(e);
|
|
55
|
-
this.aggregator.setNodes(Set(this.ownId));
|
|
56
|
-
return Map();
|
|
57
|
-
}
|
|
58
|
-
}
|
|
59
|
-
sendMessagetoPeer(peer, msg) {
|
|
60
|
-
console.info(`[${this.ownId}] send message to peer`, msg.peer, msg);
|
|
61
|
-
peer.send(msg);
|
|
21
|
+
// Used to handle timeouts and promise resolving after calling disconnect
|
|
22
|
+
get isDisconnected() {
|
|
23
|
+
return this._server === undefined;
|
|
62
24
|
}
|
|
63
25
|
/**
|
|
64
|
-
*
|
|
65
|
-
*
|
|
26
|
+
* Public method called by disco.ts when starting training. This method sends
|
|
27
|
+
* a message to the server asking to join the task and be assigned a client ID.
|
|
28
|
+
*
|
|
29
|
+
* The peer also establishes a WebSocket connection with the server to then
|
|
30
|
+
* create peer-to-peer WebRTC connections with peers. The server is used to exchange
|
|
31
|
+
* peers network information.
|
|
66
32
|
*/
|
|
67
|
-
async connectServer(url) {
|
|
68
|
-
const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
|
|
69
|
-
server.on(type.SignalForPeer, (event) => {
|
|
70
|
-
console.info(`[${this.ownId}] received signal from`, event.peer);
|
|
71
|
-
if (this.pool === undefined) {
|
|
72
|
-
throw new Error('received signal but peer pool is undefined');
|
|
73
|
-
}
|
|
74
|
-
this.pool.signal(event.peer, event.signal);
|
|
75
|
-
});
|
|
76
|
-
return server;
|
|
77
|
-
}
|
|
78
33
|
async connect() {
|
|
79
34
|
const serverURL = new URL('', this.url.href);
|
|
80
35
|
switch (this.url.protocol) {
|
|
@@ -94,13 +49,29 @@ export class Base extends Client {
|
|
|
94
49
|
};
|
|
95
50
|
this.server.send(msg);
|
|
96
51
|
const peerIdMsg = await waitMessage(this.server, type.AssignNodeID);
|
|
97
|
-
console.
|
|
52
|
+
console.log(`[${peerIdMsg.id}] assigned id generated by server`);
|
|
98
53
|
if (this._ownId !== undefined) {
|
|
99
54
|
throw new Error('received id from server but was already received');
|
|
100
55
|
}
|
|
101
56
|
this._ownId = peerIdMsg.id;
|
|
102
57
|
this.pool = new PeerPool(peerIdMsg.id);
|
|
103
58
|
}
|
|
59
|
+
/**
|
|
60
|
+
* Create a WebSocket connection with the server
|
|
61
|
+
* The client then waits for the server to forward it other client's network information.
|
|
62
|
+
* Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
|
|
63
|
+
*/
|
|
64
|
+
async connectServer(url) {
|
|
65
|
+
const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
|
|
66
|
+
server.on(type.SignalForPeer, (event) => {
|
|
67
|
+
if (this.pool === undefined) {
|
|
68
|
+
throw new Error('received signal but peer pool is undefined');
|
|
69
|
+
}
|
|
70
|
+
// Create a WebRTC connection with the peer
|
|
71
|
+
this.pool.signal(event.peer, event.signal);
|
|
72
|
+
});
|
|
73
|
+
return server;
|
|
74
|
+
}
|
|
104
75
|
async disconnect() {
|
|
105
76
|
// Disconnect from peers
|
|
106
77
|
await this.pool?.shutdown();
|
|
@@ -115,20 +86,92 @@ export class Base extends Client {
|
|
|
115
86
|
this._ownId = undefined;
|
|
116
87
|
return Promise.resolve();
|
|
117
88
|
}
|
|
89
|
+
/**
|
|
90
|
+
* At the beginning of a round, each peer tells the server it is ready to proceed
|
|
91
|
+
* The server answers with the list of all peers connected for the round
|
|
92
|
+
* Given the list, the peers then create peer-to-peer connections with each other.
|
|
93
|
+
* When connected, one peer creates a promise for every other peer's weight update
|
|
94
|
+
* and waits for it to resolve.
|
|
95
|
+
*
|
|
96
|
+
*/
|
|
118
97
|
async onRoundBeginCommunication(_, round) {
|
|
98
|
+
if (this.server === undefined) {
|
|
99
|
+
throw new Error("peer's server is undefined, make sure to call `client.connect()` first");
|
|
100
|
+
}
|
|
101
|
+
if (this.pool === undefined) {
|
|
102
|
+
throw new Error('peer pool is undefined, make sure to call `client.connect()` first');
|
|
103
|
+
}
|
|
119
104
|
// Reset peers list at each round of training to make sure client works with an updated peers
|
|
120
105
|
// list, maintained by the server. Adds any received weights to the aggregator.
|
|
121
|
-
this.connections = await this.waitForPeers(round)
|
|
106
|
+
// this.connections = await this.waitForPeers(round)
|
|
107
|
+
// Tell the server we are ready for the next round
|
|
108
|
+
const readyMessage = { type: type.PeerIsReady };
|
|
109
|
+
this.server.send(readyMessage);
|
|
110
|
+
// Wait for the server to answer with the list of peers for the round
|
|
111
|
+
try {
|
|
112
|
+
const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound, undefined, "Timeout waiting for the round's peer list");
|
|
113
|
+
const peers = Set(receivedMessage.peers);
|
|
114
|
+
if (this.ownId !== undefined && peers.has(this.ownId)) {
|
|
115
|
+
throw new Error('received peer list contains our own id');
|
|
116
|
+
}
|
|
117
|
+
// Store the list of peers for the current round including ourselves
|
|
118
|
+
this.aggregator.setNodes(peers.add(this.ownId));
|
|
119
|
+
// Initiate peer to peer connections with each peer
|
|
120
|
+
// When connected, create a promise waiting for each peer's round contribution
|
|
121
|
+
const connections = await this.pool.getPeers(peers, this.server,
|
|
122
|
+
// Init receipt of peers weights
|
|
123
|
+
// this awaits the peer's weight update and adds it to
|
|
124
|
+
// our aggregator upon reception
|
|
125
|
+
(conn) => { this.receivePayloads(conn, round); });
|
|
126
|
+
console.log(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS());
|
|
127
|
+
this.connections = connections;
|
|
128
|
+
}
|
|
129
|
+
catch (e) {
|
|
130
|
+
console.error(e);
|
|
131
|
+
this.aggregator.setNodes(Set(this.ownId));
|
|
132
|
+
this.connections = Map();
|
|
133
|
+
}
|
|
122
134
|
// Store the promise for the current round's aggregation result.
|
|
123
|
-
|
|
135
|
+
// We will await for it to resolve at the end of the round when exchanging weight updates.
|
|
136
|
+
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
137
|
+
}
|
|
138
|
+
/**
|
|
139
|
+
* At each communication rounds, awaits peers contributions and add them to the client's aggregator.
|
|
140
|
+
* This method is used as callback by getPeers when connecting to the rounds' peers
|
|
141
|
+
* @param connections
|
|
142
|
+
* @param round
|
|
143
|
+
*/
|
|
144
|
+
receivePayloads(connections, round) {
|
|
145
|
+
connections.forEach(async (connection, peerId) => {
|
|
146
|
+
let currentCommunicationRounds = 0;
|
|
147
|
+
console.log(`waiting for peer ${peerId}`);
|
|
148
|
+
do {
|
|
149
|
+
try {
|
|
150
|
+
const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId);
|
|
151
|
+
const decoded = serialization.weights.decode(message.payload);
|
|
152
|
+
if (!this.aggregator.add(peerId, decoded, round, message.round)) {
|
|
153
|
+
console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`);
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
catch (e) {
|
|
157
|
+
if (this.isDisconnected) {
|
|
158
|
+
return;
|
|
159
|
+
}
|
|
160
|
+
console.error(e instanceof Error ? e.message : e);
|
|
161
|
+
}
|
|
162
|
+
} while (++currentCommunicationRounds < this.aggregator.communicationRounds);
|
|
163
|
+
});
|
|
124
164
|
}
|
|
125
165
|
async onRoundEndCommunication(weights, round) {
|
|
126
|
-
|
|
166
|
+
if (this.aggregationResult === undefined) {
|
|
167
|
+
throw new TypeError('aggregation result promise is undefined');
|
|
168
|
+
}
|
|
127
169
|
// Perform the required communication rounds. Each communication round consists in sending our local payload,
|
|
128
170
|
// followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
|
|
129
171
|
// A communication round's payload is the aggregation result of the previous communication round. The first
|
|
130
172
|
// communication round simply sends our training result, i.e. model weights updates. This scheme allows for
|
|
131
173
|
// the aggregator to define any complex multi-round aggregation mechanism.
|
|
174
|
+
let result = weights;
|
|
132
175
|
for (let r = 0; r < this.aggregator.communicationRounds; r++) {
|
|
133
176
|
// Generate our payloads for this communication round and send them to all ready connected peers
|
|
134
177
|
if (this.connections !== undefined) {
|
|
@@ -139,15 +182,17 @@ export class Base extends Client {
|
|
|
139
182
|
this.aggregator.add(this.ownId, payload, round, r);
|
|
140
183
|
}
|
|
141
184
|
else {
|
|
142
|
-
const
|
|
143
|
-
if (
|
|
185
|
+
const peer = this.connections?.get(id);
|
|
186
|
+
if (peer !== undefined) {
|
|
144
187
|
const encoded = await serialization.weights.encode(payload);
|
|
145
|
-
|
|
188
|
+
const msg = {
|
|
146
189
|
type: type.Payload,
|
|
147
190
|
peer: id,
|
|
148
191
|
round: r,
|
|
149
192
|
payload: encoded
|
|
150
|
-
}
|
|
193
|
+
};
|
|
194
|
+
peer.send(msg);
|
|
195
|
+
console.log(`[${this.ownId}] send weight update to peer`, msg.peer, msg);
|
|
151
196
|
}
|
|
152
197
|
}
|
|
153
198
|
}));
|
|
@@ -156,37 +201,29 @@ export class Base extends Client {
|
|
|
156
201
|
throw new Error('error while sending weights');
|
|
157
202
|
}
|
|
158
203
|
}
|
|
159
|
-
if (this.aggregationResult === undefined) {
|
|
160
|
-
throw new TypeError('aggregation result promise is undefined');
|
|
161
|
-
}
|
|
162
204
|
// Wait for aggregation before proceeding to the next communication round.
|
|
163
205
|
// The current result will be used as payload for the eventual next communication round.
|
|
164
|
-
|
|
206
|
+
try {
|
|
207
|
+
result = await Promise.race([
|
|
208
|
+
this.aggregationResult,
|
|
209
|
+
timeout(undefined, "Timeout waiting on the aggregation result promise to resolve")
|
|
210
|
+
]);
|
|
211
|
+
}
|
|
212
|
+
catch (e) {
|
|
213
|
+
if (this.isDisconnected) {
|
|
214
|
+
return weights;
|
|
215
|
+
}
|
|
216
|
+
console.error(e);
|
|
217
|
+
break;
|
|
218
|
+
}
|
|
165
219
|
// There is at least one communication round remaining
|
|
166
220
|
if (r < this.aggregator.communicationRounds - 1) {
|
|
167
221
|
// Reuse the aggregation result
|
|
168
|
-
this.aggregationResult = this.aggregator.
|
|
222
|
+
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
169
223
|
}
|
|
170
224
|
}
|
|
171
225
|
// Reset the peers list for the next round
|
|
172
226
|
this.aggregator.resetNodes();
|
|
173
|
-
|
|
174
|
-
receivePayloads(connections, round) {
|
|
175
|
-
console.info(`[${this.ownId}] Accepting new contributions for round ${round}`);
|
|
176
|
-
connections.forEach(async (connection, peerId) => {
|
|
177
|
-
let receivedPayloads = 0;
|
|
178
|
-
do {
|
|
179
|
-
try {
|
|
180
|
-
const message = await waitMessageWithTimeout(connection, type.Payload);
|
|
181
|
-
const decoded = serialization.weights.decode(message.payload);
|
|
182
|
-
if (!this.aggregator.add(peerId, decoded, round, message.round)) {
|
|
183
|
-
console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`);
|
|
184
|
-
}
|
|
185
|
-
}
|
|
186
|
-
catch (e) {
|
|
187
|
-
console.warn(e instanceof Error ? e.message : e);
|
|
188
|
-
}
|
|
189
|
-
} while (++receivedPayloads < this.aggregator.communicationRounds);
|
|
190
|
-
});
|
|
227
|
+
return await this.aggregationResult;
|
|
191
228
|
}
|
|
192
229
|
}
|
|
@@ -58,13 +58,7 @@ export class Peer {
|
|
|
58
58
|
if (this.bufferSize === undefined) {
|
|
59
59
|
throw new Error('chunk without known buffer size');
|
|
60
60
|
}
|
|
61
|
-
|
|
62
|
-
// we would return this.bufferSize
|
|
63
|
-
// sadly, we are not there yet
|
|
64
|
-
//
|
|
65
|
-
// based on MDN, taking 16K seems to be a pretty safe
|
|
66
|
-
// and widely supported buffer size
|
|
67
|
-
return 16 * (1 << 10);
|
|
61
|
+
return this.bufferSize;
|
|
68
62
|
}
|
|
69
63
|
chunk(b) {
|
|
70
64
|
const messageID = this.sendCounter;
|
|
@@ -79,7 +73,8 @@ export class Peer {
|
|
|
79
73
|
}
|
|
80
74
|
const totalChunkCount = 1 + tail.count();
|
|
81
75
|
if (totalChunkCount > 0xFF) {
|
|
82
|
-
throw new Error(
|
|
76
|
+
throw new Error(`The payload is too big: ${totalChunkCount * this.maxChunkSize} bytes > 255,` +
|
|
77
|
+
' consider reducing the model size or increasing the chunk size');
|
|
83
78
|
}
|
|
84
79
|
const firstChunk = Buffer.alloc((b.length > this.maxChunkSize - FIRST_HEADER_SIZE)
|
|
85
80
|
? this.maxChunkSize
|
|
@@ -106,7 +101,7 @@ export class Peer {
|
|
|
106
101
|
});
|
|
107
102
|
}
|
|
108
103
|
signal(signal) {
|
|
109
|
-
// extract max buffer size
|
|
104
|
+
// extract max buffer size from the signal
|
|
110
105
|
if (signal.type === 'offer' || signal.type === 'answer') {
|
|
111
106
|
if (signal.sdp === undefined) {
|
|
112
107
|
throw new Error('signal answer|offer without session description');
|
|
@@ -138,8 +133,8 @@ export class Peer {
|
|
|
138
133
|
if (!Buffer.isBuffer(data) || data.length < HEADER_SIZE) {
|
|
139
134
|
throw new Error('received invalid message type');
|
|
140
135
|
}
|
|
141
|
-
const messageID = data.readUint16BE()
|
|
142
|
-
const chunkID = data.
|
|
136
|
+
const messageID = data.readUInt16BE(); //readUint16BE (case sensitive) fails at runtime
|
|
137
|
+
const chunkID = data.readUInt8(2); // same for readUint8
|
|
143
138
|
const received = this.receiving.get(messageID, {
|
|
144
139
|
total: undefined,
|
|
145
140
|
chunks: Map()
|
|
@@ -161,7 +156,7 @@ export class Peer {
|
|
|
161
156
|
if (total !== undefined) {
|
|
162
157
|
throw new Error('first header received twice');
|
|
163
158
|
}
|
|
164
|
-
const readTotal = data.
|
|
159
|
+
const readTotal = data.readUInt8(3);
|
|
165
160
|
total = readTotal;
|
|
166
161
|
chunk = Buffer.alloc(data.length - FIRST_HEADER_SIZE);
|
|
167
162
|
data.copy(chunk, 0, FIRST_HEADER_SIZE);
|
|
@@ -9,8 +9,12 @@ export class PeerPool {
|
|
|
9
9
|
this.id = id;
|
|
10
10
|
}
|
|
11
11
|
async shutdown() {
|
|
12
|
-
console.info(`[${this.id}]
|
|
13
|
-
|
|
12
|
+
console.info(`[${this.id}] is shutting down all its connections`);
|
|
13
|
+
// Add a timeout o.w. the promise hangs forever if the other peer is already disconnected
|
|
14
|
+
await Promise.race([
|
|
15
|
+
Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect())),
|
|
16
|
+
new Promise((res, _) => setTimeout(res, 1000)) // Wait for other peers to finish
|
|
17
|
+
]);
|
|
14
18
|
this.peers = Map();
|
|
15
19
|
}
|
|
16
20
|
signal(peerId, signal) {
|
|
@@ -9,7 +9,7 @@ export interface EventConnection {
|
|
|
9
9
|
disconnect: () => Promise<void>;
|
|
10
10
|
}
|
|
11
11
|
export declare function waitMessage<T extends type>(connection: EventConnection, type: T): Promise<NarrowMessage<T>>;
|
|
12
|
-
export declare function waitMessageWithTimeout<T extends type>(connection: EventConnection, type: T, timeoutMs?: number): Promise<NarrowMessage<T>>;
|
|
12
|
+
export declare function waitMessageWithTimeout<T extends type>(connection: EventConnection, type: T, timeoutMs?: number, errorMsg?: string): Promise<NarrowMessage<T>>;
|
|
13
13
|
export declare class PeerConnection extends EventEmitter<{
|
|
14
14
|
[K in type]: NarrowMessage<K>;
|
|
15
15
|
}> implements EventConnection {
|
|
@@ -12,8 +12,8 @@ export async function waitMessage(connection, type) {
|
|
|
12
12
|
});
|
|
13
13
|
});
|
|
14
14
|
}
|
|
15
|
-
export async function waitMessageWithTimeout(connection, type, timeoutMs) {
|
|
16
|
-
return await Promise.race([waitMessage(connection, type), timeout(timeoutMs)]);
|
|
15
|
+
export async function waitMessageWithTimeout(connection, type, timeoutMs, errorMsg = 'timeout') {
|
|
16
|
+
return await Promise.race([waitMessage(connection, type), timeout(timeoutMs, errorMsg)]);
|
|
17
17
|
}
|
|
18
18
|
export class PeerConnection extends EventEmitter {
|
|
19
19
|
_ownId;
|
|
@@ -41,7 +41,7 @@ export class PeerConnection extends EventEmitter {
|
|
|
41
41
|
}
|
|
42
42
|
this.emit(msg.type, msg);
|
|
43
43
|
});
|
|
44
|
-
this.peer.on('close', () => { console.warn('peer', this.peer.id, 'closed connection'); });
|
|
44
|
+
this.peer.on('close', () => { console.warn('From', this._ownId, ': peer', this.peer.id, 'closed connection'); });
|
|
45
45
|
await new Promise((resolve) => {
|
|
46
46
|
this.peer.on('connect', resolve);
|
|
47
47
|
});
|
|
@@ -1,22 +1,18 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { type MetadataKey, type MetadataValue, type WeightsContainer } from "../../index.js";
|
|
3
|
-
import { type NodeID } from "../types.js";
|
|
1
|
+
import { type WeightsContainer } from "../../index.js";
|
|
4
2
|
import { Base as Client } from "../base.js";
|
|
5
3
|
/**
|
|
6
4
|
* Client class that communicates with a centralized, federated server, when training
|
|
7
5
|
* a specific task in the federated setting.
|
|
8
6
|
*/
|
|
9
7
|
export declare class Base extends Client {
|
|
8
|
+
#private;
|
|
10
9
|
/**
|
|
11
10
|
* Arbitrary node id assigned to the federated server which we are communicating with.
|
|
12
11
|
* Indeed, the server acts as a node within the network. In the federated setting described
|
|
13
12
|
* by this client class, the server is the only node which we are communicating with.
|
|
14
13
|
*/
|
|
15
14
|
static readonly SERVER_NODE_ID = "federated-server-node-id";
|
|
16
|
-
|
|
17
|
-
* Map of metadata values for each node id.
|
|
18
|
-
*/
|
|
19
|
-
private metadataMap?;
|
|
15
|
+
get nbOfParticipants(): number;
|
|
20
16
|
/**
|
|
21
17
|
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
22
18
|
*/
|
|
@@ -31,24 +27,12 @@ export declare class Base extends Client {
|
|
|
31
27
|
* Disconnection process when user quits the task.
|
|
32
28
|
*/
|
|
33
29
|
disconnect(): Promise<void>;
|
|
30
|
+
onRoundBeginCommunication(): Promise<void>;
|
|
31
|
+
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
|
|
34
32
|
/**
|
|
35
33
|
* Send a message containing our local weight updates to the federated server.
|
|
36
34
|
* And waits for the server to reply with the most recent aggregated weights
|
|
37
35
|
* @param payload The weight updates to send
|
|
38
36
|
*/
|
|
39
37
|
private sendPayloadAndReceiveResult;
|
|
40
|
-
/**
|
|
41
|
-
* Waits for the server's result for its current (most recent) round and add it to our aggregator.
|
|
42
|
-
* Updates the aggregator's round if it's behind the server's.
|
|
43
|
-
*/
|
|
44
|
-
private receiveResult;
|
|
45
|
-
/**
|
|
46
|
-
* Fetch the metadata values maintained by the federated server, for a given metadata key.
|
|
47
|
-
* The values are indexed by node id.
|
|
48
|
-
* @param key The metadata key
|
|
49
|
-
* @returns The map of node id to metadata value
|
|
50
|
-
*/
|
|
51
|
-
receiveMetadataMap(key: MetadataKey): Promise<Map<NodeID, MetadataValue> | undefined>;
|
|
52
|
-
onRoundBeginCommunication(): Promise<void>;
|
|
53
|
-
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
|
|
54
38
|
}
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import { Map } from "immutable";
|
|
2
1
|
import { serialization, } from "../../index.js";
|
|
3
2
|
import { Base as Client } from "../base.js";
|
|
4
3
|
import { type } from "../messages.js";
|
|
@@ -15,10 +14,13 @@ export class Base extends Client {
|
|
|
15
14
|
* by this client class, the server is the only node which we are communicating with.
|
|
16
15
|
*/
|
|
17
16
|
static SERVER_NODE_ID = "federated-server-node-id";
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
17
|
+
// Total number of other federated contributors, including this client, excluding the server
|
|
18
|
+
// E.g., if 3 users are training a federated model, nbOfParticipants is 3
|
|
19
|
+
#nbOfParticipants = 1;
|
|
20
|
+
// the number of participants excluding the server
|
|
21
|
+
get nbOfParticipants() {
|
|
22
|
+
return this.#nbOfParticipants;
|
|
23
|
+
}
|
|
22
24
|
/**
|
|
23
25
|
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
24
26
|
*/
|
|
@@ -64,6 +66,32 @@ export class Base extends Client {
|
|
|
64
66
|
this.aggregator.setNodes(this.aggregator.nodes.delete(Base.SERVER_NODE_ID));
|
|
65
67
|
return Promise.resolve();
|
|
66
68
|
}
|
|
69
|
+
onRoundBeginCommunication() {
|
|
70
|
+
// Prepare the result promise for the incoming round
|
|
71
|
+
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
72
|
+
return Promise.resolve();
|
|
73
|
+
}
|
|
74
|
+
async onRoundEndCommunication(weights, round) {
|
|
75
|
+
// NB: For now, we suppose a fully-federated setting.
|
|
76
|
+
if (this.aggregationResult === undefined) {
|
|
77
|
+
throw new Error("local aggregation result was not set");
|
|
78
|
+
}
|
|
79
|
+
// Send our local contribution to the server
|
|
80
|
+
// and receive the most recent weights as an answer to our contribution
|
|
81
|
+
const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
|
|
82
|
+
if (serverResult !== undefined &&
|
|
83
|
+
this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
|
|
84
|
+
// Regular case: the server sends us its aggregation result which will serve our
|
|
85
|
+
// own aggregation result.
|
|
86
|
+
}
|
|
87
|
+
else {
|
|
88
|
+
// Unexpected case: for some reason, the server result is stale.
|
|
89
|
+
// We proceed to the next round without its result.
|
|
90
|
+
console.info(`[${this.ownId}] Server result is either stale or not received`);
|
|
91
|
+
this.aggregator.nextRound();
|
|
92
|
+
}
|
|
93
|
+
return await this.aggregationResult;
|
|
94
|
+
}
|
|
67
95
|
/**
|
|
68
96
|
* Send a message containing our local weight updates to the federated server.
|
|
69
97
|
* And waits for the server to reply with the most recent aggregated weights
|
|
@@ -76,17 +104,13 @@ export class Base extends Client {
|
|
|
76
104
|
round: this.aggregator.round,
|
|
77
105
|
};
|
|
78
106
|
this.server.send(msg);
|
|
79
|
-
//
|
|
80
|
-
|
|
81
|
-
}
|
|
82
|
-
/**
|
|
83
|
-
* Waits for the server's result for its current (most recent) round and add it to our aggregator.
|
|
84
|
-
* Updates the aggregator's round if it's behind the server's.
|
|
85
|
-
*/
|
|
86
|
-
async receiveResult() {
|
|
107
|
+
// Waits for the server's result for its current (most recent) round and add it to our aggregator.
|
|
108
|
+
// Updates the aggregator's round if it's behind the server's.
|
|
87
109
|
try {
|
|
88
|
-
|
|
110
|
+
// It is important than the client immediately awaits the server result or it may miss it
|
|
111
|
+
const { payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload);
|
|
89
112
|
const serverRound = round;
|
|
113
|
+
this.#nbOfParticipants = nbOfParticipants; // Save the current participants
|
|
90
114
|
// Store the server result only if it is not stale
|
|
91
115
|
if (this.aggregator.round <= round) {
|
|
92
116
|
const serverResult = serialization.weights.decode(payload);
|
|
@@ -101,51 +125,4 @@ export class Base extends Client {
|
|
|
101
125
|
console.error(e);
|
|
102
126
|
}
|
|
103
127
|
}
|
|
104
|
-
/**
|
|
105
|
-
* Fetch the metadata values maintained by the federated server, for a given metadata key.
|
|
106
|
-
* The values are indexed by node id.
|
|
107
|
-
* @param key The metadata key
|
|
108
|
-
* @returns The map of node id to metadata value
|
|
109
|
-
*/
|
|
110
|
-
async receiveMetadataMap(key) {
|
|
111
|
-
this.metadataMap = undefined;
|
|
112
|
-
const msg = {
|
|
113
|
-
type: type.ReceiveServerMetadata,
|
|
114
|
-
taskId: this.task.id,
|
|
115
|
-
nodeId: this.ownId,
|
|
116
|
-
round: this.aggregator.round,
|
|
117
|
-
key,
|
|
118
|
-
};
|
|
119
|
-
this.server.send(msg);
|
|
120
|
-
const received = await waitMessageWithTimeout(this.server, type.ReceiveServerMetadata);
|
|
121
|
-
if (received.metadataMap !== undefined) {
|
|
122
|
-
this.metadataMap = Map(received.metadataMap.filter(([_, v]) => v !== undefined));
|
|
123
|
-
}
|
|
124
|
-
return this.metadataMap;
|
|
125
|
-
}
|
|
126
|
-
onRoundBeginCommunication() {
|
|
127
|
-
// Prepare the result promise for the incoming round
|
|
128
|
-
this.aggregationResult = this.aggregator.receiveResult();
|
|
129
|
-
return Promise.resolve();
|
|
130
|
-
}
|
|
131
|
-
async onRoundEndCommunication(weights, round) {
|
|
132
|
-
// NB: For now, we suppose a fully-federated setting.
|
|
133
|
-
if (this.aggregationResult === undefined) {
|
|
134
|
-
throw new Error("local aggregation result was not set");
|
|
135
|
-
}
|
|
136
|
-
// Send our local contribution to the server
|
|
137
|
-
// and receive the most recent weights as an answer to our contribution
|
|
138
|
-
const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
|
|
139
|
-
if (serverResult !== undefined &&
|
|
140
|
-
this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
|
|
141
|
-
// Regular case: the server sends us its aggregation result which will serve our
|
|
142
|
-
// own aggregation result.
|
|
143
|
-
}
|
|
144
|
-
else {
|
|
145
|
-
// Unexpected case: for some reason, the server result is stale.
|
|
146
|
-
// We proceed to the next round without its result.
|
|
147
|
-
console.info(`[${this.ownId}] Server result is either stale or not received`);
|
|
148
|
-
this.aggregator.nextRound();
|
|
149
|
-
}
|
|
150
|
-
}
|
|
151
128
|
}
|
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
import { type client, type MetadataKey, type MetadataValue } from '../../index.js';
|
|
2
1
|
import { type weights } from '../../serialization/index.js';
|
|
3
2
|
import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
|
|
4
|
-
export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload |
|
|
3
|
+
export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | AssignNodeID;
|
|
5
4
|
export interface SendPayload {
|
|
6
5
|
type: type.SendPayload;
|
|
7
6
|
payload: weights.Encoded;
|
|
@@ -11,13 +10,6 @@ export interface ReceiveServerPayload {
|
|
|
11
10
|
type: type.ReceiveServerPayload;
|
|
12
11
|
payload: weights.Encoded;
|
|
13
12
|
round: number;
|
|
14
|
-
|
|
15
|
-
export interface ReceiveServerMetadata {
|
|
16
|
-
type: type.ReceiveServerMetadata;
|
|
17
|
-
nodeId: client.NodeID;
|
|
18
|
-
taskId: string;
|
|
19
|
-
round: number;
|
|
20
|
-
key: MetadataKey;
|
|
21
|
-
metadataMap?: Array<[client.NodeID, MetadataValue | undefined]>;
|
|
13
|
+
nbOfParticipants: number;
|
|
22
14
|
}
|
|
23
15
|
export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
|
package/dist/client/index.d.ts
CHANGED
|
@@ -4,5 +4,5 @@ export * as aggregator from '../aggregator/index.js';
|
|
|
4
4
|
export * as decentralized from './decentralized/index.js';
|
|
5
5
|
export * as federated from './federated/index.js';
|
|
6
6
|
export * as messages from './messages.js';
|
|
7
|
-
export
|
|
7
|
+
export { getClient, timeout } from './utils.js';
|
|
8
8
|
export { Local } from './local.js';
|