@epfml/discojs 3.0.1-p20241007204240.0 → 3.0.1-p20241024094708.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
  2. package/dist/aggregator/{base.js → aggregator.js} +48 -36
  3. package/dist/aggregator/get.d.ts +2 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/index.d.ts +1 -4
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +4 -4
  8. package/dist/aggregator/mean.js +5 -15
  9. package/dist/aggregator/secure.d.ts +4 -4
  10. package/dist/aggregator/secure.js +7 -17
  11. package/dist/client/client.d.ts +71 -17
  12. package/dist/client/client.js +118 -17
  13. package/dist/client/decentralized/decentralized_client.d.ts +11 -13
  14. package/dist/client/decentralized/decentralized_client.js +121 -84
  15. package/dist/client/decentralized/messages.d.ts +12 -6
  16. package/dist/client/decentralized/messages.js +9 -8
  17. package/dist/client/event_connection.js +2 -2
  18. package/dist/client/federated/federated_client.d.ts +1 -13
  19. package/dist/client/federated/federated_client.js +15 -94
  20. package/dist/client/federated/messages.d.ts +6 -11
  21. package/dist/client/local_client.d.ts +1 -0
  22. package/dist/client/local_client.js +3 -0
  23. package/dist/client/messages.d.ts +14 -7
  24. package/dist/client/messages.js +13 -11
  25. package/dist/default_tasks/cifar10.js +1 -1
  26. package/dist/default_tasks/lus_covid.js +1 -0
  27. package/dist/default_tasks/mnist.js +1 -1
  28. package/dist/default_tasks/simple_face.js +1 -0
  29. package/dist/default_tasks/titanic.js +1 -0
  30. package/dist/default_tasks/wikitext.js +1 -0
  31. package/dist/index.d.ts +0 -2
  32. package/dist/serialization/coder.d.ts +4 -0
  33. package/dist/serialization/coder.js +51 -0
  34. package/dist/serialization/index.d.ts +2 -0
  35. package/dist/serialization/index.js +1 -0
  36. package/dist/serialization/model.d.ts +1 -2
  37. package/dist/serialization/model.js +9 -24
  38. package/dist/serialization/weights.d.ts +2 -3
  39. package/dist/serialization/weights.js +15 -26
  40. package/dist/task/task_handler.d.ts +5 -5
  41. package/dist/task/task_handler.js +21 -15
  42. package/dist/task/training_information.d.ts +1 -2
  43. package/dist/task/training_information.js +6 -8
  44. package/dist/training/disco.d.ts +4 -1
  45. package/dist/training/trainer.js +1 -1
  46. package/dist/utils/event_emitter.d.ts +3 -3
  47. package/dist/utils/event_emitter.js +10 -9
  48. package/package.json +2 -3
@@ -1,6 +1,8 @@
1
- import axios from 'axios';
1
+ import createDebug from "debug";
2
2
  import { serialization } from '../index.js';
3
3
  import { EventEmitter } from '../utils/event_emitter.js';
4
+ import { type } from "./messages.js";
5
+ const debug = createDebug("discojs:client");
4
6
  /**
5
7
  * Main, abstract, class representing a Disco client in a network, which handles
6
8
  * communication with other nodes, be it peers or a server.
@@ -9,18 +11,25 @@ export class Client extends EventEmitter {
9
11
  url;
10
12
  task;
11
13
  aggregator;
12
- /**
13
- * Own ID provided by the network's server.
14
- */
14
+ // Own ID provided by the network's server.
15
15
  _ownId;
16
+ // The network's server.
17
+ _server;
18
+ // The aggregator's result produced after aggregation.
19
+ aggregationResult;
16
20
  /**
17
- * The network's server.
21
+ * When the server notifies clients to pause and wait until more
22
+ * participants join, we rely on this promise to wait
23
+ * until the server signals that the training can resume
18
24
  */
19
- _server;
25
+ promiseForMoreParticipants = undefined;
20
26
  /**
21
- * The aggregator's result produced after aggregation.
27
+ * When the server notifies the client that they can resume training
28
+ * after waiting for more participants, we want to be able to display what
29
+ * we were doing before waiting (training locally or updating our model).
30
+ * We use this attribute to store the status to rollback to when we stop waiting
22
31
  */
23
- aggregationResult;
32
+ previousStatus;
24
33
  constructor(url, // The network server's URL to connect to
25
34
  task, // The client's corresponding task
26
35
  aggregator) {
@@ -41,6 +50,94 @@ export class Client extends EventEmitter {
41
50
  * Handles the disconnection process of the client from any sort of network server.
42
51
  */
43
52
  async disconnect() { }
53
+ /**
54
+ * Emits the round status specified. It also stores the status emitted such that
55
+ * if the server tells the client to wait for more participants, it can display
56
+ * the waiting status and once enough participants join, it can display the previous status again
57
+ */
58
+ saveAndEmit(status) {
59
+ this.previousStatus = status;
60
+ this.emit("status", status);
61
+ }
62
+ /**
63
+ * For both federated and decentralized clients, we listen to the server to tell
64
+ * us whether there are enough participants to train. If not, we pause until further notice.
65
+ * When a client connects to the server, the server answers with the session information (id,
66
+ * number of participants) and whether there are enough participants.
67
+ * When there are the server sends a new EnoughParticipant message to update the client.
68
+ *
69
+ * `setMessageInversionFlag` is used to address the following scenario:
70
+ * 1. Client 1 connect to the server
71
+ * 2. Server answers with message A containing "not enough participants"
72
+ * 3. Before A arrives a new client joins. There are enough participants now.
73
+ * 4. Server updates client 1 with message B saying "there are enough participants"
74
+ * 5. Due to network and message sizes message B can arrive before A.
75
+ * i.e. "there are enough participants" arrives before "not enough participants"
76
+ * ending up with client 1 thinking it needs to wait for more participants.
77
+ *
78
+ * To keep track of this message inversion, `setMessageInversionFlag`
79
+ * tells us whether a message inversion occurred (by setting a boolean to true)
80
+ *
81
+ * @param setMessageInversionFlag function flagging whether a message inversion occurred
82
+ * between a NewNodeInfo message and an EnoughParticipant message.
83
+ */
84
+ setupServerCallbacks(setMessageInversionFlag) {
85
+ // Setup an event callback if the server signals that we should
86
+ // wait for more participants
87
+ this.server.on(type.WaitingForMoreParticipants, () => {
88
+ if (this.promiseForMoreParticipants !== undefined)
89
+ throw new Error("Server sent multiple WaitingForMoreParticipants messages");
90
+ debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`);
91
+ // Display the waiting status right away
92
+ this.emit("status", "not enough participants");
93
+ // Upon receiving a WaitingForMoreParticipants message,
94
+ // the client will await for this promise to resolve before sending its
95
+ // local weight update
96
+ this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
97
+ });
98
+ // As an example assume we need at least 2 participants to train,
99
+ // When two participants join almost at the same time, the server
100
+ // sends a NewNodeInfo with waitForMoreParticipants=true to the first participant
101
+ // and directly follows with an EnoughParticipants message when the 2nd participant joins
102
+ // However, the EnoughParticipants can arrive before the NewNodeInfo (which can be much bigger)
103
+ // so we check whether we received the EnoughParticipants before being assigned a node ID
104
+ this.server.once(type.EnoughParticipants, () => {
105
+ if (this._ownId === undefined) {
106
+ debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
107
+ setMessageInversionFlag();
108
+ }
109
+ });
110
+ }
111
+ /**
112
+ * Method called when the server notifies the client that there aren't enough
113
+ * participants (anymore) to start/continue training
114
+ * The method creates a promise that will resolve once the server notifies
115
+ * the client that the training can resume via a subsequent EnoughParticipants message
116
+ * @returns a promise which resolves when enough participants joined the session
117
+ */
118
+ async createPromiseForMoreParticipants() {
119
+ return new Promise((resolve) => {
120
+ // "once" is important because we can't resolve the same promise multiple times
121
+ this.server.once(type.EnoughParticipants, () => {
122
+ debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`);
123
+ // Emit the last status emitted before waiting if defined
124
+ if (this.previousStatus !== undefined)
125
+ this.emit("status", this.previousStatus);
126
+ resolve();
127
+ });
128
+ });
129
+ }
130
+ async waitForParticipantsIfNeeded() {
131
+ // we check if we are waiting for more participants before sending our weight update
132
+ if (this.waitingForMoreParticipants) {
133
+ // wait for the promise to resolve, which takes as long as it takes for new participants to join
134
+ debug(`[${shortenId(this.ownId)}] is awaiting the promise for more participants`);
135
+ this.emit("status", "not enough participants");
136
+ await this.promiseForMoreParticipants;
137
+ // Make sure to set the promise back to undefined once resolved
138
+ this.promiseForMoreParticipants = undefined;
139
+ }
140
+ }
44
141
  /**
45
142
  * Fetches the latest model available on the network's server, for the adequate task.
46
143
  * @returns The latest model
@@ -51,15 +148,9 @@ export class Client extends EventEmitter {
51
148
  url.pathname += '/';
52
149
  }
53
150
  url.pathname += `tasks/${this.task.id}/model.json`;
54
- const response = await axios.get(url.href, { responseType: 'arraybuffer' });
55
- return await serialization.model.decode(new Uint8Array(response.data));
56
- }
57
- // Number of contributors to a collaborative session
58
- // If decentralized, it should be the number of peers
59
- // If federated, it should the number of participants excluding the server
60
- // If local it should be 1
61
- get nbOfParticipants() {
62
- return this.aggregator.nodes.size; // overriden by the federated client
151
+ const response = await fetch(url);
152
+ const encoded = new Uint8Array(await response.arrayBuffer());
153
+ return await serialization.model.decode(encoded);
63
154
  }
64
155
  get ownId() {
65
156
  if (this._ownId === undefined) {
@@ -73,4 +164,14 @@ export class Client extends EventEmitter {
73
164
  }
74
165
  return this._server;
75
166
  }
167
+ /**
168
+ * Whether the client should wait until more
169
+ * participants join the session, i.e. a promise has been created
170
+ */
171
+ get waitingForMoreParticipants() {
172
+ return this.promiseForMoreParticipants !== undefined;
173
+ }
174
+ }
175
+ export function shortenId(id) {
176
+ return id.slice(0, 4);
76
177
  }
@@ -1,5 +1,5 @@
1
1
  import type { Model, WeightsContainer } from "../../index.js";
2
- import { Client } from '../index.js';
2
+ import { Client } from '../client.js';
3
3
  /**
4
4
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
5
5
  * help of the network's server, yet only exchange payloads between each other. Communication
@@ -7,11 +7,8 @@ import { Client } from '../index.js';
7
7
  * WebRTC for Node.js.
8
8
  */
9
9
  export declare class DecentralizedClient extends Client {
10
- /**
11
- * The pool of peers to communicate with during the current training round.
12
- */
13
- private pool?;
14
- private connections?;
10
+ #private;
11
+ getNbOfParticipants(): number;
15
12
  private get isDisconnected();
16
13
  /**
17
14
  * Public method called by disco.ts when starting training. This method sends
@@ -22,12 +19,6 @@ export declare class DecentralizedClient extends Client {
22
19
  * peers network information.
23
20
  */
24
21
  connect(): Promise<Model>;
25
- /**
26
- * Create a WebSocket connection with the server
27
- * The client then waits for the server to forward it other client's network information.
28
- * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
29
- */
30
- private connectServer;
31
22
  disconnect(): Promise<void>;
32
23
  /**
33
24
  * At the beginning of a round, each peer tells the server it is ready to proceed
@@ -38,6 +29,13 @@ export declare class DecentralizedClient extends Client {
38
29
  *
39
30
  */
40
31
  onRoundBeginCommunication(): Promise<void>;
32
+ onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
33
+ /**
34
+ * Signal to the server that we are ready to exchange weights.
35
+ * Once enough peers are ready, the server sends the list of peers for this round
36
+ * and the peers can establish peer-to-peer connections with each other.
37
+ */
38
+ private establishPeerConnections;
41
39
  /**
42
40
  * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
43
41
  * This method is used as callback by getPeers when connecting to the rounds' peers
@@ -45,5 +43,5 @@ export declare class DecentralizedClient extends Client {
45
43
  * @param round
46
44
  */
47
45
  private receivePayloads;
48
- onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
46
+ private exchangeWeightUpdates;
49
47
  }
@@ -1,7 +1,7 @@
1
1
  import createDebug from "debug";
2
2
  import { Map, Set } from 'immutable';
3
3
  import { serialization } from "../../index.js";
4
- import { Client } from '../index.js';
4
+ import { Client, shortenId } from '../client.js';
5
5
  import { type } from '../messages.js';
6
6
  import { timeout } from '../utils.js';
7
7
  import { WebSocketServer, waitMessage, waitMessageWithTimeout } from '../event_connection.js';
@@ -18,8 +18,12 @@ export class DecentralizedClient extends Client {
18
18
  /**
19
19
  * The pool of peers to communicate with during the current training round.
20
20
  */
21
- pool;
22
- connections;
21
+ #pool;
22
+ #connections;
23
+ getNbOfParticipants() {
24
+ const nbOfParticipants = this.aggregator.nodes.size;
25
+ return nbOfParticipants === 0 ? 1 : nbOfParticipants;
26
+ }
23
27
  // Used to handle timeouts and promise resolving after calling disconnect
24
28
  get isDisconnected() {
25
29
  return this._server === undefined;
@@ -46,42 +50,48 @@ export class DecentralizedClient extends Client {
46
50
  throw new Error(`unknown protocol: ${this.url.protocol}`);
47
51
  }
48
52
  serverURL.pathname += `decentralized/${this.task.id}`;
49
- this._server = await this.connectServer(serverURL);
53
+ // Create a WebSocket connection with the server
54
+ // The client then waits for the server to forward it other client's network information.
55
+ // Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
56
+ this._server = await WebSocketServer.connect(serverURL, messages.isMessageFromServer, messages.isMessageToServer);
57
+ this.server.on(type.SignalForPeer, (event) => {
58
+ if (this.#pool === undefined)
59
+ throw new Error('received signal but peer pool is undefined');
60
+ // Create a WebRTC connection with the peer
61
+ this.#pool.signal(event.peer, event.signal);
62
+ });
63
+ // c.f. setupServerCallbacks doc for explanation
64
+ let receivedEnoughParticipants = false;
65
+ this.setupServerCallbacks(() => receivedEnoughParticipants = true);
50
66
  const msg = {
51
67
  type: type.ClientConnected
52
68
  };
53
69
  this.server.send(msg);
54
- const { id } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
55
- debug(`[${id}] assigned id generated by server`);
70
+ const { id, waitForMoreParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
71
+ // This should come right after receiving the message to make sure
72
+ // we don't miss a subsequent message from the server
73
+ // We check if the server is telling us to wait for more participants
74
+ // and we also check if a EnoughParticipant message ended up arriving
75
+ // before the NewNodeInfo
76
+ if (waitForMoreParticipants && !receivedEnoughParticipants) {
77
+ // Create a promise that resolves when enough participants join
78
+ // The client will await this promise before sending its local weight update
79
+ this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
80
+ }
81
+ debug(`[${shortenId(id)}] assigned id generated by server`);
56
82
  if (this._ownId !== undefined) {
57
83
  throw new Error('received id from server but was already received');
58
84
  }
59
85
  this._ownId = id;
60
- this.pool = new PeerPool(id);
86
+ this.#pool = new PeerPool(id);
61
87
  return model;
62
88
  }
63
- /**
64
- * Create a WebSocket connection with the server
65
- * The client then waits for the server to forward it other client's network information.
66
- * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
67
- */
68
- async connectServer(url) {
69
- const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
70
- server.on(type.SignalForPeer, (event) => {
71
- if (this.pool === undefined) {
72
- throw new Error('received signal but peer pool is undefined');
73
- }
74
- // Create a WebRTC connection with the peer
75
- this.pool.signal(event.peer, event.signal);
76
- });
77
- return server;
78
- }
79
89
  async disconnect() {
80
90
  // Disconnect from peers
81
- await this.pool?.shutdown();
82
- this.pool = undefined;
83
- if (this.connections !== undefined) {
84
- const peers = this.connections.keySeq().toSet();
91
+ await this.#pool?.shutdown();
92
+ this.#pool = undefined;
93
+ if (this.#connections !== undefined) {
94
+ const peers = this.#connections.keySeq().toSet();
85
95
  this.aggregator.setNodes(this.aggregator.nodes.subtract(peers));
86
96
  }
87
97
  // Disconnect from server
@@ -99,13 +109,41 @@ export class DecentralizedClient extends Client {
99
109
  *
100
110
  */
101
111
  async onRoundBeginCommunication() {
112
+ // Notify the server we want to join the next round so that the server
113
+ // waits for us to be ready before sending the list of peers for the round
114
+ this.server.send({ type: type.JoinRound });
115
+ // Store the promise for the current round's aggregation result.
116
+ // We will await for it to resolve at the end of the round when exchanging weight updates.
117
+ this.aggregationResult = this.aggregator.getPromiseForAggregation();
118
+ this.saveAndEmit("local training");
119
+ return Promise.resolve();
120
+ }
121
+ async onRoundEndCommunication(weights) {
122
+ if (this.aggregationResult === undefined) {
123
+ throw new TypeError('aggregation result promise is undefined');
124
+ }
125
+ // Save the status in case participants leave and we switch to waiting for more participants
126
+ // Once enough new participants join we can display the previous status again
127
+ this.saveAndEmit("connecting to peers");
128
+ // First we check if we are waiting for more participants before sending our weight update
129
+ await this.waitForParticipantsIfNeeded();
130
+ // Create peer-to-peer connections with all peers for the round
131
+ await this.establishPeerConnections();
132
+ // Exchange weight updates with peers and return aggregated weights
133
+ return await this.exchangeWeightUpdates(weights);
134
+ }
135
+ /**
136
+ * Signal to the server that we are ready to exchange weights.
137
+ * Once enough peers are ready, the server sends the list of peers for this round
138
+ * and the peers can establish peer-to-peer connections with each other.
139
+ */
140
+ async establishPeerConnections() {
102
141
  if (this.server === undefined) {
103
142
  throw new Error("peer's server is undefined, make sure to call `client.connect()` first");
104
143
  }
105
- if (this.pool === undefined) {
144
+ if (this.#pool === undefined) {
106
145
  throw new Error('peer pool is undefined, make sure to call `client.connect()` first');
107
146
  }
108
- this.emit("status", "Retrieving peers' information");
109
147
  // Reset peers list at each round of training to make sure client works with an updated peers
110
148
  // list, maintained by the server. Adds any received weights to the aggregator.
111
149
  // Tell the server we are ready for the next round
@@ -113,33 +151,29 @@ export class DecentralizedClient extends Client {
113
151
  this.server.send(readyMessage);
114
152
  // Wait for the server to answer with the list of peers for the round
115
153
  try {
116
- debug(`[${this.ownId}] is waiting for peer list for round ${this.aggregator.round}`);
117
- const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound, undefined, "Timeout waiting for the round's peer list");
154
+ debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`);
155
+ const receivedMessage = await waitMessage(this.server, type.PeersForRound);
118
156
  const peers = Set(receivedMessage.peers);
119
157
  if (this.ownId !== undefined && peers.has(this.ownId)) {
120
158
  throw new Error('received peer list contains our own id');
121
159
  }
122
160
  // Store the list of peers for the current round including ourselves
123
161
  this.aggregator.setNodes(peers.add(this.ownId));
162
+ this.aggregator.setRound(receivedMessage.aggregationRound); // the server gives us the round number
124
163
  // Initiate peer to peer connections with each peer
125
164
  // When connected, create a promise waiting for each peer's round contribution
126
- const connections = await this.pool.getPeers(peers, this.server,
127
- // Init receipt of peers weights
128
- // this awaits the peer's weight update and adds it to
129
- // our aggregator upon reception
130
- (conn) => { this.receivePayloads(conn, this.aggregator.round); });
131
- debug(`[${this.ownId}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
132
- this.connections = connections;
165
+ const connections = await this.#pool.getPeers(peers, this.server,
166
+ // Init receipt of peers weights. this awaits the peer's
167
+ // weight update and adds it to our aggregator upon reception
168
+ (conn) => this.receivePayloads(conn));
169
+ debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
170
+ this.#connections = connections;
133
171
  }
134
172
  catch (e) {
135
- debug(`Error for [${this.ownId}] while beginning round: %o`, e);
173
+ debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
136
174
  this.aggregator.setNodes(Set(this.ownId));
137
- this.connections = Map();
175
+ this.#connections = Map();
138
176
  }
139
- // Store the promise for the current round's aggregation result.
140
- // We will await for it to resolve at the end of the round when exchanging weight updates.
141
- this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
142
- this.emit("status", "Training the model on the data you connected");
143
177
  }
144
178
  /**
145
179
  * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
@@ -147,66 +181,71 @@ export class DecentralizedClient extends Client {
147
181
  * @param connections
148
182
  * @param round
149
183
  */
150
- receivePayloads(connections, round) {
184
+ receivePayloads(connections) {
151
185
  connections.forEach(async (connection, peerId) => {
152
- let currentCommunicationRounds = 0;
153
186
  debug(`waiting for peer ${peerId}`);
154
- do {
187
+ for (let r = 0; r < this.aggregator.communicationRounds; r++) {
155
188
  try {
156
189
  const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId);
157
190
  const decoded = serialization.weights.decode(message.payload);
158
- if (!this.aggregator.add(peerId, decoded, round, message.round)) {
159
- debug(`[${this.ownId}] failed to add contribution from peer ${peerId}`);
191
+ if (!this.aggregator.isValidContribution(peerId, message.aggregationRound)) {
192
+ debug(`[${shortenId(this.ownId)}] failed to add contribution from peer ${shortenId(peerId)}`);
193
+ }
194
+ else {
195
+ debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` +
196
+ ` for round (%d, %d)`, message.aggregationRound, message.communicationRound);
197
+ this.aggregator.once("aggregation", () => debug(`[${shortenId(this.ownId)}] aggregated the model` +
198
+ ` for round (%d, %d)`, message.aggregationRound, message.communicationRound));
199
+ this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound);
160
200
  }
161
201
  }
162
202
  catch (e) {
163
203
  if (this.isDisconnected)
164
204
  return;
165
- debug(`Error for [${this.ownId}] while receiving payloads: %o`, e);
205
+ debug(`Error for [${shortenId(this.ownId)}] while receiving payloads: %o`, e);
166
206
  }
167
- } while (++currentCommunicationRounds < this.aggregator.communicationRounds);
207
+ }
168
208
  });
169
209
  }
170
- async onRoundEndCommunication(weights) {
210
+ async exchangeWeightUpdates(weights) {
171
211
  if (this.aggregationResult === undefined) {
172
212
  throw new TypeError('aggregation result promise is undefined');
173
213
  }
174
- this.emit("status", "Updating the model with other participants' models");
214
+ this.saveAndEmit("updating model");
175
215
  // Perform the required communication rounds. Each communication round consists in sending our local payload,
176
216
  // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
177
217
  // A communication round's payload is the aggregation result of the previous communication round. The first
178
218
  // communication round simply sends our training result, i.e. model weights updates. This scheme allows for
179
219
  // the aggregator to define any complex multi-round aggregation mechanism.
180
220
  let result = weights;
181
- for (let r = 0; r < this.aggregator.communicationRounds; r++) {
221
+ for (let communicationRound = 0; communicationRound < this.aggregator.communicationRounds; communicationRound++) {
222
+ const connections = this.#connections;
223
+ if (connections === undefined)
224
+ throw new Error("peer's connections is undefined");
182
225
  // Generate our payloads for this communication round and send them to all ready connected peers
183
- if (this.connections !== undefined) {
184
- const payloads = this.aggregator.makePayloads(result);
185
- try {
186
- await Promise.all(payloads.map(async (payload, id) => {
187
- if (id === this.ownId) {
188
- this.aggregator.add(this.ownId, payload, this.aggregator.round, r);
189
- }
190
- else {
191
- const peer = this.connections?.get(id);
192
- if (peer !== undefined) {
193
- const encoded = await serialization.weights.encode(payload);
194
- const msg = {
195
- type: type.Payload,
196
- peer: id,
197
- round: r,
198
- payload: encoded
199
- };
200
- peer.send(msg);
201
- debug(`[${this.ownId}] send weight update to peer ${msg.peer}: %O`, msg);
202
- }
203
- }
204
- }));
226
+ const payloads = this.aggregator.makePayloads(result);
227
+ payloads.forEach(async (payload, id) => {
228
+ // add our own contribution to the aggregator
229
+ if (id === this.ownId) {
230
+ this.aggregator.add(this.ownId, payload, this.aggregator.round, communicationRound);
231
+ return;
205
232
  }
206
- catch (cause) {
207
- throw new Error('error while sending weights', { cause });
233
+ // Send our payload to each peer
234
+ const peer = connections.get(id);
235
+ if (peer !== undefined) {
236
+ const encoded = await serialization.weights.encode(payload);
237
+ const msg = {
238
+ type: type.Payload,
239
+ peer: id,
240
+ aggregationRound: this.aggregator.round,
241
+ communicationRound,
242
+ payload: encoded
243
+ };
244
+ peer.send(msg);
245
+ debug(`[${shortenId(this.ownId)}] send weight update to peer ${shortenId(msg.peer)}` +
246
+ ` for round (%d, %d)`, this.aggregator.round, communicationRound);
208
247
  }
209
- }
248
+ });
210
249
  // Wait for aggregation before proceeding to the next communication round.
211
250
  // The current result will be used as payload for the eventual next communication round.
212
251
  try {
@@ -219,17 +258,15 @@ export class DecentralizedClient extends Client {
219
258
  if (this.isDisconnected) {
220
259
  return weights;
221
260
  }
222
- debug(`[${this.ownId}] while waiting for aggregation: %o`, e);
261
+ debug(`[${shortenId(this.ownId)}] while waiting for aggregation: %o`, e);
223
262
  break;
224
263
  }
225
264
  // There is at least one communication round remaining
226
- if (r < this.aggregator.communicationRounds - 1) {
265
+ if (communicationRound < this.aggregator.communicationRounds - 1) {
227
266
  // Reuse the aggregation result
228
- this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
267
+ this.aggregationResult = this.aggregator.getPromiseForAggregation();
229
268
  }
230
269
  }
231
- // Reset the peers list for the next round
232
- this.aggregator.resetNodes();
233
270
  return await this.aggregationResult;
234
271
  }
235
272
  }
@@ -1,7 +1,8 @@
1
- import { weights } from '../../serialization/index.js';
1
+ import { serialization } from "../../index.js";
2
2
  import { type SignalData } from './peer.js';
3
3
  import { type NodeID } from '../types.js';
4
- import { type, type ClientConnected } from '../messages.js';
4
+ import { type } from '../messages.js';
5
+ import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
5
6
  export interface NewDecentralizedNodeInfo {
6
7
  type: type.NewDecentralizedNodeInfo;
7
8
  id: NodeID;
@@ -12,21 +13,26 @@ export interface SignalForPeer {
12
13
  peer: NodeID;
13
14
  signal: SignalData;
14
15
  }
16
+ export interface JoinRound {
17
+ type: type.JoinRound;
18
+ }
15
19
  export interface PeerIsReady {
16
20
  type: type.PeerIsReady;
17
21
  }
18
22
  export interface PeersForRound {
19
23
  type: type.PeersForRound;
20
24
  peers: NodeID[];
25
+ aggregationRound: number;
21
26
  }
22
27
  export interface Payload {
23
28
  type: type.Payload;
24
29
  peer: NodeID;
25
- round: number;
26
- payload: weights.Encoded;
30
+ aggregationRound: number;
31
+ communicationRound: number;
32
+ payload: serialization.Encoded;
27
33
  }
28
- export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound;
29
- export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady;
34
+ export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound | WaitingForMoreParticipants | EnoughParticipants;
35
+ export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | JoinRound;
30
36
  export type PeerMessage = Payload;
31
37
  export declare function isMessageFromServer(o: unknown): o is MessageFromServer;
32
38
  export declare function isMessageToServer(o: unknown): o is MessageToServer;
@@ -1,10 +1,9 @@
1
- import { weights } from '../../serialization/index.js';
1
+ import { serialization } from "../../index.js";
2
2
  import { isNodeID } from '../types.js';
3
3
  import { type, hasMessageType } from '../messages.js';
4
4
  export function isMessageFromServer(o) {
5
- if (!hasMessageType(o)) {
5
+ if (!hasMessageType(o))
6
6
  return false;
7
- }
8
7
  switch (o.type) {
9
8
  case type.NewDecentralizedNodeInfo:
10
9
  return 'id' in o && isNodeID(o.id) &&
@@ -15,32 +14,34 @@ export function isMessageFromServer(o) {
15
14
  'signal' in o; // TODO check signal content?
16
15
  case type.PeersForRound:
17
16
  return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID);
17
+ case type.WaitingForMoreParticipants:
18
+ case type.EnoughParticipants:
19
+ return true;
18
20
  }
19
21
  return false;
20
22
  }
21
23
  export function isMessageToServer(o) {
22
- if (!hasMessageType(o)) {
24
+ if (!hasMessageType(o))
23
25
  return false;
24
- }
25
26
  switch (o.type) {
26
27
  case type.ClientConnected:
27
28
  return true;
28
29
  case type.SignalForPeer:
29
30
  return 'peer' in o && isNodeID(o.peer) &&
30
31
  'signal' in o; // TODO check signal content?
32
+ case type.JoinRound:
31
33
  case type.PeerIsReady:
32
34
  return true;
33
35
  }
34
36
  return false;
35
37
  }
36
38
  export function isPeerMessage(o) {
37
- if (!hasMessageType(o)) {
39
+ if (!hasMessageType(o))
38
40
  return false;
39
- }
40
41
  switch (o.type) {
41
42
  case type.Payload:
42
43
  return ('peer' in o && isNodeID(o.peer) &&
43
- 'payload' in o && weights.isEncoded(o.payload));
44
+ 'payload' in o && serialization.isEncoded(o.payload));
44
45
  }
45
46
  return false;
46
47
  }
@@ -1,6 +1,6 @@
1
1
  import createDebug from "debug";
2
2
  import WebSocket from "isomorphic-ws";
3
- import msgpack from "msgpack-lite";
3
+ import * as msgpack from "@msgpack/msgpack";
4
4
  import * as decentralizedMessages from './decentralized/messages.js';
5
5
  import { type } from './messages.js';
6
6
  import { timeout } from './utils.js';
@@ -57,7 +57,7 @@ export class PeerConnection extends EventEmitter {
57
57
  if (!decentralizedMessages.isPeerMessage(msg)) {
58
58
  throw new Error(`can't send this type of message: ${JSON.stringify(msg)}`);
59
59
  }
60
- this.peer.send(msgpack.encode(msg));
60
+ this.peer.send(Buffer.from(msgpack.encode(msg)));
61
61
  }
62
62
  async disconnect() {
63
63
  await this.peer.destroy();