@epfml/discojs 3.0.1-p20241001093123.0 → 3.0.1-p20241014092014.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
- package/dist/aggregator/{base.js → aggregator.js} +48 -36
- package/dist/aggregator/get.d.ts +2 -2
- package/dist/aggregator/get.js +4 -4
- package/dist/aggregator/index.d.ts +1 -4
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +4 -4
- package/dist/aggregator/mean.js +5 -15
- package/dist/aggregator/secure.d.ts +4 -4
- package/dist/aggregator/secure.js +7 -17
- package/dist/client/client.d.ts +71 -17
- package/dist/client/client.js +115 -14
- package/dist/client/decentralized/decentralized_client.d.ts +11 -13
- package/dist/client/decentralized/decentralized_client.js +121 -84
- package/dist/client/decentralized/messages.d.ts +10 -4
- package/dist/client/decentralized/messages.js +7 -6
- package/dist/client/federated/federated_client.d.ts +1 -13
- package/dist/client/federated/federated_client.js +15 -94
- package/dist/client/federated/messages.d.ts +2 -7
- package/dist/client/local_client.d.ts +1 -0
- package/dist/client/local_client.js +3 -0
- package/dist/client/messages.d.ts +14 -7
- package/dist/client/messages.js +13 -11
- package/dist/default_tasks/cifar10.js +1 -1
- package/dist/default_tasks/lus_covid.js +1 -0
- package/dist/default_tasks/mnist.js +1 -1
- package/dist/default_tasks/simple_face.js +1 -0
- package/dist/default_tasks/titanic.js +1 -0
- package/dist/default_tasks/wikitext.js +1 -0
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +6 -8
- package/dist/training/disco.d.ts +4 -1
- package/dist/training/trainer.js +1 -1
- package/dist/utils/event_emitter.d.ts +3 -3
- package/dist/utils/event_emitter.js +10 -9
- package/package.json +1 -1
package/dist/client/client.js
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import axios from 'axios';
|
|
2
3
|
import { serialization } from '../index.js';
|
|
3
4
|
import { EventEmitter } from '../utils/event_emitter.js';
|
|
5
|
+
import { type } from "./messages.js";
|
|
6
|
+
const debug = createDebug("discojs:client");
|
|
4
7
|
/**
|
|
5
8
|
* Main, abstract, class representing a Disco client in a network, which handles
|
|
6
9
|
* communication with other nodes, be it peers or a server.
|
|
@@ -9,18 +12,25 @@ export class Client extends EventEmitter {
|
|
|
9
12
|
url;
|
|
10
13
|
task;
|
|
11
14
|
aggregator;
|
|
12
|
-
|
|
13
|
-
* Own ID provided by the network's server.
|
|
14
|
-
*/
|
|
15
|
+
// Own ID provided by the network's server.
|
|
15
16
|
_ownId;
|
|
17
|
+
// The network's server.
|
|
18
|
+
_server;
|
|
19
|
+
// The aggregator's result produced after aggregation.
|
|
20
|
+
aggregationResult;
|
|
16
21
|
/**
|
|
17
|
-
*
|
|
22
|
+
* When the server notifies clients to pause and wait until more
|
|
23
|
+
* participants join, we rely on this promise to wait
|
|
24
|
+
* until the server signals that the training can resume
|
|
18
25
|
*/
|
|
19
|
-
|
|
26
|
+
promiseForMoreParticipants = undefined;
|
|
20
27
|
/**
|
|
21
|
-
*
|
|
28
|
+
* When the server notifies the client that they can resume training
|
|
29
|
+
* after waiting for more participants, we want to be able to display what
|
|
30
|
+
* we were doing before waiting (training locally or updating our model).
|
|
31
|
+
* We use this attribute to store the status to rollback to when we stop waiting
|
|
22
32
|
*/
|
|
23
|
-
|
|
33
|
+
previousStatus;
|
|
24
34
|
constructor(url, // The network server's URL to connect to
|
|
25
35
|
task, // The client's corresponding task
|
|
26
36
|
aggregator) {
|
|
@@ -41,6 +51,94 @@ export class Client extends EventEmitter {
|
|
|
41
51
|
* Handles the disconnection process of the client from any sort of network server.
|
|
42
52
|
*/
|
|
43
53
|
async disconnect() { }
|
|
54
|
+
/**
|
|
55
|
+
* Emits the round status specified. It also stores the status emitted such that
|
|
56
|
+
* if the server tells the client to wait for more participants, it can display
|
|
57
|
+
* the waiting status and once enough participants join, it can display the previous status again
|
|
58
|
+
*/
|
|
59
|
+
saveAndEmit(status) {
|
|
60
|
+
this.previousStatus = status;
|
|
61
|
+
this.emit("status", status);
|
|
62
|
+
}
|
|
63
|
+
/**
|
|
64
|
+
* For both federated and decentralized clients, we listen to the server to tell
|
|
65
|
+
* us whether there are enough participants to train. If not, we pause until further notice.
|
|
66
|
+
* When a client connects to the server, the server answers with the session information (id,
|
|
67
|
+
* number of participants) and whether there are enough participants.
|
|
68
|
+
* When there are the server sends a new EnoughParticipant message to update the client.
|
|
69
|
+
*
|
|
70
|
+
* `setMessageInversionFlag` is used to address the following scenario:
|
|
71
|
+
* 1. Client 1 connect to the server
|
|
72
|
+
* 2. Server answers with message A containing "not enough participants"
|
|
73
|
+
* 3. Before A arrives a new client joins. There are enough participants now.
|
|
74
|
+
* 4. Server updates client 1 with message B saying "there are enough participants"
|
|
75
|
+
* 5. Due to network and message sizes message B can arrive before A.
|
|
76
|
+
* i.e. "there are enough participants" arrives before "not enough participants"
|
|
77
|
+
* ending up with client 1 thinking it needs to wait for more participants.
|
|
78
|
+
*
|
|
79
|
+
* To keep track of this message inversion, `setMessageInversionFlag`
|
|
80
|
+
* tells us whether a message inversion occurred (by setting a boolean to true)
|
|
81
|
+
*
|
|
82
|
+
* @param setMessageInversionFlag function flagging whether a message inversion occurred
|
|
83
|
+
* between a NewNodeInfo message and an EnoughParticipant message.
|
|
84
|
+
*/
|
|
85
|
+
setupServerCallbacks(setMessageInversionFlag) {
|
|
86
|
+
// Setup an event callback if the server signals that we should
|
|
87
|
+
// wait for more participants
|
|
88
|
+
this.server.on(type.WaitingForMoreParticipants, () => {
|
|
89
|
+
if (this.promiseForMoreParticipants !== undefined)
|
|
90
|
+
throw new Error("Server sent multiple WaitingForMoreParticipants messages");
|
|
91
|
+
debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`);
|
|
92
|
+
// Display the waiting status right away
|
|
93
|
+
this.emit("status", "not enough participants");
|
|
94
|
+
// Upon receiving a WaitingForMoreParticipants message,
|
|
95
|
+
// the client will await for this promise to resolve before sending its
|
|
96
|
+
// local weight update
|
|
97
|
+
this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
|
|
98
|
+
});
|
|
99
|
+
// As an example assume we need at least 2 participants to train,
|
|
100
|
+
// When two participants join almost at the same time, the server
|
|
101
|
+
// sends a NewNodeInfo with waitForMoreParticipants=true to the first participant
|
|
102
|
+
// and directly follows with an EnoughParticipants message when the 2nd participant joins
|
|
103
|
+
// However, the EnoughParticipants can arrive before the NewNodeInfo (which can be much bigger)
|
|
104
|
+
// so we check whether we received the EnoughParticipants before being assigned a node ID
|
|
105
|
+
this.server.once(type.EnoughParticipants, () => {
|
|
106
|
+
if (this._ownId === undefined) {
|
|
107
|
+
debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
|
|
108
|
+
setMessageInversionFlag();
|
|
109
|
+
}
|
|
110
|
+
});
|
|
111
|
+
}
|
|
112
|
+
/**
|
|
113
|
+
* Method called when the server notifies the client that there aren't enough
|
|
114
|
+
* participants (anymore) to start/continue training
|
|
115
|
+
* The method creates a promise that will resolve once the server notifies
|
|
116
|
+
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
117
|
+
* @returns a promise which resolves when enough participants joined the session
|
|
118
|
+
*/
|
|
119
|
+
async createPromiseForMoreParticipants() {
|
|
120
|
+
return new Promise((resolve) => {
|
|
121
|
+
// "once" is important because we can't resolve the same promise multiple times
|
|
122
|
+
this.server.once(type.EnoughParticipants, () => {
|
|
123
|
+
debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`);
|
|
124
|
+
// Emit the last status emitted before waiting if defined
|
|
125
|
+
if (this.previousStatus !== undefined)
|
|
126
|
+
this.emit("status", this.previousStatus);
|
|
127
|
+
resolve();
|
|
128
|
+
});
|
|
129
|
+
});
|
|
130
|
+
}
|
|
131
|
+
async waitForParticipantsIfNeeded() {
|
|
132
|
+
// we check if we are waiting for more participants before sending our weight update
|
|
133
|
+
if (this.waitingForMoreParticipants) {
|
|
134
|
+
// wait for the promise to resolve, which takes as long as it takes for new participants to join
|
|
135
|
+
debug(`[${shortenId(this.ownId)}] is awaiting the promise for more participants`);
|
|
136
|
+
this.emit("status", "not enough participants");
|
|
137
|
+
await this.promiseForMoreParticipants;
|
|
138
|
+
// Make sure to set the promise back to undefined once resolved
|
|
139
|
+
this.promiseForMoreParticipants = undefined;
|
|
140
|
+
}
|
|
141
|
+
}
|
|
44
142
|
/**
|
|
45
143
|
* Fetches the latest model available on the network's server, for the adequate task.
|
|
46
144
|
* @returns The latest model
|
|
@@ -54,13 +152,6 @@ export class Client extends EventEmitter {
|
|
|
54
152
|
const response = await axios.get(url.href, { responseType: 'arraybuffer' });
|
|
55
153
|
return await serialization.model.decode(new Uint8Array(response.data));
|
|
56
154
|
}
|
|
57
|
-
// Number of contributors to a collaborative session
|
|
58
|
-
// If decentralized, it should be the number of peers
|
|
59
|
-
// If federated, it should the number of participants excluding the server
|
|
60
|
-
// If local it should be 1
|
|
61
|
-
get nbOfParticipants() {
|
|
62
|
-
return this.aggregator.nodes.size; // overriden by the federated client
|
|
63
|
-
}
|
|
64
155
|
get ownId() {
|
|
65
156
|
if (this._ownId === undefined) {
|
|
66
157
|
throw new Error('the node is not connected');
|
|
@@ -73,4 +164,14 @@ export class Client extends EventEmitter {
|
|
|
73
164
|
}
|
|
74
165
|
return this._server;
|
|
75
166
|
}
|
|
167
|
+
/**
|
|
168
|
+
* Whether the client should wait until more
|
|
169
|
+
* participants join the session, i.e. a promise has been created
|
|
170
|
+
*/
|
|
171
|
+
get waitingForMoreParticipants() {
|
|
172
|
+
return this.promiseForMoreParticipants !== undefined;
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
export function shortenId(id) {
|
|
176
|
+
return id.slice(0, 4);
|
|
76
177
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { Model, WeightsContainer } from "../../index.js";
|
|
2
|
-
import { Client } from '../
|
|
2
|
+
import { Client } from '../client.js';
|
|
3
3
|
/**
|
|
4
4
|
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
5
5
|
* help of the network's server, yet only exchange payloads between each other. Communication
|
|
@@ -7,11 +7,8 @@ import { Client } from '../index.js';
|
|
|
7
7
|
* WebRTC for Node.js.
|
|
8
8
|
*/
|
|
9
9
|
export declare class DecentralizedClient extends Client {
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
*/
|
|
13
|
-
private pool?;
|
|
14
|
-
private connections?;
|
|
10
|
+
#private;
|
|
11
|
+
getNbOfParticipants(): number;
|
|
15
12
|
private get isDisconnected();
|
|
16
13
|
/**
|
|
17
14
|
* Public method called by disco.ts when starting training. This method sends
|
|
@@ -22,12 +19,6 @@ export declare class DecentralizedClient extends Client {
|
|
|
22
19
|
* peers network information.
|
|
23
20
|
*/
|
|
24
21
|
connect(): Promise<Model>;
|
|
25
|
-
/**
|
|
26
|
-
* Create a WebSocket connection with the server
|
|
27
|
-
* The client then waits for the server to forward it other client's network information.
|
|
28
|
-
* Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
|
|
29
|
-
*/
|
|
30
|
-
private connectServer;
|
|
31
22
|
disconnect(): Promise<void>;
|
|
32
23
|
/**
|
|
33
24
|
* At the beginning of a round, each peer tells the server it is ready to proceed
|
|
@@ -38,6 +29,13 @@ export declare class DecentralizedClient extends Client {
|
|
|
38
29
|
*
|
|
39
30
|
*/
|
|
40
31
|
onRoundBeginCommunication(): Promise<void>;
|
|
32
|
+
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
33
|
+
/**
|
|
34
|
+
* Signal to the server that we are ready to exchange weights.
|
|
35
|
+
* Once enough peers are ready, the server sends the list of peers for this round
|
|
36
|
+
* and the peers can establish peer-to-peer connections with each other.
|
|
37
|
+
*/
|
|
38
|
+
private establishPeerConnections;
|
|
41
39
|
/**
|
|
42
40
|
* At each communication rounds, awaits peers contributions and add them to the client's aggregator.
|
|
43
41
|
* This method is used as callback by getPeers when connecting to the rounds' peers
|
|
@@ -45,5 +43,5 @@ export declare class DecentralizedClient extends Client {
|
|
|
45
43
|
* @param round
|
|
46
44
|
*/
|
|
47
45
|
private receivePayloads;
|
|
48
|
-
|
|
46
|
+
private exchangeWeightUpdates;
|
|
49
47
|
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import { Map, Set } from 'immutable';
|
|
3
3
|
import { serialization } from "../../index.js";
|
|
4
|
-
import { Client } from '../
|
|
4
|
+
import { Client, shortenId } from '../client.js';
|
|
5
5
|
import { type } from '../messages.js';
|
|
6
6
|
import { timeout } from '../utils.js';
|
|
7
7
|
import { WebSocketServer, waitMessage, waitMessageWithTimeout } from '../event_connection.js';
|
|
@@ -18,8 +18,12 @@ export class DecentralizedClient extends Client {
|
|
|
18
18
|
/**
|
|
19
19
|
* The pool of peers to communicate with during the current training round.
|
|
20
20
|
*/
|
|
21
|
-
pool;
|
|
22
|
-
connections;
|
|
21
|
+
#pool;
|
|
22
|
+
#connections;
|
|
23
|
+
getNbOfParticipants() {
|
|
24
|
+
const nbOfParticipants = this.aggregator.nodes.size;
|
|
25
|
+
return nbOfParticipants === 0 ? 1 : nbOfParticipants;
|
|
26
|
+
}
|
|
23
27
|
// Used to handle timeouts and promise resolving after calling disconnect
|
|
24
28
|
get isDisconnected() {
|
|
25
29
|
return this._server === undefined;
|
|
@@ -46,42 +50,48 @@ export class DecentralizedClient extends Client {
|
|
|
46
50
|
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
47
51
|
}
|
|
48
52
|
serverURL.pathname += `decentralized/${this.task.id}`;
|
|
49
|
-
|
|
53
|
+
// Create a WebSocket connection with the server
|
|
54
|
+
// The client then waits for the server to forward it other client's network information.
|
|
55
|
+
// Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
|
|
56
|
+
this._server = await WebSocketServer.connect(serverURL, messages.isMessageFromServer, messages.isMessageToServer);
|
|
57
|
+
this.server.on(type.SignalForPeer, (event) => {
|
|
58
|
+
if (this.#pool === undefined)
|
|
59
|
+
throw new Error('received signal but peer pool is undefined');
|
|
60
|
+
// Create a WebRTC connection with the peer
|
|
61
|
+
this.#pool.signal(event.peer, event.signal);
|
|
62
|
+
});
|
|
63
|
+
// c.f. setupServerCallbacks doc for explanation
|
|
64
|
+
let receivedEnoughParticipants = false;
|
|
65
|
+
this.setupServerCallbacks(() => receivedEnoughParticipants = true);
|
|
50
66
|
const msg = {
|
|
51
67
|
type: type.ClientConnected
|
|
52
68
|
};
|
|
53
69
|
this.server.send(msg);
|
|
54
|
-
const { id } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
|
|
55
|
-
|
|
70
|
+
const { id, waitForMoreParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
|
|
71
|
+
// This should come right after receiving the message to make sure
|
|
72
|
+
// we don't miss a subsequent message from the server
|
|
73
|
+
// We check if the server is telling us to wait for more participants
|
|
74
|
+
// and we also check if a EnoughParticipant message ended up arriving
|
|
75
|
+
// before the NewNodeInfo
|
|
76
|
+
if (waitForMoreParticipants && !receivedEnoughParticipants) {
|
|
77
|
+
// Create a promise that resolves when enough participants join
|
|
78
|
+
// The client will await this promise before sending its local weight update
|
|
79
|
+
this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
|
|
80
|
+
}
|
|
81
|
+
debug(`[${shortenId(id)}] assigned id generated by server`);
|
|
56
82
|
if (this._ownId !== undefined) {
|
|
57
83
|
throw new Error('received id from server but was already received');
|
|
58
84
|
}
|
|
59
85
|
this._ownId = id;
|
|
60
|
-
this
|
|
86
|
+
this.#pool = new PeerPool(id);
|
|
61
87
|
return model;
|
|
62
88
|
}
|
|
63
|
-
/**
|
|
64
|
-
* Create a WebSocket connection with the server
|
|
65
|
-
* The client then waits for the server to forward it other client's network information.
|
|
66
|
-
* Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
|
|
67
|
-
*/
|
|
68
|
-
async connectServer(url) {
|
|
69
|
-
const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
|
|
70
|
-
server.on(type.SignalForPeer, (event) => {
|
|
71
|
-
if (this.pool === undefined) {
|
|
72
|
-
throw new Error('received signal but peer pool is undefined');
|
|
73
|
-
}
|
|
74
|
-
// Create a WebRTC connection with the peer
|
|
75
|
-
this.pool.signal(event.peer, event.signal);
|
|
76
|
-
});
|
|
77
|
-
return server;
|
|
78
|
-
}
|
|
79
89
|
async disconnect() {
|
|
80
90
|
// Disconnect from peers
|
|
81
|
-
await this
|
|
82
|
-
this
|
|
83
|
-
if (this
|
|
84
|
-
const peers = this
|
|
91
|
+
await this.#pool?.shutdown();
|
|
92
|
+
this.#pool = undefined;
|
|
93
|
+
if (this.#connections !== undefined) {
|
|
94
|
+
const peers = this.#connections.keySeq().toSet();
|
|
85
95
|
this.aggregator.setNodes(this.aggregator.nodes.subtract(peers));
|
|
86
96
|
}
|
|
87
97
|
// Disconnect from server
|
|
@@ -99,13 +109,41 @@ export class DecentralizedClient extends Client {
|
|
|
99
109
|
*
|
|
100
110
|
*/
|
|
101
111
|
async onRoundBeginCommunication() {
|
|
112
|
+
// Notify the server we want to join the next round so that the server
|
|
113
|
+
// waits for us to be ready before sending the list of peers for the round
|
|
114
|
+
this.server.send({ type: type.JoinRound });
|
|
115
|
+
// Store the promise for the current round's aggregation result.
|
|
116
|
+
// We will await for it to resolve at the end of the round when exchanging weight updates.
|
|
117
|
+
this.aggregationResult = this.aggregator.getPromiseForAggregation();
|
|
118
|
+
this.saveAndEmit("local training");
|
|
119
|
+
return Promise.resolve();
|
|
120
|
+
}
|
|
121
|
+
async onRoundEndCommunication(weights) {
|
|
122
|
+
if (this.aggregationResult === undefined) {
|
|
123
|
+
throw new TypeError('aggregation result promise is undefined');
|
|
124
|
+
}
|
|
125
|
+
// Save the status in case participants leave and we switch to waiting for more participants
|
|
126
|
+
// Once enough new participants join we can display the previous status again
|
|
127
|
+
this.saveAndEmit("connecting to peers");
|
|
128
|
+
// First we check if we are waiting for more participants before sending our weight update
|
|
129
|
+
await this.waitForParticipantsIfNeeded();
|
|
130
|
+
// Create peer-to-peer connections with all peers for the round
|
|
131
|
+
await this.establishPeerConnections();
|
|
132
|
+
// Exchange weight updates with peers and return aggregated weights
|
|
133
|
+
return await this.exchangeWeightUpdates(weights);
|
|
134
|
+
}
|
|
135
|
+
/**
|
|
136
|
+
* Signal to the server that we are ready to exchange weights.
|
|
137
|
+
* Once enough peers are ready, the server sends the list of peers for this round
|
|
138
|
+
* and the peers can establish peer-to-peer connections with each other.
|
|
139
|
+
*/
|
|
140
|
+
async establishPeerConnections() {
|
|
102
141
|
if (this.server === undefined) {
|
|
103
142
|
throw new Error("peer's server is undefined, make sure to call `client.connect()` first");
|
|
104
143
|
}
|
|
105
|
-
if (this
|
|
144
|
+
if (this.#pool === undefined) {
|
|
106
145
|
throw new Error('peer pool is undefined, make sure to call `client.connect()` first');
|
|
107
146
|
}
|
|
108
|
-
this.emit("status", "Retrieving peers' information");
|
|
109
147
|
// Reset peers list at each round of training to make sure client works with an updated peers
|
|
110
148
|
// list, maintained by the server. Adds any received weights to the aggregator.
|
|
111
149
|
// Tell the server we are ready for the next round
|
|
@@ -113,33 +151,29 @@ export class DecentralizedClient extends Client {
|
|
|
113
151
|
this.server.send(readyMessage);
|
|
114
152
|
// Wait for the server to answer with the list of peers for the round
|
|
115
153
|
try {
|
|
116
|
-
debug(`[${this.ownId}] is waiting for peer list for round ${this.aggregator.round}`);
|
|
117
|
-
const receivedMessage = await
|
|
154
|
+
debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`);
|
|
155
|
+
const receivedMessage = await waitMessage(this.server, type.PeersForRound);
|
|
118
156
|
const peers = Set(receivedMessage.peers);
|
|
119
157
|
if (this.ownId !== undefined && peers.has(this.ownId)) {
|
|
120
158
|
throw new Error('received peer list contains our own id');
|
|
121
159
|
}
|
|
122
160
|
// Store the list of peers for the current round including ourselves
|
|
123
161
|
this.aggregator.setNodes(peers.add(this.ownId));
|
|
162
|
+
this.aggregator.setRound(receivedMessage.aggregationRound); // the server gives us the round number
|
|
124
163
|
// Initiate peer to peer connections with each peer
|
|
125
164
|
// When connected, create a promise waiting for each peer's round contribution
|
|
126
|
-
const connections = await this
|
|
127
|
-
// Init receipt of peers weights
|
|
128
|
-
//
|
|
129
|
-
|
|
130
|
-
(
|
|
131
|
-
|
|
132
|
-
this.connections = connections;
|
|
165
|
+
const connections = await this.#pool.getPeers(peers, this.server,
|
|
166
|
+
// Init receipt of peers weights. this awaits the peer's
|
|
167
|
+
// weight update and adds it to our aggregator upon reception
|
|
168
|
+
(conn) => this.receivePayloads(conn));
|
|
169
|
+
debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
|
|
170
|
+
this.#connections = connections;
|
|
133
171
|
}
|
|
134
172
|
catch (e) {
|
|
135
|
-
debug(`Error for [${this.ownId}] while beginning round: %o`, e);
|
|
173
|
+
debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
|
|
136
174
|
this.aggregator.setNodes(Set(this.ownId));
|
|
137
|
-
this
|
|
175
|
+
this.#connections = Map();
|
|
138
176
|
}
|
|
139
|
-
// Store the promise for the current round's aggregation result.
|
|
140
|
-
// We will await for it to resolve at the end of the round when exchanging weight updates.
|
|
141
|
-
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
142
|
-
this.emit("status", "Training the model on the data you connected");
|
|
143
177
|
}
|
|
144
178
|
/**
|
|
145
179
|
* At each communication rounds, awaits peers contributions and add them to the client's aggregator.
|
|
@@ -147,66 +181,71 @@ export class DecentralizedClient extends Client {
|
|
|
147
181
|
* @param connections
|
|
148
182
|
* @param round
|
|
149
183
|
*/
|
|
150
|
-
receivePayloads(connections
|
|
184
|
+
receivePayloads(connections) {
|
|
151
185
|
connections.forEach(async (connection, peerId) => {
|
|
152
|
-
let currentCommunicationRounds = 0;
|
|
153
186
|
debug(`waiting for peer ${peerId}`);
|
|
154
|
-
|
|
187
|
+
for (let r = 0; r < this.aggregator.communicationRounds; r++) {
|
|
155
188
|
try {
|
|
156
189
|
const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId);
|
|
157
190
|
const decoded = serialization.weights.decode(message.payload);
|
|
158
|
-
if (!this.aggregator.
|
|
159
|
-
debug(`[${this.ownId}] failed to add contribution from peer ${peerId}`);
|
|
191
|
+
if (!this.aggregator.isValidContribution(peerId, message.aggregationRound)) {
|
|
192
|
+
debug(`[${shortenId(this.ownId)}] failed to add contribution from peer ${shortenId(peerId)}`);
|
|
193
|
+
}
|
|
194
|
+
else {
|
|
195
|
+
debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` +
|
|
196
|
+
` for round (%d, %d)`, message.aggregationRound, message.communicationRound);
|
|
197
|
+
this.aggregator.once("aggregation", () => debug(`[${shortenId(this.ownId)}] aggregated the model` +
|
|
198
|
+
` for round (%d, %d)`, message.aggregationRound, message.communicationRound));
|
|
199
|
+
this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound);
|
|
160
200
|
}
|
|
161
201
|
}
|
|
162
202
|
catch (e) {
|
|
163
203
|
if (this.isDisconnected)
|
|
164
204
|
return;
|
|
165
|
-
debug(`Error for [${this.ownId}] while receiving payloads: %o`, e);
|
|
205
|
+
debug(`Error for [${shortenId(this.ownId)}] while receiving payloads: %o`, e);
|
|
166
206
|
}
|
|
167
|
-
}
|
|
207
|
+
}
|
|
168
208
|
});
|
|
169
209
|
}
|
|
170
|
-
async
|
|
210
|
+
async exchangeWeightUpdates(weights) {
|
|
171
211
|
if (this.aggregationResult === undefined) {
|
|
172
212
|
throw new TypeError('aggregation result promise is undefined');
|
|
173
213
|
}
|
|
174
|
-
this.
|
|
214
|
+
this.saveAndEmit("updating model");
|
|
175
215
|
// Perform the required communication rounds. Each communication round consists in sending our local payload,
|
|
176
216
|
// followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
|
|
177
217
|
// A communication round's payload is the aggregation result of the previous communication round. The first
|
|
178
218
|
// communication round simply sends our training result, i.e. model weights updates. This scheme allows for
|
|
179
219
|
// the aggregator to define any complex multi-round aggregation mechanism.
|
|
180
220
|
let result = weights;
|
|
181
|
-
for (let
|
|
221
|
+
for (let communicationRound = 0; communicationRound < this.aggregator.communicationRounds; communicationRound++) {
|
|
222
|
+
const connections = this.#connections;
|
|
223
|
+
if (connections === undefined)
|
|
224
|
+
throw new Error("peer's connections is undefined");
|
|
182
225
|
// Generate our payloads for this communication round and send them to all ready connected peers
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
}
|
|
190
|
-
else {
|
|
191
|
-
const peer = this.connections?.get(id);
|
|
192
|
-
if (peer !== undefined) {
|
|
193
|
-
const encoded = await serialization.weights.encode(payload);
|
|
194
|
-
const msg = {
|
|
195
|
-
type: type.Payload,
|
|
196
|
-
peer: id,
|
|
197
|
-
round: r,
|
|
198
|
-
payload: encoded
|
|
199
|
-
};
|
|
200
|
-
peer.send(msg);
|
|
201
|
-
debug(`[${this.ownId}] send weight update to peer ${msg.peer}: %O`, msg);
|
|
202
|
-
}
|
|
203
|
-
}
|
|
204
|
-
}));
|
|
226
|
+
const payloads = this.aggregator.makePayloads(result);
|
|
227
|
+
payloads.forEach(async (payload, id) => {
|
|
228
|
+
// add our own contribution to the aggregator
|
|
229
|
+
if (id === this.ownId) {
|
|
230
|
+
this.aggregator.add(this.ownId, payload, this.aggregator.round, communicationRound);
|
|
231
|
+
return;
|
|
205
232
|
}
|
|
206
|
-
|
|
207
|
-
|
|
233
|
+
// Send our payload to each peer
|
|
234
|
+
const peer = connections.get(id);
|
|
235
|
+
if (peer !== undefined) {
|
|
236
|
+
const encoded = await serialization.weights.encode(payload);
|
|
237
|
+
const msg = {
|
|
238
|
+
type: type.Payload,
|
|
239
|
+
peer: id,
|
|
240
|
+
aggregationRound: this.aggregator.round,
|
|
241
|
+
communicationRound,
|
|
242
|
+
payload: encoded
|
|
243
|
+
};
|
|
244
|
+
peer.send(msg);
|
|
245
|
+
debug(`[${shortenId(this.ownId)}] send weight update to peer ${shortenId(msg.peer)}` +
|
|
246
|
+
` for round (%d, %d)`, this.aggregator.round, communicationRound);
|
|
208
247
|
}
|
|
209
|
-
}
|
|
248
|
+
});
|
|
210
249
|
// Wait for aggregation before proceeding to the next communication round.
|
|
211
250
|
// The current result will be used as payload for the eventual next communication round.
|
|
212
251
|
try {
|
|
@@ -219,17 +258,15 @@ export class DecentralizedClient extends Client {
|
|
|
219
258
|
if (this.isDisconnected) {
|
|
220
259
|
return weights;
|
|
221
260
|
}
|
|
222
|
-
debug(`[${this.ownId}] while waiting for aggregation: %o`, e);
|
|
261
|
+
debug(`[${shortenId(this.ownId)}] while waiting for aggregation: %o`, e);
|
|
223
262
|
break;
|
|
224
263
|
}
|
|
225
264
|
// There is at least one communication round remaining
|
|
226
|
-
if (
|
|
265
|
+
if (communicationRound < this.aggregator.communicationRounds - 1) {
|
|
227
266
|
// Reuse the aggregation result
|
|
228
|
-
this.aggregationResult =
|
|
267
|
+
this.aggregationResult = this.aggregator.getPromiseForAggregation();
|
|
229
268
|
}
|
|
230
269
|
}
|
|
231
|
-
// Reset the peers list for the next round
|
|
232
|
-
this.aggregator.resetNodes();
|
|
233
270
|
return await this.aggregationResult;
|
|
234
271
|
}
|
|
235
272
|
}
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import { weights } from '../../serialization/index.js';
|
|
2
2
|
import { type SignalData } from './peer.js';
|
|
3
3
|
import { type NodeID } from '../types.js';
|
|
4
|
-
import { type
|
|
4
|
+
import { type } from '../messages.js';
|
|
5
|
+
import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
|
|
5
6
|
export interface NewDecentralizedNodeInfo {
|
|
6
7
|
type: type.NewDecentralizedNodeInfo;
|
|
7
8
|
id: NodeID;
|
|
@@ -12,21 +13,26 @@ export interface SignalForPeer {
|
|
|
12
13
|
peer: NodeID;
|
|
13
14
|
signal: SignalData;
|
|
14
15
|
}
|
|
16
|
+
export interface JoinRound {
|
|
17
|
+
type: type.JoinRound;
|
|
18
|
+
}
|
|
15
19
|
export interface PeerIsReady {
|
|
16
20
|
type: type.PeerIsReady;
|
|
17
21
|
}
|
|
18
22
|
export interface PeersForRound {
|
|
19
23
|
type: type.PeersForRound;
|
|
20
24
|
peers: NodeID[];
|
|
25
|
+
aggregationRound: number;
|
|
21
26
|
}
|
|
22
27
|
export interface Payload {
|
|
23
28
|
type: type.Payload;
|
|
24
29
|
peer: NodeID;
|
|
25
|
-
|
|
30
|
+
aggregationRound: number;
|
|
31
|
+
communicationRound: number;
|
|
26
32
|
payload: weights.Encoded;
|
|
27
33
|
}
|
|
28
|
-
export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound;
|
|
29
|
-
export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady;
|
|
34
|
+
export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound | WaitingForMoreParticipants | EnoughParticipants;
|
|
35
|
+
export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | JoinRound;
|
|
30
36
|
export type PeerMessage = Payload;
|
|
31
37
|
export declare function isMessageFromServer(o: unknown): o is MessageFromServer;
|
|
32
38
|
export declare function isMessageToServer(o: unknown): o is MessageToServer;
|
|
@@ -2,9 +2,8 @@ import { weights } from '../../serialization/index.js';
|
|
|
2
2
|
import { isNodeID } from '../types.js';
|
|
3
3
|
import { type, hasMessageType } from '../messages.js';
|
|
4
4
|
export function isMessageFromServer(o) {
|
|
5
|
-
if (!hasMessageType(o))
|
|
5
|
+
if (!hasMessageType(o))
|
|
6
6
|
return false;
|
|
7
|
-
}
|
|
8
7
|
switch (o.type) {
|
|
9
8
|
case type.NewDecentralizedNodeInfo:
|
|
10
9
|
return 'id' in o && isNodeID(o.id) &&
|
|
@@ -15,28 +14,30 @@ export function isMessageFromServer(o) {
|
|
|
15
14
|
'signal' in o; // TODO check signal content?
|
|
16
15
|
case type.PeersForRound:
|
|
17
16
|
return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID);
|
|
17
|
+
case type.WaitingForMoreParticipants:
|
|
18
|
+
case type.EnoughParticipants:
|
|
19
|
+
return true;
|
|
18
20
|
}
|
|
19
21
|
return false;
|
|
20
22
|
}
|
|
21
23
|
export function isMessageToServer(o) {
|
|
22
|
-
if (!hasMessageType(o))
|
|
24
|
+
if (!hasMessageType(o))
|
|
23
25
|
return false;
|
|
24
|
-
}
|
|
25
26
|
switch (o.type) {
|
|
26
27
|
case type.ClientConnected:
|
|
27
28
|
return true;
|
|
28
29
|
case type.SignalForPeer:
|
|
29
30
|
return 'peer' in o && isNodeID(o.peer) &&
|
|
30
31
|
'signal' in o; // TODO check signal content?
|
|
32
|
+
case type.JoinRound:
|
|
31
33
|
case type.PeerIsReady:
|
|
32
34
|
return true;
|
|
33
35
|
}
|
|
34
36
|
return false;
|
|
35
37
|
}
|
|
36
38
|
export function isPeerMessage(o) {
|
|
37
|
-
if (!hasMessageType(o))
|
|
39
|
+
if (!hasMessageType(o))
|
|
38
40
|
return false;
|
|
39
|
-
}
|
|
40
41
|
switch (o.type) {
|
|
41
42
|
case type.Payload:
|
|
42
43
|
return ('peer' in o && isNodeID(o.peer) &&
|
|
@@ -6,25 +6,13 @@ import { Client } from "../client.js";
|
|
|
6
6
|
*/
|
|
7
7
|
export declare class FederatedClient extends Client {
|
|
8
8
|
#private;
|
|
9
|
-
|
|
10
|
-
/**
|
|
11
|
-
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
12
|
-
*/
|
|
13
|
-
private connectServer;
|
|
9
|
+
getNbOfParticipants(): number;
|
|
14
10
|
/**
|
|
15
11
|
* Initializes the connection to the server, gets our node ID
|
|
16
12
|
* as well as the latest training information: latest global model, current round and
|
|
17
13
|
* whether we are waiting for more participants.
|
|
18
14
|
*/
|
|
19
15
|
connect(): Promise<Model>;
|
|
20
|
-
/**
|
|
21
|
-
* Method called when the server notifies the client that there aren't enough
|
|
22
|
-
* participants (anymore) to start/continue training
|
|
23
|
-
* The method creates a promise that will resolve once the server notifies
|
|
24
|
-
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
25
|
-
* @returns a promise which resolves when enough participants joined the session
|
|
26
|
-
*/
|
|
27
|
-
private waitForMoreParticipants;
|
|
28
16
|
/**
|
|
29
17
|
* Disconnection process when user quits the task.
|
|
30
18
|
*/
|