@epfml/discojs 3.0.1-p20241001093123.0 → 3.0.1-p20241014092014.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
  2. package/dist/aggregator/{base.js → aggregator.js} +48 -36
  3. package/dist/aggregator/get.d.ts +2 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/index.d.ts +1 -4
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +4 -4
  8. package/dist/aggregator/mean.js +5 -15
  9. package/dist/aggregator/secure.d.ts +4 -4
  10. package/dist/aggregator/secure.js +7 -17
  11. package/dist/client/client.d.ts +71 -17
  12. package/dist/client/client.js +115 -14
  13. package/dist/client/decentralized/decentralized_client.d.ts +11 -13
  14. package/dist/client/decentralized/decentralized_client.js +121 -84
  15. package/dist/client/decentralized/messages.d.ts +10 -4
  16. package/dist/client/decentralized/messages.js +7 -6
  17. package/dist/client/federated/federated_client.d.ts +1 -13
  18. package/dist/client/federated/federated_client.js +15 -94
  19. package/dist/client/federated/messages.d.ts +2 -7
  20. package/dist/client/local_client.d.ts +1 -0
  21. package/dist/client/local_client.js +3 -0
  22. package/dist/client/messages.d.ts +14 -7
  23. package/dist/client/messages.js +13 -11
  24. package/dist/default_tasks/cifar10.js +1 -1
  25. package/dist/default_tasks/lus_covid.js +1 -0
  26. package/dist/default_tasks/mnist.js +1 -1
  27. package/dist/default_tasks/simple_face.js +1 -0
  28. package/dist/default_tasks/titanic.js +1 -0
  29. package/dist/default_tasks/wikitext.js +1 -0
  30. package/dist/task/training_information.d.ts +1 -2
  31. package/dist/task/training_information.js +6 -8
  32. package/dist/training/disco.d.ts +4 -1
  33. package/dist/training/trainer.js +1 -1
  34. package/dist/utils/event_emitter.d.ts +3 -3
  35. package/dist/utils/event_emitter.js +10 -9
  36. package/package.json +1 -1
@@ -1,6 +1,9 @@
1
+ import createDebug from "debug";
1
2
  import axios from 'axios';
2
3
  import { serialization } from '../index.js';
3
4
  import { EventEmitter } from '../utils/event_emitter.js';
5
+ import { type } from "./messages.js";
6
+ const debug = createDebug("discojs:client");
4
7
  /**
5
8
  * Main, abstract, class representing a Disco client in a network, which handles
6
9
  * communication with other nodes, be it peers or a server.
@@ -9,18 +12,25 @@ export class Client extends EventEmitter {
9
12
  url;
10
13
  task;
11
14
  aggregator;
12
- /**
13
- * Own ID provided by the network's server.
14
- */
15
+ // Own ID provided by the network's server.
15
16
  _ownId;
17
+ // The network's server.
18
+ _server;
19
+ // The aggregator's result produced after aggregation.
20
+ aggregationResult;
16
21
  /**
17
- * The network's server.
22
+ * When the server notifies clients to pause and wait until more
23
+ * participants join, we rely on this promise to wait
24
+ * until the server signals that the training can resume
18
25
  */
19
- _server;
26
+ promiseForMoreParticipants = undefined;
20
27
  /**
21
- * The aggregator's result produced after aggregation.
28
+ * When the server notifies the client that they can resume training
29
+ * after waiting for more participants, we want to be able to display what
30
+ * we were doing before waiting (training locally or updating our model).
31
+ * We use this attribute to store the status to rollback to when we stop waiting
22
32
  */
23
- aggregationResult;
33
+ previousStatus;
24
34
  constructor(url, // The network server's URL to connect to
25
35
  task, // The client's corresponding task
26
36
  aggregator) {
@@ -41,6 +51,94 @@ export class Client extends EventEmitter {
41
51
  * Handles the disconnection process of the client from any sort of network server.
42
52
  */
43
53
  async disconnect() { }
54
+ /**
55
+ * Emits the round status specified. It also stores the status emitted such that
56
+ * if the server tells the client to wait for more participants, it can display
57
+ * the waiting status and once enough participants join, it can display the previous status again
58
+ */
59
+ saveAndEmit(status) {
60
+ this.previousStatus = status;
61
+ this.emit("status", status);
62
+ }
63
+ /**
64
+ * For both federated and decentralized clients, we listen to the server to tell
65
+ * us whether there are enough participants to train. If not, we pause until further notice.
66
+ * When a client connects to the server, the server answers with the session information (id,
67
+ * number of participants) and whether there are enough participants.
68
+ * When there are the server sends a new EnoughParticipant message to update the client.
69
+ *
70
+ * `setMessageInversionFlag` is used to address the following scenario:
71
+ * 1. Client 1 connect to the server
72
+ * 2. Server answers with message A containing "not enough participants"
73
+ * 3. Before A arrives a new client joins. There are enough participants now.
74
+ * 4. Server updates client 1 with message B saying "there are enough participants"
75
+ * 5. Due to network and message sizes message B can arrive before A.
76
+ * i.e. "there are enough participants" arrives before "not enough participants"
77
+ * ending up with client 1 thinking it needs to wait for more participants.
78
+ *
79
+ * To keep track of this message inversion, `setMessageInversionFlag`
80
+ * tells us whether a message inversion occurred (by setting a boolean to true)
81
+ *
82
+ * @param setMessageInversionFlag function flagging whether a message inversion occurred
83
+ * between a NewNodeInfo message and an EnoughParticipant message.
84
+ */
85
+ setupServerCallbacks(setMessageInversionFlag) {
86
+ // Setup an event callback if the server signals that we should
87
+ // wait for more participants
88
+ this.server.on(type.WaitingForMoreParticipants, () => {
89
+ if (this.promiseForMoreParticipants !== undefined)
90
+ throw new Error("Server sent multiple WaitingForMoreParticipants messages");
91
+ debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`);
92
+ // Display the waiting status right away
93
+ this.emit("status", "not enough participants");
94
+ // Upon receiving a WaitingForMoreParticipants message,
95
+ // the client will await for this promise to resolve before sending its
96
+ // local weight update
97
+ this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
98
+ });
99
+ // As an example assume we need at least 2 participants to train,
100
+ // When two participants join almost at the same time, the server
101
+ // sends a NewNodeInfo with waitForMoreParticipants=true to the first participant
102
+ // and directly follows with an EnoughParticipants message when the 2nd participant joins
103
+ // However, the EnoughParticipants can arrive before the NewNodeInfo (which can be much bigger)
104
+ // so we check whether we received the EnoughParticipants before being assigned a node ID
105
+ this.server.once(type.EnoughParticipants, () => {
106
+ if (this._ownId === undefined) {
107
+ debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
108
+ setMessageInversionFlag();
109
+ }
110
+ });
111
+ }
112
+ /**
113
+ * Method called when the server notifies the client that there aren't enough
114
+ * participants (anymore) to start/continue training
115
+ * The method creates a promise that will resolve once the server notifies
116
+ * the client that the training can resume via a subsequent EnoughParticipants message
117
+ * @returns a promise which resolves when enough participants joined the session
118
+ */
119
+ async createPromiseForMoreParticipants() {
120
+ return new Promise((resolve) => {
121
+ // "once" is important because we can't resolve the same promise multiple times
122
+ this.server.once(type.EnoughParticipants, () => {
123
+ debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`);
124
+ // Emit the last status emitted before waiting if defined
125
+ if (this.previousStatus !== undefined)
126
+ this.emit("status", this.previousStatus);
127
+ resolve();
128
+ });
129
+ });
130
+ }
131
+ async waitForParticipantsIfNeeded() {
132
+ // we check if we are waiting for more participants before sending our weight update
133
+ if (this.waitingForMoreParticipants) {
134
+ // wait for the promise to resolve, which takes as long as it takes for new participants to join
135
+ debug(`[${shortenId(this.ownId)}] is awaiting the promise for more participants`);
136
+ this.emit("status", "not enough participants");
137
+ await this.promiseForMoreParticipants;
138
+ // Make sure to set the promise back to undefined once resolved
139
+ this.promiseForMoreParticipants = undefined;
140
+ }
141
+ }
44
142
  /**
45
143
  * Fetches the latest model available on the network's server, for the adequate task.
46
144
  * @returns The latest model
@@ -54,13 +152,6 @@ export class Client extends EventEmitter {
54
152
  const response = await axios.get(url.href, { responseType: 'arraybuffer' });
55
153
  return await serialization.model.decode(new Uint8Array(response.data));
56
154
  }
57
- // Number of contributors to a collaborative session
58
- // If decentralized, it should be the number of peers
59
- // If federated, it should the number of participants excluding the server
60
- // If local it should be 1
61
- get nbOfParticipants() {
62
- return this.aggregator.nodes.size; // overriden by the federated client
63
- }
64
155
  get ownId() {
65
156
  if (this._ownId === undefined) {
66
157
  throw new Error('the node is not connected');
@@ -73,4 +164,14 @@ export class Client extends EventEmitter {
73
164
  }
74
165
  return this._server;
75
166
  }
167
+ /**
168
+ * Whether the client should wait until more
169
+ * participants join the session, i.e. a promise has been created
170
+ */
171
+ get waitingForMoreParticipants() {
172
+ return this.promiseForMoreParticipants !== undefined;
173
+ }
174
+ }
175
+ export function shortenId(id) {
176
+ return id.slice(0, 4);
76
177
  }
@@ -1,5 +1,5 @@
1
1
  import type { Model, WeightsContainer } from "../../index.js";
2
- import { Client } from '../index.js';
2
+ import { Client } from '../client.js';
3
3
  /**
4
4
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
5
5
  * help of the network's server, yet only exchange payloads between each other. Communication
@@ -7,11 +7,8 @@ import { Client } from '../index.js';
7
7
  * WebRTC for Node.js.
8
8
  */
9
9
  export declare class DecentralizedClient extends Client {
10
- /**
11
- * The pool of peers to communicate with during the current training round.
12
- */
13
- private pool?;
14
- private connections?;
10
+ #private;
11
+ getNbOfParticipants(): number;
15
12
  private get isDisconnected();
16
13
  /**
17
14
  * Public method called by disco.ts when starting training. This method sends
@@ -22,12 +19,6 @@ export declare class DecentralizedClient extends Client {
22
19
  * peers network information.
23
20
  */
24
21
  connect(): Promise<Model>;
25
- /**
26
- * Create a WebSocket connection with the server
27
- * The client then waits for the server to forward it other client's network information.
28
- * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
29
- */
30
- private connectServer;
31
22
  disconnect(): Promise<void>;
32
23
  /**
33
24
  * At the beginning of a round, each peer tells the server it is ready to proceed
@@ -38,6 +29,13 @@ export declare class DecentralizedClient extends Client {
38
29
  *
39
30
  */
40
31
  onRoundBeginCommunication(): Promise<void>;
32
+ onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
33
+ /**
34
+ * Signal to the server that we are ready to exchange weights.
35
+ * Once enough peers are ready, the server sends the list of peers for this round
36
+ * and the peers can establish peer-to-peer connections with each other.
37
+ */
38
+ private establishPeerConnections;
41
39
  /**
42
40
  * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
43
41
  * This method is used as callback by getPeers when connecting to the rounds' peers
@@ -45,5 +43,5 @@ export declare class DecentralizedClient extends Client {
45
43
  * @param round
46
44
  */
47
45
  private receivePayloads;
48
- onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
46
+ private exchangeWeightUpdates;
49
47
  }
@@ -1,7 +1,7 @@
1
1
  import createDebug from "debug";
2
2
  import { Map, Set } from 'immutable';
3
3
  import { serialization } from "../../index.js";
4
- import { Client } from '../index.js';
4
+ import { Client, shortenId } from '../client.js';
5
5
  import { type } from '../messages.js';
6
6
  import { timeout } from '../utils.js';
7
7
  import { WebSocketServer, waitMessage, waitMessageWithTimeout } from '../event_connection.js';
@@ -18,8 +18,12 @@ export class DecentralizedClient extends Client {
18
18
  /**
19
19
  * The pool of peers to communicate with during the current training round.
20
20
  */
21
- pool;
22
- connections;
21
+ #pool;
22
+ #connections;
23
+ getNbOfParticipants() {
24
+ const nbOfParticipants = this.aggregator.nodes.size;
25
+ return nbOfParticipants === 0 ? 1 : nbOfParticipants;
26
+ }
23
27
  // Used to handle timeouts and promise resolving after calling disconnect
24
28
  get isDisconnected() {
25
29
  return this._server === undefined;
@@ -46,42 +50,48 @@ export class DecentralizedClient extends Client {
46
50
  throw new Error(`unknown protocol: ${this.url.protocol}`);
47
51
  }
48
52
  serverURL.pathname += `decentralized/${this.task.id}`;
49
- this._server = await this.connectServer(serverURL);
53
+ // Create a WebSocket connection with the server
54
+ // The client then waits for the server to forward it other client's network information.
55
+ // Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
56
+ this._server = await WebSocketServer.connect(serverURL, messages.isMessageFromServer, messages.isMessageToServer);
57
+ this.server.on(type.SignalForPeer, (event) => {
58
+ if (this.#pool === undefined)
59
+ throw new Error('received signal but peer pool is undefined');
60
+ // Create a WebRTC connection with the peer
61
+ this.#pool.signal(event.peer, event.signal);
62
+ });
63
+ // c.f. setupServerCallbacks doc for explanation
64
+ let receivedEnoughParticipants = false;
65
+ this.setupServerCallbacks(() => receivedEnoughParticipants = true);
50
66
  const msg = {
51
67
  type: type.ClientConnected
52
68
  };
53
69
  this.server.send(msg);
54
- const { id } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
55
- debug(`[${id}] assigned id generated by server`);
70
+ const { id, waitForMoreParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
71
+ // This should come right after receiving the message to make sure
72
+ // we don't miss a subsequent message from the server
73
+ // We check if the server is telling us to wait for more participants
74
+ // and we also check if a EnoughParticipant message ended up arriving
75
+ // before the NewNodeInfo
76
+ if (waitForMoreParticipants && !receivedEnoughParticipants) {
77
+ // Create a promise that resolves when enough participants join
78
+ // The client will await this promise before sending its local weight update
79
+ this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
80
+ }
81
+ debug(`[${shortenId(id)}] assigned id generated by server`);
56
82
  if (this._ownId !== undefined) {
57
83
  throw new Error('received id from server but was already received');
58
84
  }
59
85
  this._ownId = id;
60
- this.pool = new PeerPool(id);
86
+ this.#pool = new PeerPool(id);
61
87
  return model;
62
88
  }
63
- /**
64
- * Create a WebSocket connection with the server
65
- * The client then waits for the server to forward it other client's network information.
66
- * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
67
- */
68
- async connectServer(url) {
69
- const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
70
- server.on(type.SignalForPeer, (event) => {
71
- if (this.pool === undefined) {
72
- throw new Error('received signal but peer pool is undefined');
73
- }
74
- // Create a WebRTC connection with the peer
75
- this.pool.signal(event.peer, event.signal);
76
- });
77
- return server;
78
- }
79
89
  async disconnect() {
80
90
  // Disconnect from peers
81
- await this.pool?.shutdown();
82
- this.pool = undefined;
83
- if (this.connections !== undefined) {
84
- const peers = this.connections.keySeq().toSet();
91
+ await this.#pool?.shutdown();
92
+ this.#pool = undefined;
93
+ if (this.#connections !== undefined) {
94
+ const peers = this.#connections.keySeq().toSet();
85
95
  this.aggregator.setNodes(this.aggregator.nodes.subtract(peers));
86
96
  }
87
97
  // Disconnect from server
@@ -99,13 +109,41 @@ export class DecentralizedClient extends Client {
99
109
  *
100
110
  */
101
111
  async onRoundBeginCommunication() {
112
+ // Notify the server we want to join the next round so that the server
113
+ // waits for us to be ready before sending the list of peers for the round
114
+ this.server.send({ type: type.JoinRound });
115
+ // Store the promise for the current round's aggregation result.
116
+ // We will await for it to resolve at the end of the round when exchanging weight updates.
117
+ this.aggregationResult = this.aggregator.getPromiseForAggregation();
118
+ this.saveAndEmit("local training");
119
+ return Promise.resolve();
120
+ }
121
+ async onRoundEndCommunication(weights) {
122
+ if (this.aggregationResult === undefined) {
123
+ throw new TypeError('aggregation result promise is undefined');
124
+ }
125
+ // Save the status in case participants leave and we switch to waiting for more participants
126
+ // Once enough new participants join we can display the previous status again
127
+ this.saveAndEmit("connecting to peers");
128
+ // First we check if we are waiting for more participants before sending our weight update
129
+ await this.waitForParticipantsIfNeeded();
130
+ // Create peer-to-peer connections with all peers for the round
131
+ await this.establishPeerConnections();
132
+ // Exchange weight updates with peers and return aggregated weights
133
+ return await this.exchangeWeightUpdates(weights);
134
+ }
135
+ /**
136
+ * Signal to the server that we are ready to exchange weights.
137
+ * Once enough peers are ready, the server sends the list of peers for this round
138
+ * and the peers can establish peer-to-peer connections with each other.
139
+ */
140
+ async establishPeerConnections() {
102
141
  if (this.server === undefined) {
103
142
  throw new Error("peer's server is undefined, make sure to call `client.connect()` first");
104
143
  }
105
- if (this.pool === undefined) {
144
+ if (this.#pool === undefined) {
106
145
  throw new Error('peer pool is undefined, make sure to call `client.connect()` first');
107
146
  }
108
- this.emit("status", "Retrieving peers' information");
109
147
  // Reset peers list at each round of training to make sure client works with an updated peers
110
148
  // list, maintained by the server. Adds any received weights to the aggregator.
111
149
  // Tell the server we are ready for the next round
@@ -113,33 +151,29 @@ export class DecentralizedClient extends Client {
113
151
  this.server.send(readyMessage);
114
152
  // Wait for the server to answer with the list of peers for the round
115
153
  try {
116
- debug(`[${this.ownId}] is waiting for peer list for round ${this.aggregator.round}`);
117
- const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound, undefined, "Timeout waiting for the round's peer list");
154
+ debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`);
155
+ const receivedMessage = await waitMessage(this.server, type.PeersForRound);
118
156
  const peers = Set(receivedMessage.peers);
119
157
  if (this.ownId !== undefined && peers.has(this.ownId)) {
120
158
  throw new Error('received peer list contains our own id');
121
159
  }
122
160
  // Store the list of peers for the current round including ourselves
123
161
  this.aggregator.setNodes(peers.add(this.ownId));
162
+ this.aggregator.setRound(receivedMessage.aggregationRound); // the server gives us the round number
124
163
  // Initiate peer to peer connections with each peer
125
164
  // When connected, create a promise waiting for each peer's round contribution
126
- const connections = await this.pool.getPeers(peers, this.server,
127
- // Init receipt of peers weights
128
- // this awaits the peer's weight update and adds it to
129
- // our aggregator upon reception
130
- (conn) => { this.receivePayloads(conn, this.aggregator.round); });
131
- debug(`[${this.ownId}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
132
- this.connections = connections;
165
+ const connections = await this.#pool.getPeers(peers, this.server,
166
+ // Init receipt of peers weights. this awaits the peer's
167
+ // weight update and adds it to our aggregator upon reception
168
+ (conn) => this.receivePayloads(conn));
169
+ debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
170
+ this.#connections = connections;
133
171
  }
134
172
  catch (e) {
135
- debug(`Error for [${this.ownId}] while beginning round: %o`, e);
173
+ debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
136
174
  this.aggregator.setNodes(Set(this.ownId));
137
- this.connections = Map();
175
+ this.#connections = Map();
138
176
  }
139
- // Store the promise for the current round's aggregation result.
140
- // We will await for it to resolve at the end of the round when exchanging weight updates.
141
- this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
142
- this.emit("status", "Training the model on the data you connected");
143
177
  }
144
178
  /**
145
179
  * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
@@ -147,66 +181,71 @@ export class DecentralizedClient extends Client {
147
181
  * @param connections
148
182
  * @param round
149
183
  */
150
- receivePayloads(connections, round) {
184
+ receivePayloads(connections) {
151
185
  connections.forEach(async (connection, peerId) => {
152
- let currentCommunicationRounds = 0;
153
186
  debug(`waiting for peer ${peerId}`);
154
- do {
187
+ for (let r = 0; r < this.aggregator.communicationRounds; r++) {
155
188
  try {
156
189
  const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId);
157
190
  const decoded = serialization.weights.decode(message.payload);
158
- if (!this.aggregator.add(peerId, decoded, round, message.round)) {
159
- debug(`[${this.ownId}] failed to add contribution from peer ${peerId}`);
191
+ if (!this.aggregator.isValidContribution(peerId, message.aggregationRound)) {
192
+ debug(`[${shortenId(this.ownId)}] failed to add contribution from peer ${shortenId(peerId)}`);
193
+ }
194
+ else {
195
+ debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` +
196
+ ` for round (%d, %d)`, message.aggregationRound, message.communicationRound);
197
+ this.aggregator.once("aggregation", () => debug(`[${shortenId(this.ownId)}] aggregated the model` +
198
+ ` for round (%d, %d)`, message.aggregationRound, message.communicationRound));
199
+ this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound);
160
200
  }
161
201
  }
162
202
  catch (e) {
163
203
  if (this.isDisconnected)
164
204
  return;
165
- debug(`Error for [${this.ownId}] while receiving payloads: %o`, e);
205
+ debug(`Error for [${shortenId(this.ownId)}] while receiving payloads: %o`, e);
166
206
  }
167
- } while (++currentCommunicationRounds < this.aggregator.communicationRounds);
207
+ }
168
208
  });
169
209
  }
170
- async onRoundEndCommunication(weights) {
210
+ async exchangeWeightUpdates(weights) {
171
211
  if (this.aggregationResult === undefined) {
172
212
  throw new TypeError('aggregation result promise is undefined');
173
213
  }
174
- this.emit("status", "Updating the model with other participants' models");
214
+ this.saveAndEmit("updating model");
175
215
  // Perform the required communication rounds. Each communication round consists in sending our local payload,
176
216
  // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
177
217
  // A communication round's payload is the aggregation result of the previous communication round. The first
178
218
  // communication round simply sends our training result, i.e. model weights updates. This scheme allows for
179
219
  // the aggregator to define any complex multi-round aggregation mechanism.
180
220
  let result = weights;
181
- for (let r = 0; r < this.aggregator.communicationRounds; r++) {
221
+ for (let communicationRound = 0; communicationRound < this.aggregator.communicationRounds; communicationRound++) {
222
+ const connections = this.#connections;
223
+ if (connections === undefined)
224
+ throw new Error("peer's connections is undefined");
182
225
  // Generate our payloads for this communication round and send them to all ready connected peers
183
- if (this.connections !== undefined) {
184
- const payloads = this.aggregator.makePayloads(result);
185
- try {
186
- await Promise.all(payloads.map(async (payload, id) => {
187
- if (id === this.ownId) {
188
- this.aggregator.add(this.ownId, payload, this.aggregator.round, r);
189
- }
190
- else {
191
- const peer = this.connections?.get(id);
192
- if (peer !== undefined) {
193
- const encoded = await serialization.weights.encode(payload);
194
- const msg = {
195
- type: type.Payload,
196
- peer: id,
197
- round: r,
198
- payload: encoded
199
- };
200
- peer.send(msg);
201
- debug(`[${this.ownId}] send weight update to peer ${msg.peer}: %O`, msg);
202
- }
203
- }
204
- }));
226
+ const payloads = this.aggregator.makePayloads(result);
227
+ payloads.forEach(async (payload, id) => {
228
+ // add our own contribution to the aggregator
229
+ if (id === this.ownId) {
230
+ this.aggregator.add(this.ownId, payload, this.aggregator.round, communicationRound);
231
+ return;
205
232
  }
206
- catch (cause) {
207
- throw new Error('error while sending weights', { cause });
233
+ // Send our payload to each peer
234
+ const peer = connections.get(id);
235
+ if (peer !== undefined) {
236
+ const encoded = await serialization.weights.encode(payload);
237
+ const msg = {
238
+ type: type.Payload,
239
+ peer: id,
240
+ aggregationRound: this.aggregator.round,
241
+ communicationRound,
242
+ payload: encoded
243
+ };
244
+ peer.send(msg);
245
+ debug(`[${shortenId(this.ownId)}] send weight update to peer ${shortenId(msg.peer)}` +
246
+ ` for round (%d, %d)`, this.aggregator.round, communicationRound);
208
247
  }
209
- }
248
+ });
210
249
  // Wait for aggregation before proceeding to the next communication round.
211
250
  // The current result will be used as payload for the eventual next communication round.
212
251
  try {
@@ -219,17 +258,15 @@ export class DecentralizedClient extends Client {
219
258
  if (this.isDisconnected) {
220
259
  return weights;
221
260
  }
222
- debug(`[${this.ownId}] while waiting for aggregation: %o`, e);
261
+ debug(`[${shortenId(this.ownId)}] while waiting for aggregation: %o`, e);
223
262
  break;
224
263
  }
225
264
  // There is at least one communication round remaining
226
- if (r < this.aggregator.communicationRounds - 1) {
265
+ if (communicationRound < this.aggregator.communicationRounds - 1) {
227
266
  // Reuse the aggregation result
228
- this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
267
+ this.aggregationResult = this.aggregator.getPromiseForAggregation();
229
268
  }
230
269
  }
231
- // Reset the peers list for the next round
232
- this.aggregator.resetNodes();
233
270
  return await this.aggregationResult;
234
271
  }
235
272
  }
@@ -1,7 +1,8 @@
1
1
  import { weights } from '../../serialization/index.js';
2
2
  import { type SignalData } from './peer.js';
3
3
  import { type NodeID } from '../types.js';
4
- import { type, type ClientConnected } from '../messages.js';
4
+ import { type } from '../messages.js';
5
+ import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
5
6
  export interface NewDecentralizedNodeInfo {
6
7
  type: type.NewDecentralizedNodeInfo;
7
8
  id: NodeID;
@@ -12,21 +13,26 @@ export interface SignalForPeer {
12
13
  peer: NodeID;
13
14
  signal: SignalData;
14
15
  }
16
+ export interface JoinRound {
17
+ type: type.JoinRound;
18
+ }
15
19
  export interface PeerIsReady {
16
20
  type: type.PeerIsReady;
17
21
  }
18
22
  export interface PeersForRound {
19
23
  type: type.PeersForRound;
20
24
  peers: NodeID[];
25
+ aggregationRound: number;
21
26
  }
22
27
  export interface Payload {
23
28
  type: type.Payload;
24
29
  peer: NodeID;
25
- round: number;
30
+ aggregationRound: number;
31
+ communicationRound: number;
26
32
  payload: weights.Encoded;
27
33
  }
28
- export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound;
29
- export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady;
34
+ export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound | WaitingForMoreParticipants | EnoughParticipants;
35
+ export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | JoinRound;
30
36
  export type PeerMessage = Payload;
31
37
  export declare function isMessageFromServer(o: unknown): o is MessageFromServer;
32
38
  export declare function isMessageToServer(o: unknown): o is MessageToServer;
@@ -2,9 +2,8 @@ import { weights } from '../../serialization/index.js';
2
2
  import { isNodeID } from '../types.js';
3
3
  import { type, hasMessageType } from '../messages.js';
4
4
  export function isMessageFromServer(o) {
5
- if (!hasMessageType(o)) {
5
+ if (!hasMessageType(o))
6
6
  return false;
7
- }
8
7
  switch (o.type) {
9
8
  case type.NewDecentralizedNodeInfo:
10
9
  return 'id' in o && isNodeID(o.id) &&
@@ -15,28 +14,30 @@ export function isMessageFromServer(o) {
15
14
  'signal' in o; // TODO check signal content?
16
15
  case type.PeersForRound:
17
16
  return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID);
17
+ case type.WaitingForMoreParticipants:
18
+ case type.EnoughParticipants:
19
+ return true;
18
20
  }
19
21
  return false;
20
22
  }
21
23
  export function isMessageToServer(o) {
22
- if (!hasMessageType(o)) {
24
+ if (!hasMessageType(o))
23
25
  return false;
24
- }
25
26
  switch (o.type) {
26
27
  case type.ClientConnected:
27
28
  return true;
28
29
  case type.SignalForPeer:
29
30
  return 'peer' in o && isNodeID(o.peer) &&
30
31
  'signal' in o; // TODO check signal content?
32
+ case type.JoinRound:
31
33
  case type.PeerIsReady:
32
34
  return true;
33
35
  }
34
36
  return false;
35
37
  }
36
38
  export function isPeerMessage(o) {
37
- if (!hasMessageType(o)) {
39
+ if (!hasMessageType(o))
38
40
  return false;
39
- }
40
41
  switch (o.type) {
41
42
  case type.Payload:
42
43
  return ('peer' in o && isNodeID(o.peer) &&
@@ -6,25 +6,13 @@ import { Client } from "../client.js";
6
6
  */
7
7
  export declare class FederatedClient extends Client {
8
8
  #private;
9
- get nbOfParticipants(): number;
10
- /**
11
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
12
- */
13
- private connectServer;
9
+ getNbOfParticipants(): number;
14
10
  /**
15
11
  * Initializes the connection to the server, gets our node ID
16
12
  * as well as the latest training information: latest global model, current round and
17
13
  * whether we are waiting for more participants.
18
14
  */
19
15
  connect(): Promise<Model>;
20
- /**
21
- * Method called when the server notifies the client that there aren't enough
22
- * participants (anymore) to start/continue training
23
- * The method creates a promise that will resolve once the server notifies
24
- * the client that the training can resume via a subsequent EnoughParticipants message
25
- * @returns a promise which resolves when enough participants joined the session
26
- */
27
- private waitForMoreParticipants;
28
16
  /**
29
17
  * Disconnection process when user quits the task.
30
18
  */