@epfml/discojs 3.0.1-p20241001093123.0 → 3.0.1-p20241014092014.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
  2. package/dist/aggregator/{base.js → aggregator.js} +48 -36
  3. package/dist/aggregator/get.d.ts +2 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/index.d.ts +1 -4
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +4 -4
  8. package/dist/aggregator/mean.js +5 -15
  9. package/dist/aggregator/secure.d.ts +4 -4
  10. package/dist/aggregator/secure.js +7 -17
  11. package/dist/client/client.d.ts +71 -17
  12. package/dist/client/client.js +115 -14
  13. package/dist/client/decentralized/decentralized_client.d.ts +11 -13
  14. package/dist/client/decentralized/decentralized_client.js +121 -84
  15. package/dist/client/decentralized/messages.d.ts +10 -4
  16. package/dist/client/decentralized/messages.js +7 -6
  17. package/dist/client/federated/federated_client.d.ts +1 -13
  18. package/dist/client/federated/federated_client.js +15 -94
  19. package/dist/client/federated/messages.d.ts +2 -7
  20. package/dist/client/local_client.d.ts +1 -0
  21. package/dist/client/local_client.js +3 -0
  22. package/dist/client/messages.d.ts +14 -7
  23. package/dist/client/messages.js +13 -11
  24. package/dist/default_tasks/cifar10.js +1 -1
  25. package/dist/default_tasks/lus_covid.js +1 -0
  26. package/dist/default_tasks/mnist.js +1 -1
  27. package/dist/default_tasks/simple_face.js +1 -0
  28. package/dist/default_tasks/titanic.js +1 -0
  29. package/dist/default_tasks/wikitext.js +1 -0
  30. package/dist/task/training_information.d.ts +1 -2
  31. package/dist/task/training_information.js +6 -8
  32. package/dist/training/disco.d.ts +4 -1
  33. package/dist/training/trainer.js +1 -1
  34. package/dist/utils/event_emitter.d.ts +3 -3
  35. package/dist/utils/event_emitter.js +10 -9
  36. package/package.json +1 -1
@@ -1,6 +1,6 @@
1
1
  import createDebug from "debug";
2
2
  import { serialization } from "../../index.js";
3
- import { Client } from "../client.js";
3
+ import { Client, shortenId } from "../client.js";
4
4
  import { type } from "../messages.js";
5
5
  import { waitMessage, waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
6
6
  import * as messages from "./messages.js";
@@ -19,38 +19,10 @@ export class FederatedClient extends Client {
19
19
  // Total number of other federated contributors, including this client, excluding the server
20
20
  // E.g., if 3 users are training a federated model, nbOfParticipants is 3
21
21
  #nbOfParticipants = 1;
22
- /**
23
- * When the server notifies clients to pause and wait until more
24
- * participants join, we rely on this promise to wait
25
- * until the server signals that the training can resume
26
- */
27
- #promiseForMoreParticipants = undefined;
28
- /**
29
- * When the server notifies the client that they can resume training
30
- * after waiting for more participants, we want to be able to display what
31
- * we were doing before waiting (training locally or updating our model).
32
- * We use this attribute to store the status to rollback to when we stop waiting
33
- */
34
- #previousStatus = undefined;
35
- /**
36
- * Whether the client should wait until more
37
- * participants join the session, i.e. a promise has been created
38
- */
39
- get #waitingForMoreParticipants() {
40
- return this.#promiseForMoreParticipants !== undefined;
41
- }
42
22
  // the number of participants excluding the server
43
- get nbOfParticipants() {
23
+ getNbOfParticipants() {
44
24
  return this.#nbOfParticipants;
45
25
  }
46
- /**
47
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
48
- */
49
- async connectServer(url) {
50
- const server = await WebSocketServer.connect(url, messages.isMessageFederated, // can only receive federated message types from the server
51
- messages.isMessageFederated);
52
- return server;
53
- }
54
26
  /**
55
27
  * Initializes the connection to the server, gets our node ID
56
28
  * as well as the latest training information: latest global model, current round and
@@ -70,31 +42,12 @@ export class FederatedClient extends Client {
70
42
  throw new Error(`unknown protocol: ${this.url.protocol}`);
71
43
  }
72
44
  serverURL.pathname += `federated/${this.task.id}`;
73
- this._server = await this.connectServer(serverURL);
74
- // Setup an event callback if the server signals that we should
75
- // wait for more participants
76
- this.server.on(type.WaitingForMoreParticipants, () => {
77
- debug(`[${id.slice(0, 4)}] received WaitingForMoreParticipants message from server`);
78
- // Display the waiting status right away
79
- this.emit("status", "Waiting for more participants");
80
- // Upon receiving a WaitingForMoreParticipants message,
81
- // the client will await for this promise to resolve before sending its
82
- // local weight update
83
- this.#promiseForMoreParticipants = this.waitForMoreParticipants();
84
- });
85
- // As an example assume we need at least 2 participants to train,
86
- // When two participants join almost at the same time, the server
87
- // sends a NewFederatedNodeInfo with waitForMoreParticipants=true to the first participant
88
- // and directly follows with an EnoughParticipants message when the 2nd participant joins
89
- // However, the EnoughParticipants can arrive before the NewFederatedNodeInfo (which is much bigger)
90
- // so we check whether we received the EnoughParticipants before being assigned a node ID
45
+ // Opens a new WebSocket connection with the server and listens to new messages over the channel
46
+ this._server = await WebSocketServer.connect(serverURL, messages.isMessageFederated, // can only receive federated message types from the server
47
+ messages.isMessageFederated);
48
+ // c.f. setupServerCallbacks doc for explanation
91
49
  let receivedEnoughParticipants = false;
92
- this.server.once(type.EnoughParticipants, () => {
93
- if (this._ownId === undefined) {
94
- debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
95
- receivedEnoughParticipants = true;
96
- }
97
- });
50
+ this.setupServerCallbacks(() => receivedEnoughParticipants = true);
98
51
  this.aggregator.registerNode(SERVER_NODE_ID);
99
52
  const msg = {
100
53
  type: type.ClientConnected,
@@ -109,40 +62,21 @@ export class FederatedClient extends Client {
109
62
  if (waitForMoreParticipants && !receivedEnoughParticipants) {
110
63
  // Create a promise that resolves when enough participants join
111
64
  // The client will await this promise before sending its local weight update
112
- this.#promiseForMoreParticipants = this.waitForMoreParticipants();
65
+ this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
113
66
  }
114
67
  if (this._ownId !== undefined) {
115
68
  throw new Error('received id from server but was already received');
116
69
  }
117
70
  this._ownId = id;
118
- debug(`[${id.slice(0, 4)}] joined session at round ${round} `);
71
+ debug(`[${shortenId(id)}] joined session at round ${round} `);
119
72
  this.aggregator.setRound(round);
120
73
  this.#nbOfParticipants = nbOfParticipants;
121
74
  // Upon connecting, the server answers with a boolean
122
75
  // which indicates whether there are enough participants or not
123
- debug(`[${this.ownId.slice(0, 4)}] upon connecting, wait for participant flag %o`, this.#waitingForMoreParticipants);
76
+ debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants);
124
77
  model.weights = serialization.weights.decode(payload);
125
78
  return model;
126
79
  }
127
- /**
128
- * Method called when the server notifies the client that there aren't enough
129
- * participants (anymore) to start/continue training
130
- * The method creates a promise that will resolve once the server notifies
131
- * the client that the training can resume via a subsequent EnoughParticipants message
132
- * @returns a promise which resolves when enough participants joined the session
133
- */
134
- async waitForMoreParticipants() {
135
- return new Promise((resolve) => {
136
- // "once" is important because we can't resolve the same promise multiple times
137
- this.server.once(type.EnoughParticipants, () => {
138
- debug(`[${this.ownId.slice(0, 4)}] received EnoughParticipants message from server`);
139
- // Emit the last status emitted before waiting if defined
140
- if (this.#previousStatus !== undefined)
141
- this.emit("status", this.#previousStatus);
142
- resolve();
143
- });
144
- });
145
- }
146
80
  /**
147
81
  * Disconnection process when user quits the task.
148
82
  */
@@ -155,10 +89,7 @@ export class FederatedClient extends Client {
155
89
  onRoundBeginCommunication() {
156
90
  // Prepare the result promise for the incoming round
157
91
  this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
158
- // Save the status in case participants leave and we switch to waiting for more participants
159
- // Once enough new participants join we can display the previous status again
160
- this.#previousStatus = "Training the model on the data you connected";
161
- this.emit("status", this.#previousStatus);
92
+ this.saveAndEmit("local training");
162
93
  return Promise.resolve();
163
94
  }
164
95
  /**
@@ -176,18 +107,8 @@ export class FederatedClient extends Client {
176
107
  throw new Error("local aggregation result was not set");
177
108
  }
178
109
  // First we check if we are waiting for more participants before sending our weight update
179
- if (this.#waitingForMoreParticipants) {
180
- // wait for the promise to resolve, which takes as long as it takes for new participants to join
181
- debug(`[${this.ownId.slice(0, 4)}] is awaiting the promise for more participants`);
182
- this.emit("status", "Waiting for more participants");
183
- await this.#promiseForMoreParticipants;
184
- // Make sure to set the promise back to undefined once resolved
185
- this.#promiseForMoreParticipants = undefined;
186
- }
187
- // Save the status in case participants leave and we switch to waiting for more participants
188
- // Once enough new participants join we can display the previous status again
189
- this.#previousStatus = "Updating the model with other participants' models";
190
- this.emit("status", this.#previousStatus);
110
+ await this.waitForParticipantsIfNeeded();
111
+ this.saveAndEmit("updating model");
191
112
  // Send our local contribution to the server
192
113
  // and receive the server global update for this round as an answer to our contribution
193
114
  const payloadToServer = this.aggregator.makePayloads(weights).first();
@@ -198,9 +119,9 @@ export class FederatedClient extends Client {
198
119
  };
199
120
  // Need to await the resulting global model right after sending our local contribution
200
121
  // to make sure we don't miss it
201
- debug(`[${this.ownId.slice(0, 4)}] sent its local update to the server for round ${this.aggregator.round}`);
122
+ debug(`[${shortenId(this.ownId)}] sent its local update to the server for round ${this.aggregator.round}`);
202
123
  this.server.send(msg);
203
- debug(`[${this.ownId.slice(0, 4)}] is waiting for server update for round ${this.aggregator.round + 1}`);
124
+ debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`);
204
125
  const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
205
126
  this.#nbOfParticipants = nbOfParticipants; // Save the current participants
206
127
  const serverResult = serialization.weights.decode(payloadFromServer);
@@ -1,6 +1,7 @@
1
1
  import { type weights } from '../../serialization/index.js';
2
2
  import { type NodeID } from '..//types.js';
3
- import { type, type ClientConnected } from '../messages.js';
3
+ import { type } from '../messages.js';
4
+ import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
4
5
  export type MessageFederated = ClientConnected | NewFederatedNodeInfo | SendPayload | ReceiveServerPayload | WaitingForMoreParticipants | EnoughParticipants;
5
6
  export interface NewFederatedNodeInfo {
6
7
  type: type.NewFederatedNodeInfo;
@@ -21,10 +22,4 @@ export interface ReceiveServerPayload {
21
22
  round: number;
22
23
  nbOfParticipants: number;
23
24
  }
24
- export interface EnoughParticipants {
25
- type: type.EnoughParticipants;
26
- }
27
- export interface WaitingForMoreParticipants {
28
- type: type.WaitingForMoreParticipants;
29
- }
30
25
  export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
@@ -5,6 +5,7 @@ import { Client } from "./client.js";
5
5
  * with anyone. Thus LocalClient doesn't do anything during communication
6
6
  */
7
7
  export declare class LocalClient extends Client {
8
+ getNbOfParticipants(): number;
8
9
  onRoundBeginCommunication(): Promise<void>;
9
10
  onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
10
11
  }
@@ -4,6 +4,9 @@ import { Client } from "./client.js";
4
4
  * with anyone. Thus LocalClient doesn't do anything during communication
5
5
  */
6
6
  export class LocalClient extends Client {
7
+ getNbOfParticipants() {
8
+ return 1;
9
+ }
7
10
  onRoundBeginCommunication() {
8
11
  return Promise.resolve();
9
12
  }
@@ -3,19 +3,26 @@ import type * as federated from './federated/messages.js';
3
3
  export declare enum type {
4
4
  ClientConnected = 0,
5
5
  NewDecentralizedNodeInfo = 1,
6
- SignalForPeer = 2,
6
+ JoinRound = 2,
7
7
  PeerIsReady = 3,
8
8
  PeersForRound = 4,
9
- Payload = 5,
10
- NewFederatedNodeInfo = 6,
11
- WaitingForMoreParticipants = 7,
12
- EnoughParticipants = 8,
13
- SendPayload = 9,
14
- ReceiveServerPayload = 10
9
+ SignalForPeer = 5,
10
+ Payload = 6,
11
+ NewFederatedNodeInfo = 7,
12
+ WaitingForMoreParticipants = 8,
13
+ EnoughParticipants = 9,
14
+ SendPayload = 10,
15
+ ReceiveServerPayload = 11
15
16
  }
16
17
  export interface ClientConnected {
17
18
  type: type.ClientConnected;
18
19
  }
20
+ export interface EnoughParticipants {
21
+ type: type.EnoughParticipants;
22
+ }
23
+ export interface WaitingForMoreParticipants {
24
+ type: type.WaitingForMoreParticipants;
25
+ }
19
26
  export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
20
27
  export type NarrowMessage<D> = Extract<Message, {
21
28
  type: D;
@@ -9,34 +9,36 @@ export var type;
9
9
  // answers with its peer id and also tells the client whether we are waiting
10
10
  // for more participants before starting training
11
11
  type[type["NewDecentralizedNodeInfo"] = 1] = "NewDecentralizedNodeInfo";
12
- // Message forwarded by the server from a client to another client
13
- // to establish a peer-to-peer (WebRTC) connection
14
- type[type["SignalForPeer"] = 2] = "SignalForPeer";
12
+ // Message sent by peers to the server to signal they want to
13
+ // join the next round
14
+ type[type["JoinRound"] = 2] = "JoinRound";
15
15
  // Message sent by nodes to server signaling they are ready to
16
16
  // start the next round
17
17
  type[type["PeerIsReady"] = 3] = "PeerIsReady";
18
18
  // Sent by the server to participating peers containing the list
19
19
  // of peers for the round
20
20
  type[type["PeersForRound"] = 4] = "PeersForRound";
21
+ // Message forwarded by the server from a client to another client
22
+ // to establish a peer-to-peer (WebRTC) connection
23
+ type[type["SignalForPeer"] = 5] = "SignalForPeer";
21
24
  // The weight update
22
- type[type["Payload"] = 5] = "Payload";
25
+ type[type["Payload"] = 6] = "Payload";
23
26
  /* Federated */
24
27
  // The server answers the ClientConnected message with the necessary information
25
28
  // to start training: node id, latest model global weights, current round etc
26
- type[type["NewFederatedNodeInfo"] = 6] = "NewFederatedNodeInfo";
29
+ type[type["NewFederatedNodeInfo"] = 7] = "NewFederatedNodeInfo";
27
30
  // Message sent by server to notify clients that there are not enough
28
31
  // participants to continue training
29
- type[type["WaitingForMoreParticipants"] = 7] = "WaitingForMoreParticipants";
32
+ type[type["WaitingForMoreParticipants"] = 8] = "WaitingForMoreParticipants";
30
33
  // Message sent by server to notify clients that there are now enough
31
34
  // participants to start training collaboratively
32
- type[type["EnoughParticipants"] = 8] = "EnoughParticipants";
33
- type[type["SendPayload"] = 9] = "SendPayload";
34
- type[type["ReceiveServerPayload"] = 10] = "ReceiveServerPayload";
35
+ type[type["EnoughParticipants"] = 9] = "EnoughParticipants";
36
+ type[type["SendPayload"] = 10] = "SendPayload";
37
+ type[type["ReceiveServerPayload"] = 11] = "ReceiveServerPayload";
35
38
  })(type || (type = {}));
36
39
  export function hasMessageType(raw) {
37
- if (typeof raw !== 'object' || raw === null) {
40
+ if (typeof raw !== 'object' || raw === null)
38
41
  return false;
39
- }
40
42
  const o = raw;
41
43
  if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
42
44
  return false;
@@ -29,8 +29,8 @@ export const cifar10 = {
29
29
  IMAGE_W: 224,
30
30
  LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
31
31
  scheme: 'decentralized',
32
+ aggregationStrategy: 'mean',
32
33
  privacy: { clippingRadius: 20, noiseScale: 1 },
33
- decentralizedSecure: true,
34
34
  minNbOfParticipants: 3,
35
35
  maxShareValue: 100,
36
36
  tensorBackend: 'tfjs'
@@ -28,6 +28,7 @@ export const lusCovid = {
28
28
  LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
29
29
  dataType: 'image',
30
30
  scheme: 'federated',
31
+ aggregationStrategy: 'mean',
31
32
  minNbOfParticipants: 2,
32
33
  tensorBackend: 'tfjs'
33
34
  }
@@ -29,7 +29,7 @@ export const mnist = {
29
29
  preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
30
30
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
31
31
  scheme: 'decentralized',
32
- decentralizedSecure: true,
32
+ aggregationStrategy: 'secure',
33
33
  minNbOfParticipants: 3,
34
34
  maxShareValue: 100,
35
35
  tensorBackend: 'tfjs'
@@ -28,6 +28,7 @@ export const simpleFace = {
28
28
  IMAGE_W: 200,
29
29
  LABEL_LIST: ['child', 'adult'],
30
30
  scheme: 'federated',
31
+ aggregationStrategy: 'mean',
31
32
  minNbOfParticipants: 2,
32
33
  tensorBackend: 'tfjs'
33
34
  }
@@ -62,6 +62,7 @@ export const titanic = {
62
62
  'Survived'
63
63
  ],
64
64
  scheme: 'federated',
65
+ aggregationStrategy: 'mean',
65
66
  minNbOfParticipants: 2,
66
67
  tensorBackend: 'tfjs'
67
68
  }
@@ -25,6 +25,7 @@ export const wikitext = {
25
25
  dataType: 'text',
26
26
  preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
27
27
  scheme: 'federated',
28
+ aggregationStrategy: 'mean',
28
29
  minNbOfParticipants: 2,
29
30
  epochs: 6,
30
31
  // Unused by wikitext because data already comes split
@@ -18,10 +18,9 @@ export interface TrainingInformation {
18
18
  LABEL_LIST?: string[];
19
19
  scheme: 'decentralized' | 'federated' | 'local';
20
20
  privacy?: Privacy;
21
- decentralizedSecure?: boolean;
22
21
  maxShareValue?: number;
23
22
  minNbOfParticipants: number;
24
- aggregator?: 'mean' | 'secure';
23
+ aggregationStrategy?: 'mean' | 'secure';
25
24
  tokenizer?: string | PreTrainedTokenizer;
26
25
  maxSequenceLength?: number;
27
26
  tensorBackend: 'tfjs' | 'gpt';
@@ -24,7 +24,7 @@ export function isTrainingInformation(raw) {
24
24
  if (typeof raw !== 'object' || raw === null) {
25
25
  return false;
26
26
  }
27
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
27
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregationStrategy, batchSize, dataType, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
28
28
  if (typeof dataType !== 'string' ||
29
29
  typeof epochs !== 'number' ||
30
30
  typeof batchSize !== 'number' ||
@@ -33,8 +33,7 @@ export function isTrainingInformation(raw) {
33
33
  typeof minNbOfParticipants !== 'number' ||
34
34
  (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
35
35
  (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
36
- (aggregator !== undefined && typeof aggregator !== 'string') ||
37
- (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
36
+ (aggregationStrategy !== undefined && typeof aggregationStrategy !== 'string') ||
38
37
  (privacy !== undefined && !isPrivacy(privacy)) ||
39
38
  (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
40
39
  (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
@@ -45,8 +44,8 @@ export function isTrainingInformation(raw) {
45
44
  (preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
46
45
  return false;
47
46
  }
48
- if (aggregator !== undefined) {
49
- switch (aggregator) {
47
+ if (aggregationStrategy !== undefined) {
48
+ switch (aggregationStrategy) {
50
49
  case 'mean': break;
51
50
  case 'secure': break;
52
51
  default: return false;
@@ -58,7 +57,7 @@ export function isTrainingInformation(raw) {
58
57
  case 'text': break;
59
58
  default: return false;
60
59
  }
61
- // interdepences on data type
60
+ // interdependencies on data type
62
61
  if (dataType === 'image') {
63
62
  if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
64
63
  return false;
@@ -87,10 +86,9 @@ export function isTrainingInformation(raw) {
87
86
  IMAGE_W,
88
87
  IMAGE_H,
89
88
  LABEL_LIST,
90
- aggregator,
89
+ aggregationStrategy,
91
90
  batchSize,
92
91
  dataType,
93
- decentralizedSecure,
94
92
  privacy,
95
93
  epochs,
96
94
  inputColumns,
@@ -7,7 +7,10 @@ interface DiscoConfig {
7
7
  scheme: TrainingInformation["scheme"];
8
8
  logger: Logger;
9
9
  }
10
- export type RoundStatus = "Waiting for more participants" | "Retrieving peers' information" | "Updating the model with other participants' models" | "Training the model on the data you connected";
10
+ export type RoundStatus = 'not enough participants' | // Server notification to wait for more participants
11
+ 'updating model' | // fetching/aggregating local updates into a global model
12
+ 'local training' | // Training the model locally
13
+ 'connecting to peers';
11
14
  /**
12
15
  * Top-level class handling distributed training from a client's perspective. It is meant to be
13
16
  * a convenient object providing a reduced yet complete API that wraps model training and
@@ -62,7 +62,7 @@ export class Trainer {
62
62
  }
63
63
  return {
64
64
  epochs: epochsLogs,
65
- participants: this.#client.nbOfParticipants,
65
+ participants: this.#client.getNbOfParticipants(),
66
66
  };
67
67
  }
68
68
  }
@@ -1,11 +1,11 @@
1
- type Listener<T> = (_: T) => void;
1
+ type Listener<T> = (_: T) => void | Promise<void>;
2
2
  /**
3
3
  * Call handlers on given events
4
4
  *
5
5
  * @typeParam I object/mapping from event name to emitted value type
6
6
  */
7
7
  export declare class EventEmitter<I extends Record<string, unknown>> {
8
- private listeners;
8
+ #private;
9
9
  /**
10
10
  * @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
11
11
  */
@@ -13,7 +13,7 @@ export declare class EventEmitter<I extends Record<string, unknown>> {
13
13
  [E in keyof I]?: Listener<I[E]>;
14
14
  });
15
15
  /**
16
- * Register listener to call on event
16
+ * Register listener to call on event.
17
17
  *
18
18
  * @param event event name to listen to
19
19
  * @param listener handler to call
@@ -6,7 +6,8 @@ import { List } from 'immutable';
6
6
  * @typeParam I object/mapping from event name to emitted value type
7
7
  */
8
8
  export class EventEmitter {
9
- listeners = {};
9
+ // List of callbacks to run per event
10
+ #listeners = {};
10
11
  /**
11
12
  * @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
12
13
  */
@@ -19,14 +20,14 @@ export class EventEmitter {
19
20
  }
20
21
  }
21
22
  /**
22
- * Register listener to call on event
23
+ * Register listener to call on event.
23
24
  *
24
25
  * @param event event name to listen to
25
26
  * @param listener handler to call
26
27
  */
27
28
  on(event, listener) {
28
- const eventListeners = this.listeners[event] ?? List();
29
- this.listeners[event] = eventListeners.push([false, listener]);
29
+ const eventListeners = this.#listeners[event] ?? List();
30
+ this.#listeners[event] = eventListeners.push([false, listener]);
30
31
  }
31
32
  /**
32
33
  * Register listener to call once on next event
@@ -35,8 +36,8 @@ export class EventEmitter {
35
36
  * @param listener handler to call next time
36
37
  */
37
38
  once(event, listener) {
38
- const eventListeners = this.listeners[event] ?? List();
39
- this.listeners[event] = eventListeners.push([true, listener]);
39
+ const eventListeners = this.#listeners[event] ?? List();
40
+ this.#listeners[event] = eventListeners.push([true, listener]);
40
41
  }
41
42
  /**
42
43
  * Send value to registered listeners of event name
@@ -45,9 +46,9 @@ export class EventEmitter {
45
46
  * @param value what to call listeners with
46
47
  */
47
48
  emit(event, value) {
48
- const eventListeners = this.listeners[event] ?? List();
49
- this.listeners[event] = eventListeners.filterNot(([once]) => once);
50
- eventListeners.forEach(([_, listener]) => { listener(value); });
49
+ const eventListeners = this.#listeners[event] ?? List();
50
+ this.#listeners[event] = eventListeners.filterNot(([once]) => once);
51
+ eventListeners.forEach(async ([_, listener]) => { await listener(value); });
51
52
  }
52
53
  }
53
54
  /** `EventEmitter` for all events */
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20241001093123.0",
3
+ "version": "3.0.1-p20241014092014.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",