@epfml/discojs 3.0.1-p20250331133703.0 → 3.0.1-p20250402090722.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,9 @@ import { EventEmitter } from '../utils/event_emitter.js';
9
9
  */
10
10
  export declare abstract class Client extends EventEmitter<{
11
11
  'status': RoundStatus;
12
+ 'participants': number;
12
13
  }> {
14
+ #private;
13
15
  readonly url: URL;
14
16
  readonly task: Task<DataType>;
15
17
  readonly aggregator: Aggregator;
@@ -22,13 +24,6 @@ export declare abstract class Client extends EventEmitter<{
22
24
  * until the server signals that the training can resume
23
25
  */
24
26
  protected promiseForMoreParticipants: Promise<void> | undefined;
25
- /**
26
- * When the server notifies the client that they can resume training
27
- * after waiting for more participants, we want to be able to display what
28
- * we were doing before waiting (training locally or updating our model).
29
- * We use this attribute to store the status to rollback to when we stop waiting
30
- */
31
- private previousStatus;
32
27
  constructor(url: URL, // The network server's URL to connect to
33
28
  task: Task<DataType>, // The client's corresponding task
34
29
  aggregator: Aggregator);
@@ -101,7 +96,12 @@ export declare abstract class Client extends EventEmitter<{
101
96
  * If federated, it should the number of participants excluding the server
102
97
  * If local it should be 1
103
98
  */
104
- abstract getNbOfParticipants(): number;
99
+ get nbOfParticipants(): number;
100
+ /**
101
+ * Setter for the number of participants
102
+ * It emits the number of participants to the client
103
+ */
104
+ set nbOfParticipants(nbOfParticipants: number);
105
105
  get ownId(): NodeID;
106
106
  get server(): EventConnection;
107
107
  /**
@@ -29,7 +29,9 @@ export class Client extends EventEmitter {
29
29
  * we were doing before waiting (training locally or updating our model).
30
30
  * We use this attribute to store the status to rollback to when we stop waiting
31
31
  */
32
- previousStatus;
32
+ #previousStatus;
33
+ // Current number of participants including this client in the training session
34
+ #nbOfParticipants = 1;
33
35
  constructor(url, // The network server's URL to connect to
34
36
  task, // The client's corresponding task
35
37
  aggregator) {
@@ -56,7 +58,7 @@ export class Client extends EventEmitter {
56
58
  * the waiting status and once enough participants join, it can display the previous status again
57
59
  */
58
60
  saveAndEmit(status) {
59
- this.previousStatus = status;
61
+ this.#previousStatus = status;
60
62
  this.emit("status", status);
61
63
  }
62
64
  /**
@@ -84,12 +86,13 @@ export class Client extends EventEmitter {
84
86
  setupServerCallbacks(setMessageInversionFlag) {
85
87
  // Setup an event callback if the server signals that we should
86
88
  // wait for more participants
87
- this.server.on(type.WaitingForMoreParticipants, () => {
89
+ this.server.on(type.WaitingForMoreParticipants, (event) => {
88
90
  if (this.promiseForMoreParticipants !== undefined)
89
91
  throw new Error("Server sent multiple WaitingForMoreParticipants messages");
90
92
  debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`);
91
93
  // Display the waiting status right away
92
94
  this.emit("status", "not enough participants");
95
+ this.nbOfParticipants = event.nbOfParticipants; // emits the `participants` event
93
96
  // Upon receiving a WaitingForMoreParticipants message,
94
97
  // the client will await for this promise to resolve before sending its
95
98
  // local weight update
@@ -101,10 +104,10 @@ export class Client extends EventEmitter {
101
104
  // and directly follows with an EnoughParticipants message when the 2nd participant joins
102
105
  // However, the EnoughParticipants can arrive before the NewNodeInfo (which can be much bigger)
103
106
  // so we check whether we received the EnoughParticipants before being assigned a node ID
104
- this.server.once(type.EnoughParticipants, () => {
107
+ this.server.once(type.EnoughParticipants, (event) => {
105
108
  if (this._ownId === undefined) {
106
- debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
107
109
  setMessageInversionFlag();
110
+ this.nbOfParticipants = event.nbOfParticipants;
108
111
  }
109
112
  });
110
113
  }
@@ -118,11 +121,12 @@ export class Client extends EventEmitter {
118
121
  async createPromiseForMoreParticipants() {
119
122
  return new Promise((resolve) => {
120
123
  // "once" is important because we can't resolve the same promise multiple times
121
- this.server.once(type.EnoughParticipants, () => {
124
+ this.server.once(type.EnoughParticipants, (event) => {
122
125
  debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`);
123
126
  // Emit the last status emitted before waiting if defined
124
- if (this.previousStatus !== undefined)
125
- this.emit("status", this.previousStatus);
127
+ if (this.#previousStatus !== undefined)
128
+ this.emit("status", this.#previousStatus);
129
+ this.nbOfParticipants = event.nbOfParticipants;
126
130
  resolve();
127
131
  });
128
132
  });
@@ -154,6 +158,23 @@ export class Client extends EventEmitter {
154
158
  const encoded = new Uint8Array(await response.arrayBuffer());
155
159
  return await serialization.model.decode(encoded);
156
160
  }
161
+ /**
162
+ * Number of contributors to a collaborative session
163
+ * If decentralized, it should be the number of peers
164
+ * If federated, it should the number of participants excluding the server
165
+ * If local it should be 1
166
+ */
167
+ get nbOfParticipants() {
168
+ return this.#nbOfParticipants;
169
+ }
170
+ /**
171
+ * Setter for the number of participants
172
+ * It emits the number of participants to the client
173
+ */
174
+ set nbOfParticipants(nbOfParticipants) {
175
+ this.#nbOfParticipants = nbOfParticipants;
176
+ this.emit("participants", nbOfParticipants);
177
+ }
157
178
  get ownId() {
158
179
  if (this._ownId === undefined) {
159
180
  throw new Error('the node is not connected');
@@ -8,8 +8,8 @@ import { Client } from '../client.js';
8
8
  */
9
9
  export declare class DecentralizedClient extends Client {
10
10
  #private;
11
- getNbOfParticipants(): number;
12
11
  private get isDisconnected();
12
+ private setAggregatorNodes;
13
13
  /**
14
14
  * Public method called by disco.ts when starting training. This method sends
15
15
  * a message to the server asking to join the task and be assigned a client ID.
@@ -20,14 +20,15 @@ export class DecentralizedClient extends Client {
20
20
  */
21
21
  #pool;
22
22
  #connections;
23
- getNbOfParticipants() {
24
- const nbOfParticipants = this.aggregator.nodes.size;
25
- return nbOfParticipants === 0 ? 1 : nbOfParticipants;
26
- }
27
23
  // Used to handle timeouts and promise resolving after calling disconnect
28
24
  get isDisconnected() {
29
25
  return this._server === undefined;
30
26
  }
27
+ setAggregatorNodes(nodes) {
28
+ this.aggregator.setNodes(nodes);
29
+ // Emits the `participants` event
30
+ this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size;
31
+ }
31
32
  /**
32
33
  * Public method called by disco.ts when starting training. This method sends
33
34
  * a message to the server asking to join the task and be assigned a client ID.
@@ -67,7 +68,8 @@ export class DecentralizedClient extends Client {
67
68
  type: type.ClientConnected
68
69
  };
69
70
  this.server.send(msg);
70
- const { id, waitForMoreParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
71
+ const { id, waitForMoreParticipants, nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
72
+ this.nbOfParticipants = nbOfParticipants;
71
73
  // This should come right after receiving the message to make sure
72
74
  // we don't miss a subsequent message from the server
73
75
  // We check if the server is telling us to wait for more participants
@@ -92,7 +94,7 @@ export class DecentralizedClient extends Client {
92
94
  this.#pool = undefined;
93
95
  if (this.#connections !== undefined) {
94
96
  const peers = this.#connections.keySeq().toSet();
95
- this.aggregator.setNodes(this.aggregator.nodes.subtract(peers));
97
+ this.setAggregatorNodes(this.aggregator.nodes.subtract(peers));
96
98
  }
97
99
  // Disconnect from server
98
100
  await this.server?.disconnect();
@@ -158,7 +160,7 @@ export class DecentralizedClient extends Client {
158
160
  throw new Error('received peer list contains our own id');
159
161
  }
160
162
  // Store the list of peers for the current round including ourselves
161
- this.aggregator.setNodes(peers.add(this.ownId));
163
+ this.setAggregatorNodes(peers.add(this.ownId));
162
164
  this.aggregator.setRound(receivedMessage.aggregationRound); // the server gives us the round number
163
165
  // Initiate peer to peer connections with each peer
164
166
  // When connected, create a promise waiting for each peer's round contribution
@@ -171,7 +173,7 @@ export class DecentralizedClient extends Client {
171
173
  }
172
174
  catch (e) {
173
175
  debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
174
- this.aggregator.setNodes(Set(this.ownId));
176
+ this.setAggregatorNodes(Set(this.ownId));
175
177
  this.#connections = Map();
176
178
  }
177
179
  }
@@ -7,6 +7,7 @@ export interface NewDecentralizedNodeInfo {
7
7
  type: type.NewDecentralizedNodeInfo;
8
8
  id: NodeID;
9
9
  waitForMoreParticipants: boolean;
10
+ nbOfParticipants: number;
10
11
  }
11
12
  export interface SignalForPeer {
12
13
  type: type.SignalForPeer;
@@ -5,8 +5,6 @@ import { Client } from "../client.js";
5
5
  * a specific task in the federated setting.
6
6
  */
7
7
  export declare class FederatedClient extends Client {
8
- #private;
9
- getNbOfParticipants(): number;
10
8
  /**
11
9
  * Initializes the connection to the server, gets our node ID
12
10
  * as well as the latest training information: latest global model, current round and
@@ -16,13 +16,6 @@ const SERVER_NODE_ID = "federated-server-node-id";
16
16
  * a specific task in the federated setting.
17
17
  */
18
18
  export class FederatedClient extends Client {
19
- // Total number of other federated contributors, including this client, excluding the server
20
- // E.g., if 3 users are training a federated model, nbOfParticipants is 3
21
- #nbOfParticipants = 1;
22
- // the number of participants excluding the server
23
- getNbOfParticipants() {
24
- return this.#nbOfParticipants;
25
- }
26
19
  /**
27
20
  * Initializes the connection to the server, gets our node ID
28
21
  * as well as the latest training information: latest global model, current round and
@@ -70,7 +63,7 @@ export class FederatedClient extends Client {
70
63
  this._ownId = id;
71
64
  debug(`[${shortenId(id)}] joined session at round ${round} `);
72
65
  this.aggregator.setRound(round);
73
- this.#nbOfParticipants = nbOfParticipants;
66
+ this.nbOfParticipants = nbOfParticipants;
74
67
  // Upon connecting, the server answers with a boolean
75
68
  // which indicates whether there are enough participants or not
76
69
  debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants);
@@ -129,7 +122,7 @@ export class FederatedClient extends Client {
129
122
  this.server.send(msg);
130
123
  debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`);
131
124
  const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
132
- this.#nbOfParticipants = nbOfParticipants; // Save the current participants
125
+ this.nbOfParticipants = nbOfParticipants; // Save the current participants
133
126
  const serverResult = serialization.weights.decode(payloadFromServer);
134
127
  this.aggregator.setRound(serverRound);
135
128
  return serverResult;
@@ -5,7 +5,6 @@ import { Client } from "./client.js";
5
5
  * with anyone. Thus LocalClient doesn't do anything during communication
6
6
  */
7
7
  export declare class LocalClient extends Client {
8
- getNbOfParticipants(): number;
9
8
  onRoundBeginCommunication(): Promise<void>;
10
9
  onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
11
10
  }
@@ -4,9 +4,6 @@ import { Client } from "./client.js";
4
4
  * with anyone. Thus LocalClient doesn't do anything during communication
5
5
  */
6
6
  export class LocalClient extends Client {
7
- getNbOfParticipants() {
8
- return 1;
9
- }
10
7
  onRoundBeginCommunication() {
11
8
  return Promise.resolve();
12
9
  }
@@ -19,9 +19,11 @@ export interface ClientConnected {
19
19
  }
20
20
  export interface EnoughParticipants {
21
21
  type: type.EnoughParticipants;
22
+ nbOfParticipants: number;
22
23
  }
23
24
  export interface WaitingForMoreParticipants {
24
25
  type: type.WaitingForMoreParticipants;
26
+ nbOfParticipants: number;
25
27
  }
26
28
  export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
27
29
  export type NarrowMessage<D> = Extract<Message, {
@@ -26,6 +26,7 @@ export type RoundStatus = 'not enough participants' | // Server notification to
26
26
  */
27
27
  export declare class Disco<D extends DataType> extends EventEmitter<{
28
28
  status: RoundStatus;
29
+ participants: number;
29
30
  }> {
30
31
  #private;
31
32
  readonly trainer: Trainer<D>;
@@ -53,6 +53,7 @@ export class Disco extends EventEmitter {
53
53
  this.trainer = new Trainer(task, client);
54
54
  // Simply propagate the training status events emitted by the client
55
55
  this.#client.on("status", (status) => this.emit("status", status));
56
+ this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants));
56
57
  }
57
58
  /** Train on dataset, yielding logs of every round. */
58
59
  async *trainByRound(dataset) {
@@ -62,7 +62,7 @@ export class Trainer {
62
62
  }
63
63
  return {
64
64
  epochs: epochsLogs,
65
- participants: this.#client.getNbOfParticipants(),
65
+ participants: this.#client.nbOfParticipants,
66
66
  };
67
67
  }
68
68
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20250331133703.0",
3
+ "version": "3.0.1-p20250402090722.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",