@epfml/discojs 2.2.2-p20240703101552.0 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/dist/aggregator/base.d.ts +9 -48
  2. package/dist/aggregator/base.js +8 -69
  3. package/dist/aggregator/get.d.ts +23 -11
  4. package/dist/aggregator/get.js +40 -23
  5. package/dist/aggregator/index.d.ts +1 -1
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +25 -6
  8. package/dist/aggregator/mean.js +62 -17
  9. package/dist/aggregator/secure.d.ts +2 -2
  10. package/dist/aggregator/secure.js +4 -7
  11. package/dist/client/base.d.ts +3 -3
  12. package/dist/client/base.js +6 -8
  13. package/dist/client/decentralized/base.d.ts +27 -10
  14. package/dist/client/decentralized/base.js +123 -86
  15. package/dist/client/decentralized/peer.js +7 -12
  16. package/dist/client/decentralized/peer_pool.js +6 -2
  17. package/dist/client/event_connection.d.ts +1 -1
  18. package/dist/client/event_connection.js +3 -3
  19. package/dist/client/federated/base.d.ts +5 -21
  20. package/dist/client/federated/base.js +38 -61
  21. package/dist/client/federated/messages.d.ts +2 -10
  22. package/dist/client/federated/messages.js +0 -1
  23. package/dist/client/index.d.ts +1 -1
  24. package/dist/client/index.js +1 -1
  25. package/dist/client/local.d.ts +3 -1
  26. package/dist/client/local.js +4 -1
  27. package/dist/client/messages.d.ts +1 -2
  28. package/dist/client/messages.js +8 -3
  29. package/dist/client/utils.d.ts +4 -2
  30. package/dist/client/utils.js +18 -3
  31. package/dist/dataset/data/data.d.ts +1 -1
  32. package/dist/dataset/data/data.js +13 -2
  33. package/dist/dataset/data/preprocessing/image_preprocessing.js +6 -4
  34. package/dist/default_tasks/cifar10.js +1 -2
  35. package/dist/default_tasks/lus_covid.js +0 -5
  36. package/dist/default_tasks/mnist.js +15 -14
  37. package/dist/default_tasks/simple_face.js +0 -2
  38. package/dist/default_tasks/titanic.js +2 -4
  39. package/dist/default_tasks/wikitext.js +7 -1
  40. package/dist/index.d.ts +0 -1
  41. package/dist/index.js +0 -1
  42. package/dist/models/gpt/config.js +1 -1
  43. package/dist/privacy.d.ts +8 -10
  44. package/dist/privacy.js +25 -40
  45. package/dist/task/task_handler.js +10 -2
  46. package/dist/task/training_information.d.ts +7 -4
  47. package/dist/task/training_information.js +25 -6
  48. package/dist/training/disco.d.ts +30 -28
  49. package/dist/training/disco.js +75 -73
  50. package/dist/training/index.d.ts +1 -1
  51. package/dist/training/index.js +1 -0
  52. package/dist/training/trainer.d.ts +16 -0
  53. package/dist/training/trainer.js +72 -0
  54. package/dist/types.d.ts +0 -2
  55. package/dist/weights/weights_container.d.ts +0 -5
  56. package/dist/weights/weights_container.js +0 -7
  57. package/package.json +1 -1
  58. package/dist/async_informant.d.ts +0 -15
  59. package/dist/async_informant.js +0 -42
  60. package/dist/training/trainer/distributed_trainer.d.ts +0 -20
  61. package/dist/training/trainer/distributed_trainer.js +0 -41
  62. package/dist/training/trainer/local_trainer.d.ts +0 -12
  63. package/dist/training/trainer/local_trainer.js +0 -24
  64. package/dist/training/trainer/trainer.d.ts +0 -32
  65. package/dist/training/trainer/trainer.js +0 -61
  66. package/dist/training/trainer/trainer_builder.d.ts +0 -23
  67. package/dist/training/trainer/trainer_builder.js +0 -47
@@ -1,7 +1,5 @@
1
- import { type WeightsContainer } from '../../index.js';
1
+ import type { WeightsContainer } from "../../index.js";
2
2
  import { Client } from '../index.js';
3
- import { type PeerConnection } from '../event_connection.js';
4
- import * as messages from './messages.js';
5
3
  /**
6
4
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
7
5
  * help of the network's server, yet only exchange payloads between each other. Communication
@@ -14,19 +12,38 @@ export declare class Base extends Client {
14
12
  */
15
13
  private pool?;
16
14
  private connections?;
15
+ private get isDisconnected();
17
16
  /**
18
- * Send message to server that this client is ready for the next training round.
17
+ * Public method called by disco.ts when starting training. This method sends
18
+ * a message to the server asking to join the task and be assigned a client ID.
19
+ *
20
+ * The peer also establishes a WebSocket connection with the server to then
21
+ * create peer-to-peer WebRTC connections with peers. The server is used to exchange
22
+ * peers network information.
19
23
  */
20
- private waitForPeers;
21
- protected sendMessagetoPeer(peer: PeerConnection, msg: messages.PeerMessage): void;
24
+ connect(): Promise<void>;
22
25
  /**
23
- * Creation of the WebSocket for the server, connection of client to that WebSocket,
24
- * deals with message reception from the decentralized client's perspective (messages received by client).
26
+ * Create a WebSocket connection with the server
27
+ * The client then waits for the server to forward it other client's network information.
28
+ * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
25
29
  */
26
30
  private connectServer;
27
- connect(): Promise<void>;
28
31
  disconnect(): Promise<void>;
32
+ /**
33
+ * At the beginning of a round, each peer tells the server it is ready to proceed
34
+ * The server answers with the list of all peers connected for the round
35
+ * Given the list, the peers then create peer-to-peer connections with each other.
36
+ * When connected, one peer creates a promise for every other peer's weight update
37
+ * and waits for it to resolve.
38
+ *
39
+ */
29
40
  onRoundBeginCommunication(_: WeightsContainer, round: number): Promise<void>;
30
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
41
+ /**
42
+ * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
43
+ * This method is used as callback by getPeers when connecting to the rounds' peers
44
+ * @param connections
45
+ * @param round
46
+ */
31
47
  private receivePayloads;
48
+ onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
32
49
  }
@@ -1,5 +1,5 @@
1
1
  import { Map, Set } from 'immutable';
2
- import { serialization } from '../../index.js';
2
+ import { serialization } from "../../index.js";
3
3
  import { Client } from '../index.js';
4
4
  import { type } from '../messages.js';
5
5
  import { timeout } from '../utils.js';
@@ -18,63 +18,18 @@ export class Base extends Client {
18
18
  */
19
19
  pool;
20
20
  connections;
21
- /**
22
- * Send message to server that this client is ready for the next training round.
23
- */
24
- async waitForPeers(round) {
25
- console.info(`[${this.ownId}] is ready for round`, round);
26
- // Broadcast our readiness
27
- const readyMessage = { type: type.PeerIsReady };
28
- if (this.server === undefined) {
29
- throw new Error('server undefined, could not connect peers');
30
- }
31
- this.server.send(readyMessage);
32
- // Wait for peers to be connected before sending any update information
33
- try {
34
- const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound);
35
- if (this.nodes.size > 0) {
36
- throw new Error('got new peer list from server but was already received for this round');
37
- }
38
- const peers = Set(receivedMessage.peers);
39
- console.info(`[${this.ownId}] received peers for round:`, peers.toJS());
40
- if (this.ownId !== undefined && peers.has(this.ownId)) {
41
- throw new Error('received peer list contains our own id');
42
- }
43
- this.aggregator.setNodes(peers.add(this.ownId));
44
- if (this.pool === undefined) {
45
- throw new Error('waiting for peers but peer pool is undefined');
46
- }
47
- const connections = await this.pool.getPeers(peers, this.server,
48
- // Init receipt of peers weights
49
- (conn) => { this.receivePayloads(conn, round); });
50
- console.info(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS());
51
- return connections;
52
- }
53
- catch (e) {
54
- console.error(e);
55
- this.aggregator.setNodes(Set(this.ownId));
56
- return Map();
57
- }
58
- }
59
- sendMessagetoPeer(peer, msg) {
60
- console.info(`[${this.ownId}] send message to peer`, msg.peer, msg);
61
- peer.send(msg);
21
+ // Used to handle timeouts and promise resolving after calling disconnect
22
+ get isDisconnected() {
23
+ return this._server === undefined;
62
24
  }
63
25
  /**
64
- * Creation of the WebSocket for the server, connection of client to that WebSocket,
65
- * deals with message reception from the decentralized client's perspective (messages received by client).
26
+ * Public method called by disco.ts when starting training. This method sends
27
+ * a message to the server asking to join the task and be assigned a client ID.
28
+ *
29
+ * The peer also establishes a WebSocket connection with the server to then
30
+ * create peer-to-peer WebRTC connections with peers. The server is used to exchange
31
+ * peers network information.
66
32
  */
67
- async connectServer(url) {
68
- const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
69
- server.on(type.SignalForPeer, (event) => {
70
- console.info(`[${this.ownId}] received signal from`, event.peer);
71
- if (this.pool === undefined) {
72
- throw new Error('received signal but peer pool is undefined');
73
- }
74
- this.pool.signal(event.peer, event.signal);
75
- });
76
- return server;
77
- }
78
33
  async connect() {
79
34
  const serverURL = new URL('', this.url.href);
80
35
  switch (this.url.protocol) {
@@ -94,13 +49,29 @@ export class Base extends Client {
94
49
  };
95
50
  this.server.send(msg);
96
51
  const peerIdMsg = await waitMessage(this.server, type.AssignNodeID);
97
- console.info(`[${peerIdMsg.id}] assigned id generated by server`);
52
+ console.log(`[${peerIdMsg.id}] assigned id generated by server`);
98
53
  if (this._ownId !== undefined) {
99
54
  throw new Error('received id from server but was already received');
100
55
  }
101
56
  this._ownId = peerIdMsg.id;
102
57
  this.pool = new PeerPool(peerIdMsg.id);
103
58
  }
59
+ /**
60
+ * Create a WebSocket connection with the server
61
+ * The client then waits for the server to forward it other client's network information.
62
+ * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
63
+ */
64
+ async connectServer(url) {
65
+ const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
66
+ server.on(type.SignalForPeer, (event) => {
67
+ if (this.pool === undefined) {
68
+ throw new Error('received signal but peer pool is undefined');
69
+ }
70
+ // Create a WebRTC connection with the peer
71
+ this.pool.signal(event.peer, event.signal);
72
+ });
73
+ return server;
74
+ }
104
75
  async disconnect() {
105
76
  // Disconnect from peers
106
77
  await this.pool?.shutdown();
@@ -115,20 +86,92 @@ export class Base extends Client {
115
86
  this._ownId = undefined;
116
87
  return Promise.resolve();
117
88
  }
89
+ /**
90
+ * At the beginning of a round, each peer tells the server it is ready to proceed
91
+ * The server answers with the list of all peers connected for the round
92
+ * Given the list, the peers then create peer-to-peer connections with each other.
93
+ * When connected, one peer creates a promise for every other peer's weight update
94
+ * and waits for it to resolve.
95
+ *
96
+ */
118
97
  async onRoundBeginCommunication(_, round) {
98
+ if (this.server === undefined) {
99
+ throw new Error("peer's server is undefined, make sure to call `client.connect()` first");
100
+ }
101
+ if (this.pool === undefined) {
102
+ throw new Error('peer pool is undefined, make sure to call `client.connect()` first');
103
+ }
119
104
  // Reset peers list at each round of training to make sure client works with an updated peers
120
105
  // list, maintained by the server. Adds any received weights to the aggregator.
121
- this.connections = await this.waitForPeers(round);
106
+ // this.connections = await this.waitForPeers(round)
107
+ // Tell the server we are ready for the next round
108
+ const readyMessage = { type: type.PeerIsReady };
109
+ this.server.send(readyMessage);
110
+ // Wait for the server to answer with the list of peers for the round
111
+ try {
112
+ const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound, undefined, "Timeout waiting for the round's peer list");
113
+ const peers = Set(receivedMessage.peers);
114
+ if (this.ownId !== undefined && peers.has(this.ownId)) {
115
+ throw new Error('received peer list contains our own id');
116
+ }
117
+ // Store the list of peers for the current round including ourselves
118
+ this.aggregator.setNodes(peers.add(this.ownId));
119
+ // Initiate peer to peer connections with each peer
120
+ // When connected, create a promise waiting for each peer's round contribution
121
+ const connections = await this.pool.getPeers(peers, this.server,
122
+ // Init receipt of peers weights
123
+ // this awaits the peer's weight update and adds it to
124
+ // our aggregator upon reception
125
+ (conn) => { this.receivePayloads(conn, round); });
126
+ console.log(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS());
127
+ this.connections = connections;
128
+ }
129
+ catch (e) {
130
+ console.error(e);
131
+ this.aggregator.setNodes(Set(this.ownId));
132
+ this.connections = Map();
133
+ }
122
134
  // Store the promise for the current round's aggregation result.
123
- this.aggregationResult = this.aggregator.receiveResult();
135
+ // We will await for it to resolve at the end of the round when exchanging weight updates.
136
+ this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
137
+ }
138
+ /**
139
+ * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
140
+ * This method is used as callback by getPeers when connecting to the rounds' peers
141
+ * @param connections
142
+ * @param round
143
+ */
144
+ receivePayloads(connections, round) {
145
+ connections.forEach(async (connection, peerId) => {
146
+ let currentCommunicationRounds = 0;
147
+ console.log(`waiting for peer ${peerId}`);
148
+ do {
149
+ try {
150
+ const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId);
151
+ const decoded = serialization.weights.decode(message.payload);
152
+ if (!this.aggregator.add(peerId, decoded, round, message.round)) {
153
+ console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`);
154
+ }
155
+ }
156
+ catch (e) {
157
+ if (this.isDisconnected) {
158
+ return;
159
+ }
160
+ console.error(e instanceof Error ? e.message : e);
161
+ }
162
+ } while (++currentCommunicationRounds < this.aggregator.communicationRounds);
163
+ });
124
164
  }
125
165
  async onRoundEndCommunication(weights, round) {
126
- let result = weights;
166
+ if (this.aggregationResult === undefined) {
167
+ throw new TypeError('aggregation result promise is undefined');
168
+ }
127
169
  // Perform the required communication rounds. Each communication round consists in sending our local payload,
128
170
  // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
129
171
  // A communication round's payload is the aggregation result of the previous communication round. The first
130
172
  // communication round simply sends our training result, i.e. model weights updates. This scheme allows for
131
173
  // the aggregator to define any complex multi-round aggregation mechanism.
174
+ let result = weights;
132
175
  for (let r = 0; r < this.aggregator.communicationRounds; r++) {
133
176
  // Generate our payloads for this communication round and send them to all ready connected peers
134
177
  if (this.connections !== undefined) {
@@ -139,15 +182,17 @@ export class Base extends Client {
139
182
  this.aggregator.add(this.ownId, payload, round, r);
140
183
  }
141
184
  else {
142
- const connection = this.connections?.get(id);
143
- if (connection !== undefined) {
185
+ const peer = this.connections?.get(id);
186
+ if (peer !== undefined) {
144
187
  const encoded = await serialization.weights.encode(payload);
145
- this.sendMessagetoPeer(connection, {
188
+ const msg = {
146
189
  type: type.Payload,
147
190
  peer: id,
148
191
  round: r,
149
192
  payload: encoded
150
- });
193
+ };
194
+ peer.send(msg);
195
+ console.log(`[${this.ownId}] send weight update to peer`, msg.peer, msg);
151
196
  }
152
197
  }
153
198
  }));
@@ -156,37 +201,29 @@ export class Base extends Client {
156
201
  throw new Error('error while sending weights');
157
202
  }
158
203
  }
159
- if (this.aggregationResult === undefined) {
160
- throw new TypeError('aggregation result promise is undefined');
161
- }
162
204
  // Wait for aggregation before proceeding to the next communication round.
163
205
  // The current result will be used as payload for the eventual next communication round.
164
- result = await Promise.race([this.aggregationResult, timeout()]);
206
+ try {
207
+ result = await Promise.race([
208
+ this.aggregationResult,
209
+ timeout(undefined, "Timeout waiting on the aggregation result promise to resolve")
210
+ ]);
211
+ }
212
+ catch (e) {
213
+ if (this.isDisconnected) {
214
+ return weights;
215
+ }
216
+ console.error(e);
217
+ break;
218
+ }
165
219
  // There is at least one communication round remaining
166
220
  if (r < this.aggregator.communicationRounds - 1) {
167
221
  // Reuse the aggregation result
168
- this.aggregationResult = this.aggregator.receiveResult();
222
+ this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
169
223
  }
170
224
  }
171
225
  // Reset the peers list for the next round
172
226
  this.aggregator.resetNodes();
173
- }
174
- receivePayloads(connections, round) {
175
- console.info(`[${this.ownId}] Accepting new contributions for round ${round}`);
176
- connections.forEach(async (connection, peerId) => {
177
- let receivedPayloads = 0;
178
- do {
179
- try {
180
- const message = await waitMessageWithTimeout(connection, type.Payload);
181
- const decoded = serialization.weights.decode(message.payload);
182
- if (!this.aggregator.add(peerId, decoded, round, message.round)) {
183
- console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`);
184
- }
185
- }
186
- catch (e) {
187
- console.warn(e instanceof Error ? e.message : e);
188
- }
189
- } while (++receivedPayloads < this.aggregator.communicationRounds);
190
- });
227
+ return await this.aggregationResult;
191
228
  }
192
229
  }
@@ -58,13 +58,7 @@ export class Peer {
58
58
  if (this.bufferSize === undefined) {
59
59
  throw new Error('chunk without known buffer size');
60
60
  }
61
- // in the perfect world of bug-free implementations
62
- // we would return this.bufferSize
63
- // sadly, we are not there yet
64
- //
65
- // based on MDN, taking 16K seems to be a pretty safe
66
- // and widely supported buffer size
67
- return 16 * (1 << 10);
61
+ return this.bufferSize;
68
62
  }
69
63
  chunk(b) {
70
64
  const messageID = this.sendCounter;
@@ -79,7 +73,8 @@ export class Peer {
79
73
  }
80
74
  const totalChunkCount = 1 + tail.count();
81
75
  if (totalChunkCount > 0xFF) {
82
- throw new Error('too big message to even chunk it');
76
+ throw new Error(`The payload is too big: ${totalChunkCount * this.maxChunkSize} bytes > 255,` +
77
+ ' consider reducing the model size or increasing the chunk size');
83
78
  }
84
79
  const firstChunk = Buffer.alloc((b.length > this.maxChunkSize - FIRST_HEADER_SIZE)
85
80
  ? this.maxChunkSize
@@ -106,7 +101,7 @@ export class Peer {
106
101
  });
107
102
  }
108
103
  signal(signal) {
109
- // extract max buffer size
104
+ // extract max buffer size from the signal
110
105
  if (signal.type === 'offer' || signal.type === 'answer') {
111
106
  if (signal.sdp === undefined) {
112
107
  throw new Error('signal answer|offer without session description');
@@ -138,8 +133,8 @@ export class Peer {
138
133
  if (!Buffer.isBuffer(data) || data.length < HEADER_SIZE) {
139
134
  throw new Error('received invalid message type');
140
135
  }
141
- const messageID = data.readUint16BE();
142
- const chunkID = data.readUint8(2);
136
+ const messageID = data.readUInt16BE(); //readUint16BE (case sensitive) fails at runtime
137
+ const chunkID = data.readUInt8(2); // same for readUint8
143
138
  const received = this.receiving.get(messageID, {
144
139
  total: undefined,
145
140
  chunks: Map()
@@ -161,7 +156,7 @@ export class Peer {
161
156
  if (total !== undefined) {
162
157
  throw new Error('first header received twice');
163
158
  }
164
- const readTotal = data.readUint8(3);
159
+ const readTotal = data.readUInt8(3);
165
160
  total = readTotal;
166
161
  chunk = Buffer.alloc(data.length - FIRST_HEADER_SIZE);
167
162
  data.copy(chunk, 0, FIRST_HEADER_SIZE);
@@ -9,8 +9,12 @@ export class PeerPool {
9
9
  this.id = id;
10
10
  }
11
11
  async shutdown() {
12
- console.info(`[${this.id}] shutdown their peers`);
13
- await Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect()));
12
+ console.info(`[${this.id}] is shutting down all its connections`);
13
+ // Add a timeout o.w. the promise hangs forever if the other peer is already disconnected
14
+ await Promise.race([
15
+ Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect())),
16
+ new Promise((res, _) => setTimeout(res, 1000)) // Wait for other peers to finish
17
+ ]);
14
18
  this.peers = Map();
15
19
  }
16
20
  signal(peerId, signal) {
@@ -9,7 +9,7 @@ export interface EventConnection {
9
9
  disconnect: () => Promise<void>;
10
10
  }
11
11
  export declare function waitMessage<T extends type>(connection: EventConnection, type: T): Promise<NarrowMessage<T>>;
12
- export declare function waitMessageWithTimeout<T extends type>(connection: EventConnection, type: T, timeoutMs?: number): Promise<NarrowMessage<T>>;
12
+ export declare function waitMessageWithTimeout<T extends type>(connection: EventConnection, type: T, timeoutMs?: number, errorMsg?: string): Promise<NarrowMessage<T>>;
13
13
  export declare class PeerConnection extends EventEmitter<{
14
14
  [K in type]: NarrowMessage<K>;
15
15
  }> implements EventConnection {
@@ -12,8 +12,8 @@ export async function waitMessage(connection, type) {
12
12
  });
13
13
  });
14
14
  }
15
- export async function waitMessageWithTimeout(connection, type, timeoutMs) {
16
- return await Promise.race([waitMessage(connection, type), timeout(timeoutMs)]);
15
+ export async function waitMessageWithTimeout(connection, type, timeoutMs, errorMsg = 'timeout') {
16
+ return await Promise.race([waitMessage(connection, type), timeout(timeoutMs, errorMsg)]);
17
17
  }
18
18
  export class PeerConnection extends EventEmitter {
19
19
  _ownId;
@@ -41,7 +41,7 @@ export class PeerConnection extends EventEmitter {
41
41
  }
42
42
  this.emit(msg.type, msg);
43
43
  });
44
- this.peer.on('close', () => { console.warn('peer', this.peer.id, 'closed connection'); });
44
+ this.peer.on('close', () => { console.warn('From', this._ownId, ': peer', this.peer.id, 'closed connection'); });
45
45
  await new Promise((resolve) => {
46
46
  this.peer.on('connect', resolve);
47
47
  });
@@ -1,22 +1,18 @@
1
- import { Map } from "immutable";
2
- import { type MetadataKey, type MetadataValue, type WeightsContainer } from "../../index.js";
3
- import { type NodeID } from "../types.js";
1
+ import { type WeightsContainer } from "../../index.js";
4
2
  import { Base as Client } from "../base.js";
5
3
  /**
6
4
  * Client class that communicates with a centralized, federated server, when training
7
5
  * a specific task in the federated setting.
8
6
  */
9
7
  export declare class Base extends Client {
8
+ #private;
10
9
  /**
11
10
  * Arbitrary node id assigned to the federated server which we are communicating with.
12
11
  * Indeed, the server acts as a node within the network. In the federated setting described
13
12
  * by this client class, the server is the only node which we are communicating with.
14
13
  */
15
14
  static readonly SERVER_NODE_ID = "federated-server-node-id";
16
- /**
17
- * Map of metadata values for each node id.
18
- */
19
- private metadataMap?;
15
+ get nbOfParticipants(): number;
20
16
  /**
21
17
  * Opens a new WebSocket connection with the server and listens to new messages over the channel
22
18
  */
@@ -31,24 +27,12 @@ export declare class Base extends Client {
31
27
  * Disconnection process when user quits the task.
32
28
  */
33
29
  disconnect(): Promise<void>;
30
+ onRoundBeginCommunication(): Promise<void>;
31
+ onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
34
32
  /**
35
33
  * Send a message containing our local weight updates to the federated server.
36
34
  * And waits for the server to reply with the most recent aggregated weights
37
35
  * @param payload The weight updates to send
38
36
  */
39
37
  private sendPayloadAndReceiveResult;
40
- /**
41
- * Waits for the server's result for its current (most recent) round and add it to our aggregator.
42
- * Updates the aggregator's round if it's behind the server's.
43
- */
44
- private receiveResult;
45
- /**
46
- * Fetch the metadata values maintained by the federated server, for a given metadata key.
47
- * The values are indexed by node id.
48
- * @param key The metadata key
49
- * @returns The map of node id to metadata value
50
- */
51
- receiveMetadataMap(key: MetadataKey): Promise<Map<NodeID, MetadataValue> | undefined>;
52
- onRoundBeginCommunication(): Promise<void>;
53
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
54
38
  }
@@ -1,4 +1,3 @@
1
- import { Map } from "immutable";
2
1
  import { serialization, } from "../../index.js";
3
2
  import { Base as Client } from "../base.js";
4
3
  import { type } from "../messages.js";
@@ -15,10 +14,13 @@ export class Base extends Client {
15
14
  * by this client class, the server is the only node which we are communicating with.
16
15
  */
17
16
  static SERVER_NODE_ID = "federated-server-node-id";
18
- /**
19
- * Map of metadata values for each node id.
20
- */
21
- metadataMap;
17
+ // Total number of other federated contributors, including this client, excluding the server
18
+ // E.g., if 3 users are training a federated model, nbOfParticipants is 3
19
+ #nbOfParticipants = 1;
20
+ // the number of participants excluding the server
21
+ get nbOfParticipants() {
22
+ return this.#nbOfParticipants;
23
+ }
22
24
  /**
23
25
  * Opens a new WebSocket connection with the server and listens to new messages over the channel
24
26
  */
@@ -64,6 +66,32 @@ export class Base extends Client {
64
66
  this.aggregator.setNodes(this.aggregator.nodes.delete(Base.SERVER_NODE_ID));
65
67
  return Promise.resolve();
66
68
  }
69
+ onRoundBeginCommunication() {
70
+ // Prepare the result promise for the incoming round
71
+ this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
72
+ return Promise.resolve();
73
+ }
74
+ async onRoundEndCommunication(weights, round) {
75
+ // NB: For now, we suppose a fully-federated setting.
76
+ if (this.aggregationResult === undefined) {
77
+ throw new Error("local aggregation result was not set");
78
+ }
79
+ // Send our local contribution to the server
80
+ // and receive the most recent weights as an answer to our contribution
81
+ const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
82
+ if (serverResult !== undefined &&
83
+ this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
84
+ // Regular case: the server sends us its aggregation result which will serve our
85
+ // own aggregation result.
86
+ }
87
+ else {
88
+ // Unexpected case: for some reason, the server result is stale.
89
+ // We proceed to the next round without its result.
90
+ console.info(`[${this.ownId}] Server result is either stale or not received`);
91
+ this.aggregator.nextRound();
92
+ }
93
+ return await this.aggregationResult;
94
+ }
67
95
  /**
68
96
  * Send a message containing our local weight updates to the federated server.
69
97
  * And waits for the server to reply with the most recent aggregated weights
@@ -76,17 +104,13 @@ export class Base extends Client {
76
104
  round: this.aggregator.round,
77
105
  };
78
106
  this.server.send(msg);
79
- // It is important than the client immediately awaits the server result or it may miss it
80
- return await this.receiveResult();
81
- }
82
- /**
83
- * Waits for the server's result for its current (most recent) round and add it to our aggregator.
84
- * Updates the aggregator's round if it's behind the server's.
85
- */
86
- async receiveResult() {
107
+ // Waits for the server's result for its current (most recent) round and add it to our aggregator.
108
+ // Updates the aggregator's round if it's behind the server's.
87
109
  try {
88
- const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload);
110
+ // It is important than the client immediately awaits the server result or it may miss it
111
+ const { payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload);
89
112
  const serverRound = round;
113
+ this.#nbOfParticipants = nbOfParticipants; // Save the current participants
90
114
  // Store the server result only if it is not stale
91
115
  if (this.aggregator.round <= round) {
92
116
  const serverResult = serialization.weights.decode(payload);
@@ -101,51 +125,4 @@ export class Base extends Client {
101
125
  console.error(e);
102
126
  }
103
127
  }
104
- /**
105
- * Fetch the metadata values maintained by the federated server, for a given metadata key.
106
- * The values are indexed by node id.
107
- * @param key The metadata key
108
- * @returns The map of node id to metadata value
109
- */
110
- async receiveMetadataMap(key) {
111
- this.metadataMap = undefined;
112
- const msg = {
113
- type: type.ReceiveServerMetadata,
114
- taskId: this.task.id,
115
- nodeId: this.ownId,
116
- round: this.aggregator.round,
117
- key,
118
- };
119
- this.server.send(msg);
120
- const received = await waitMessageWithTimeout(this.server, type.ReceiveServerMetadata);
121
- if (received.metadataMap !== undefined) {
122
- this.metadataMap = Map(received.metadataMap.filter(([_, v]) => v !== undefined));
123
- }
124
- return this.metadataMap;
125
- }
126
- onRoundBeginCommunication() {
127
- // Prepare the result promise for the incoming round
128
- this.aggregationResult = this.aggregator.receiveResult();
129
- return Promise.resolve();
130
- }
131
- async onRoundEndCommunication(weights, round) {
132
- // NB: For now, we suppose a fully-federated setting.
133
- if (this.aggregationResult === undefined) {
134
- throw new Error("local aggregation result was not set");
135
- }
136
- // Send our local contribution to the server
137
- // and receive the most recent weights as an answer to our contribution
138
- const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
139
- if (serverResult !== undefined &&
140
- this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
141
- // Regular case: the server sends us its aggregation result which will serve our
142
- // own aggregation result.
143
- }
144
- else {
145
- // Unexpected case: for some reason, the server result is stale.
146
- // We proceed to the next round without its result.
147
- console.info(`[${this.ownId}] Server result is either stale or not received`);
148
- this.aggregator.nextRound();
149
- }
150
- }
151
128
  }
@@ -1,7 +1,6 @@
1
- import { type client, type MetadataKey, type MetadataValue } from '../../index.js';
2
1
  import { type weights } from '../../serialization/index.js';
3
2
  import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
4
- export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | ReceiveServerMetadata | AssignNodeID;
3
+ export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | AssignNodeID;
5
4
  export interface SendPayload {
6
5
  type: type.SendPayload;
7
6
  payload: weights.Encoded;
@@ -11,13 +10,6 @@ export interface ReceiveServerPayload {
11
10
  type: type.ReceiveServerPayload;
12
11
  payload: weights.Encoded;
13
12
  round: number;
14
- }
15
- export interface ReceiveServerMetadata {
16
- type: type.ReceiveServerMetadata;
17
- nodeId: client.NodeID;
18
- taskId: string;
19
- round: number;
20
- key: MetadataKey;
21
- metadataMap?: Array<[client.NodeID, MetadataValue | undefined]>;
13
+ nbOfParticipants: number;
22
14
  }
23
15
  export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
@@ -7,7 +7,6 @@ export function isMessageFederated(raw) {
7
7
  case type.ClientConnected:
8
8
  case type.SendPayload:
9
9
  case type.ReceiveServerPayload:
10
- case type.ReceiveServerMetadata:
11
10
  case type.AssignNodeID:
12
11
  return true;
13
12
  }
@@ -4,5 +4,5 @@ export * as aggregator from '../aggregator/index.js';
4
4
  export * as decentralized from './decentralized/index.js';
5
5
  export * as federated from './federated/index.js';
6
6
  export * as messages from './messages.js';
7
- export * as utils from './utils.js';
7
+ export { getClient, timeout } from './utils.js';
8
8
  export { Local } from './local.js';