@epfml/discojs 3.0.1-p20240902100041.0 → 3.0.1-p20240904094219.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/dist/aggregator/base.d.ts +16 -2
  2. package/dist/aggregator/base.js +25 -3
  3. package/dist/aggregator/mean.d.ts +1 -0
  4. package/dist/aggregator/mean.js +11 -6
  5. package/dist/aggregator/secure.js +1 -1
  6. package/dist/client/{base.d.ts → client.d.ts} +13 -30
  7. package/dist/client/{base.js → client.js} +10 -20
  8. package/dist/client/decentralized/{base.d.ts → decentralized_client.d.ts} +5 -5
  9. package/dist/client/decentralized/{base.js → decentralized_client.js} +20 -16
  10. package/dist/client/decentralized/index.d.ts +1 -1
  11. package/dist/client/decentralized/index.js +1 -1
  12. package/dist/client/decentralized/messages.d.ts +7 -2
  13. package/dist/client/decentralized/messages.js +4 -2
  14. package/dist/client/event_connection.js +2 -2
  15. package/dist/client/federated/federated_client.d.ts +44 -0
  16. package/dist/client/federated/federated_client.js +210 -0
  17. package/dist/client/federated/index.d.ts +1 -1
  18. package/dist/client/federated/index.js +1 -1
  19. package/dist/client/federated/messages.d.ts +17 -2
  20. package/dist/client/federated/messages.js +3 -1
  21. package/dist/client/index.d.ts +2 -2
  22. package/dist/client/index.js +2 -2
  23. package/dist/client/local_client.d.ts +10 -0
  24. package/dist/client/local_client.js +14 -0
  25. package/dist/client/messages.d.ts +6 -8
  26. package/dist/client/messages.js +23 -7
  27. package/dist/client/utils.js +1 -1
  28. package/dist/default_tasks/cifar10.js +1 -2
  29. package/dist/default_tasks/lus_covid.js +1 -1
  30. package/dist/default_tasks/mnist.js +1 -2
  31. package/dist/default_tasks/simple_face.js +2 -2
  32. package/dist/default_tasks/titanic.js +2 -2
  33. package/dist/default_tasks/wikitext.js +1 -1
  34. package/dist/index.d.ts +4 -2
  35. package/dist/index.js +1 -1
  36. package/dist/logging/logger.d.ts +1 -1
  37. package/dist/serialization/model.js +18 -9
  38. package/dist/task/index.d.ts +0 -1
  39. package/dist/task/index.js +0 -1
  40. package/dist/task/task.d.ts +0 -2
  41. package/dist/task/task.js +2 -4
  42. package/dist/task/training_information.d.ts +1 -2
  43. package/dist/task/training_information.js +3 -5
  44. package/dist/training/disco.d.ts +14 -16
  45. package/dist/training/disco.js +22 -46
  46. package/dist/training/index.d.ts +1 -1
  47. package/dist/training/trainer.d.ts +3 -2
  48. package/dist/training/trainer.js +12 -5
  49. package/dist/utils/event_emitter.js +1 -3
  50. package/package.json +1 -1
  51. package/dist/client/federated/base.d.ts +0 -38
  52. package/dist/client/federated/base.js +0 -130
  53. package/dist/client/local.d.ts +0 -5
  54. package/dist/client/local.js +0 -6
  55. package/dist/memory/base.d.ts +0 -111
  56. package/dist/memory/base.js +0 -9
  57. package/dist/memory/empty.d.ts +0 -20
  58. package/dist/memory/empty.js +0 -43
  59. package/dist/memory/index.d.ts +0 -2
  60. package/dist/memory/index.js +0 -2
  61. package/dist/task/digest.d.ts +0 -5
  62. package/dist/task/digest.js +0 -14
@@ -58,13 +58,22 @@ export declare abstract class Base<T> extends EventEmitter<{
58
58
  communicationRounds?: number);
59
59
  /**
60
60
  * Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
61
+ * The aggregation round is increased whenever a new global model is obtained and local models are updated.
62
+ * Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
63
+ * which requires multiple steps to obtain a global model)
61
64
  * The contribution will be aggregated during the next aggregation step.
62
65
  * @param nodeId The node's id
63
66
  * @param contribution The node's contribution
64
- * @param round For which aggregation round the contribution was made
65
- * @param communicationRound For which communication round the contribution was made
67
+ * @param round aggregation round of the contribution was made
68
+ * @param communicationRound communication round the contribution was made within the aggregation round
69
+ * @returns boolean, true if the contribution has been successfully taken into account or False if it has been rejected
66
70
  */
67
71
  abstract add(nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean;
72
+ /**
73
+ * Evaluates whether a given participant contribution can be used in the current aggregation round
74
+ * the boolean returned by `this.add` is obtained via `this.isValidContribution`
75
+ */
76
+ isValidContribution(nodeId: client.NodeID, round: number): boolean;
68
77
  /**
69
78
  * Performs an aggregation step over the received node contributions.
70
79
  * Must store the aggregation's result in the aggregator's result promise.
@@ -91,6 +100,11 @@ export declare abstract class Base<T> extends EventEmitter<{
91
100
  * @returns True is the node wasn't already in the list of nodes, False if already included
92
101
  */
93
102
  registerNode(nodeId: client.NodeID): boolean;
103
+ /**
104
+ * Remove a node's id from the set of active nodes.
105
+ * @param nodeId The node to be removed
106
+ */
107
+ removeNode(nodeId: client.NodeID): void;
94
108
  /**
95
109
  * Overwrites the current set of active nodes with the given one. A node represents
96
110
  * an active neighbor peer/client within the network, whom we are communicating with
@@ -60,6 +60,21 @@ export class Base extends EventEmitter {
60
60
  // and communication rounds.
61
61
  this.on('aggregation', () => this.nextRound());
62
62
  }
63
+ /**
64
+ * Evaluates whether a given participant contribution can be used in the current aggregation round
65
+ * the boolean returned by `this.add` is obtained via `this.isValidContribution`
66
+ */
67
+ isValidContribution(nodeId, round) {
68
+ if (!this.nodes.has(nodeId)) {
69
+ debug("Contribution rejected because node id is not registered");
70
+ return false;
71
+ }
72
+ if (!this.isWithinRoundCutoff(round)) {
73
+ debug(`Contribution rejected because round ${round} is not within round cutoff`);
74
+ return false;
75
+ }
76
+ return true;
77
+ }
63
78
  /**
64
79
  * Returns whether the given round is recent enough, dependent on the
65
80
  * aggregator's round cutoff.
@@ -77,16 +92,16 @@ export class Base extends EventEmitter {
77
92
  log(step, from) {
78
93
  switch (step) {
79
94
  case AggregationStep.ADD:
80
- debug(`adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
95
+ debug(`Adding contribution from node ${from ?? '"unknown"'} for aggregation round ${this.round} and communication round ${this.communicationRound}`);
81
96
  break;
82
97
  case AggregationStep.UPDATE:
83
98
  if (from === undefined) {
84
99
  return;
85
100
  }
86
- debug(`updating contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
101
+ debug(`Updating contribution from node ${from} for aggregation round ${this.round} and communication round ${this.communicationRound}`);
87
102
  break;
88
103
  case AggregationStep.AGGREGATE:
89
- debug(`buffer full, aggregating weights for round (${this.communicationRound}, ${this.round})`);
104
+ debug(`Buffer is full. Aggregating weights for round aggregation round ${this.round} and communication round ${this.communicationRound}`);
90
105
  break;
91
106
  default: {
92
107
  const _ = step;
@@ -108,6 +123,13 @@ export class Base extends EventEmitter {
108
123
  }
109
124
  return false;
110
125
  }
126
+ /**
127
+ * Remove a node's id from the set of active nodes.
128
+ * @param nodeId The node to be removed
129
+ */
130
+ removeNode(nodeId) {
131
+ this._nodes = this._nodes.delete(nodeId);
132
+ }
111
133
  /**
112
134
  * Overwrites the current set of active nodes with the given one. A node represents
113
135
  * an active neighbor peer/client within the network, whom we are communicating with
@@ -30,6 +30,7 @@ export declare class MeanAggregator extends Aggregator<WeightsContainer> {
30
30
  constructor(roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
31
31
  /** Checks whether the contributions buffer is full. */
32
32
  isFull(): boolean;
33
+ set minNbOfParticipants(minNbOfParticipants: number);
33
34
  add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
34
35
  aggregate(): void;
35
36
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
@@ -9,6 +9,7 @@ const debug = createDebug("discojs:aggregator:mean");
9
9
  export class MeanAggregator extends Aggregator {
10
10
  #threshold;
11
11
  #thresholdType;
12
+ #minNbOfParticipants;
12
13
  /**
13
14
  * Create a mean aggregator that averages all weight updates received when a specified threshold is met.
14
15
  * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
@@ -65,21 +66,24 @@ export class MeanAggregator extends Aggregator {
65
66
  }
66
67
  /** Checks whether the contributions buffer is full. */
67
68
  isFull() {
69
+ // Make sure that we are over the minimum number of participants
70
+ // if specified
71
+ if (this.#minNbOfParticipants !== undefined &&
72
+ this.nodes.size < this.#minNbOfParticipants)
73
+ return false;
68
74
  const thresholdValue = this.#thresholdType == 'relative'
69
75
  ? this.#threshold * this.nodes.size
70
76
  : this.#threshold;
71
77
  return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
72
78
  }
79
+ set minNbOfParticipants(minNbOfParticipants) {
80
+ this.#minNbOfParticipants = minNbOfParticipants;
81
+ }
73
82
  add(nodeId, contribution, round, currentContributions = 0) {
74
83
  if (currentContributions !== 0)
75
84
  throw new Error("only a single communication round");
76
- if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
77
- if (!this.nodes.has(nodeId))
78
- debug(`contribution rejected because node ${nodeId} is not registered`);
79
- if (!this.isWithinRoundCutoff(round))
80
- debug(`contribution rejected because round ${round} is not within cutoff`);
85
+ if (!this.isValidContribution(nodeId, round))
81
86
  return false;
82
- }
83
87
  this.log(this.contributions.hasIn([0, nodeId])
84
88
  ? AggregationStep.UPDATE
85
89
  : AggregationStep.ADD, nodeId);
@@ -94,6 +98,7 @@ export class MeanAggregator extends Aggregator {
94
98
  throw new Error("aggregating without any contribution");
95
99
  this.log(AggregationStep.AGGREGATE);
96
100
  const result = aggregation.avg(currentContributions.values());
101
+ // Emitting the event runs the superclass' callback to increment the round
97
102
  this.emit('aggregation', result);
98
103
  }
99
104
  makePayloads(weights) {
@@ -48,7 +48,7 @@ export class SecureAggregator extends Aggregator {
48
48
  default:
49
49
  throw new Error("requires communication round to be 0 or 1");
50
50
  }
51
- if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
51
+ if (!this.isValidContribution(nodeId, round))
52
52
  return false;
53
53
  this.log(this.contributions.hasIn([communicationRound, nodeId])
54
54
  ? AggregationStep.UPDATE
@@ -1,23 +1,17 @@
1
- import type { Model, Task, WeightsContainer } from '../index.js';
1
+ import type { Model, Task, WeightsContainer, RoundStatus } from '../index.js';
2
2
  import type { NodeID } from './types.js';
3
3
  import type { EventConnection } from './event_connection.js';
4
4
  import type { Aggregator } from '../aggregator/index.js';
5
+ import { EventEmitter } from '../utils/event_emitter.js';
5
6
  /**
6
7
  * Main, abstract, class representing a Disco client in a network, which handles
7
8
  * communication with other nodes, be it peers or a server.
8
9
  */
9
- export declare abstract class Base {
10
- /**
11
- * The network server's URL to connect to.
12
- */
10
+ export declare abstract class Client extends EventEmitter<{
11
+ 'status': RoundStatus;
12
+ }> {
13
13
  readonly url: URL;
14
- /**
15
- * The client's corresponding task.
16
- */
17
14
  readonly task: Task;
18
- /**
19
- * The client's aggregator.
20
- */
21
15
  readonly aggregator: Aggregator;
22
16
  /**
23
17
  * Own ID provided by the network's server.
@@ -31,23 +25,15 @@ export declare abstract class Base {
31
25
  * The aggregator's result produced after aggregation.
32
26
  */
33
27
  protected aggregationResult?: Promise<WeightsContainer>;
34
- constructor(
35
- /**
36
- * The network server's URL to connect to.
37
- */
38
- url: URL,
39
- /**
40
- * The client's corresponding task.
41
- */
42
- task: Task,
43
- /**
44
- * The client's aggregator.
45
- */
28
+ constructor(url: URL, // The network server's URL to connect to
29
+ task: Task, // The client's corresponding task
46
30
  aggregator: Aggregator);
47
31
  /**
48
32
  * Handles the connection process from the client to any sort of network server.
33
+ * This method is overriden by the federated and decentralized clients
34
+ * By default, it fetches and returns the server's base model
49
35
  */
50
- connect(): Promise<void>;
36
+ connect(): Promise<Model>;
51
37
  /**
52
38
  * Handles the disconnection process of the client from any sort of network server.
53
39
  */
@@ -59,17 +45,14 @@ export declare abstract class Base {
59
45
  getLatestModel(): Promise<Model>;
60
46
  /**
61
47
  * Communication callback called at the beginning of every training round.
62
- * @param _weights The most recent local weight updates
63
- * @param _round The current training round
64
48
  */
65
- onRoundBeginCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
49
+ abstract onRoundBeginCommunication(): Promise<void>;
66
50
  /**
67
51
  * Communication callback called the end of every training round.
68
- * @param _weights The most recent local weight updates
69
- * @param _round The current training round
52
+ * @param weights The local weight update resulting for the current local training round
70
53
  * @returns aggregated weights or the local weights upon error
71
54
  */
72
- abstract onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<WeightsContainer>;
55
+ abstract onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
73
56
  get nbOfParticipants(): number;
74
57
  get ownId(): NodeID;
75
58
  get server(): EventConnection;
@@ -1,10 +1,11 @@
1
1
  import axios from 'axios';
2
2
  import { serialization } from '../index.js';
3
+ import { EventEmitter } from '../utils/event_emitter.js';
3
4
  /**
4
5
  * Main, abstract, class representing a Disco client in a network, which handles
5
6
  * communication with other nodes, be it peers or a server.
6
7
  */
7
- export class Base {
8
+ export class Client extends EventEmitter {
8
9
  url;
9
10
  task;
10
11
  aggregator;
@@ -20,27 +21,22 @@ export class Base {
20
21
  * The aggregator's result produced after aggregation.
21
22
  */
22
23
  aggregationResult;
23
- constructor(
24
- /**
25
- * The network server's URL to connect to.
26
- */
27
- url,
28
- /**
29
- * The client's corresponding task.
30
- */
31
- task,
32
- /**
33
- * The client's aggregator.
34
- */
24
+ constructor(url, // The network server's URL to connect to
25
+ task, // The client's corresponding task
35
26
  aggregator) {
27
+ super();
36
28
  this.url = url;
37
29
  this.task = task;
38
30
  this.aggregator = aggregator;
39
31
  }
40
32
  /**
41
33
  * Handles the connection process from the client to any sort of network server.
34
+ * This method is overriden by the federated and decentralized clients
35
+ * By default, it fetches and returns the server's base model
42
36
  */
43
- async connect() { }
37
+ async connect() {
38
+ return this.getLatestModel();
39
+ }
44
40
  /**
45
41
  * Handles the disconnection process of the client from any sort of network server.
46
42
  */
@@ -58,12 +54,6 @@ export class Base {
58
54
  const response = await axios.get(url.href, { responseType: 'arraybuffer' });
59
55
  return await serialization.model.decode(new Uint8Array(response.data));
60
56
  }
61
- /**
62
- * Communication callback called at the beginning of every training round.
63
- * @param _weights The most recent local weight updates
64
- * @param _round The current training round
65
- */
66
- async onRoundBeginCommunication(_weights, _round) { }
67
57
  // Number of contributors to a collaborative session
68
58
  // If decentralized, it should be the number of peers
69
59
  // If federated, it should the number of participants excluding the server
@@ -1,4 +1,4 @@
1
- import type { WeightsContainer } from "../../index.js";
1
+ import type { Model, WeightsContainer } from "../../index.js";
2
2
  import { Client } from '../index.js';
3
3
  /**
4
4
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
@@ -6,7 +6,7 @@ import { Client } from '../index.js';
6
6
  * with the server is based off regular WebSockets, whereas peer-to-peer communication uses
7
7
  * WebRTC for Node.js.
8
8
  */
9
- export declare class Base extends Client {
9
+ export declare class DecentralizedClient extends Client {
10
10
  /**
11
11
  * The pool of peers to communicate with during the current training round.
12
12
  */
@@ -21,7 +21,7 @@ export declare class Base extends Client {
21
21
  * create peer-to-peer WebRTC connections with peers. The server is used to exchange
22
22
  * peers network information.
23
23
  */
24
- connect(): Promise<void>;
24
+ connect(): Promise<Model>;
25
25
  /**
26
26
  * Create a WebSocket connection with the server
27
27
  * The client then waits for the server to forward it other client's network information.
@@ -37,7 +37,7 @@ export declare class Base extends Client {
37
37
  * and waits for it to resolve.
38
38
  *
39
39
  */
40
- onRoundBeginCommunication(_: WeightsContainer, round: number): Promise<void>;
40
+ onRoundBeginCommunication(): Promise<void>;
41
41
  /**
42
42
  * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
43
43
  * This method is used as callback by getPeers when connecting to the rounds' peers
@@ -45,5 +45,5 @@ export declare class Base extends Client {
45
45
  * @param round
46
46
  */
47
47
  private receivePayloads;
48
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
48
+ onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
49
49
  }
@@ -14,7 +14,7 @@ const debug = createDebug("discojs:client:decentralized");
14
14
  * with the server is based off regular WebSockets, whereas peer-to-peer communication uses
15
15
  * WebRTC for Node.js.
16
16
  */
17
- export class Base extends Client {
17
+ export class DecentralizedClient extends Client {
18
18
  /**
19
19
  * The pool of peers to communicate with during the current training round.
20
20
  */
@@ -33,6 +33,7 @@ export class Base extends Client {
33
33
  * peers network information.
34
34
  */
35
35
  async connect() {
36
+ const model = await super.connect(); // Get the server base model
36
37
  const serverURL = new URL('', this.url.href);
37
38
  switch (this.url.protocol) {
38
39
  case 'http:':
@@ -44,19 +45,20 @@ export class Base extends Client {
44
45
  default:
45
46
  throw new Error(`unknown protocol: ${this.url.protocol}`);
46
47
  }
47
- serverURL.pathname += `deai/${this.task.id}`;
48
+ serverURL.pathname += `decentralized/${this.task.id}`;
48
49
  this._server = await this.connectServer(serverURL);
49
50
  const msg = {
50
51
  type: type.ClientConnected
51
52
  };
52
53
  this.server.send(msg);
53
- const peerIdMsg = await waitMessage(this.server, type.AssignNodeID);
54
- debug(`[${peerIdMsg.id}] assigned id generated by server`);
54
+ const { id } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
55
+ debug(`[${id}] assigned id generated by server`);
55
56
  if (this._ownId !== undefined) {
56
57
  throw new Error('received id from server but was already received');
57
58
  }
58
- this._ownId = peerIdMsg.id;
59
- this.pool = new PeerPool(peerIdMsg.id);
59
+ this._ownId = id;
60
+ this.pool = new PeerPool(id);
61
+ return model;
60
62
  }
61
63
  /**
62
64
  * Create a WebSocket connection with the server
@@ -96,21 +98,22 @@ export class Base extends Client {
96
98
  * and waits for it to resolve.
97
99
  *
98
100
  */
99
- async onRoundBeginCommunication(_, round) {
101
+ async onRoundBeginCommunication() {
100
102
  if (this.server === undefined) {
101
103
  throw new Error("peer's server is undefined, make sure to call `client.connect()` first");
102
104
  }
103
105
  if (this.pool === undefined) {
104
106
  throw new Error('peer pool is undefined, make sure to call `client.connect()` first');
105
107
  }
108
+ this.emit("status", "Retrieving peers' information");
106
109
  // Reset peers list at each round of training to make sure client works with an updated peers
107
110
  // list, maintained by the server. Adds any received weights to the aggregator.
108
- // this.connections = await this.waitForPeers(round)
109
111
  // Tell the server we are ready for the next round
110
112
  const readyMessage = { type: type.PeerIsReady };
111
113
  this.server.send(readyMessage);
112
114
  // Wait for the server to answer with the list of peers for the round
113
115
  try {
116
+ debug(`[${this.ownId}] is waiting for peer list for round ${this.aggregator.round}`);
114
117
  const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound, undefined, "Timeout waiting for the round's peer list");
115
118
  const peers = Set(receivedMessage.peers);
116
119
  if (this.ownId !== undefined && peers.has(this.ownId)) {
@@ -124,18 +127,19 @@ export class Base extends Client {
124
127
  // Init receipt of peers weights
125
128
  // this awaits the peer's weight update and adds it to
126
129
  // our aggregator upon reception
127
- (conn) => { this.receivePayloads(conn, round); });
128
- debug(`[${this.ownId}] received peers for round ${round}: %o`, connections.keySeq().toJS());
130
+ (conn) => { this.receivePayloads(conn, this.aggregator.round); });
131
+ debug(`[${this.ownId}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
129
132
  this.connections = connections;
130
133
  }
131
134
  catch (e) {
132
- debug(`[${this.ownId}] while beginning round: %o`, e);
135
+ debug(`Error for [${this.ownId}] while beginning round: %o`, e);
133
136
  this.aggregator.setNodes(Set(this.ownId));
134
137
  this.connections = Map();
135
138
  }
136
139
  // Store the promise for the current round's aggregation result.
137
140
  // We will await for it to resolve at the end of the round when exchanging weight updates.
138
141
  this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
142
+ this.emit("status", "Training the model on the data you connected");
139
143
  }
140
144
  /**
141
145
  * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
@@ -156,18 +160,18 @@ export class Base extends Client {
156
160
  }
157
161
  }
158
162
  catch (e) {
159
- if (this.isDisconnected) {
163
+ if (this.isDisconnected)
160
164
  return;
161
- }
162
- debug(`[${this.ownId}] while receiving payloads: %o`, e);
165
+ debug(`Error for [${this.ownId}] while receiving payloads: %o`, e);
163
166
  }
164
167
  } while (++currentCommunicationRounds < this.aggregator.communicationRounds);
165
168
  });
166
169
  }
167
- async onRoundEndCommunication(weights, round) {
170
+ async onRoundEndCommunication(weights) {
168
171
  if (this.aggregationResult === undefined) {
169
172
  throw new TypeError('aggregation result promise is undefined');
170
173
  }
174
+ this.emit("status", "Updating the model with other participants' models");
171
175
  // Perform the required communication rounds. Each communication round consists in sending our local payload,
172
176
  // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
173
177
  // A communication round's payload is the aggregation result of the previous communication round. The first
@@ -181,7 +185,7 @@ export class Base extends Client {
181
185
  try {
182
186
  await Promise.all(payloads.map(async (payload, id) => {
183
187
  if (id === this.ownId) {
184
- this.aggregator.add(this.ownId, payload, round, r);
188
+ this.aggregator.add(this.ownId, payload, this.aggregator.round, r);
185
189
  }
186
190
  else {
187
191
  const peer = this.connections?.get(id);
@@ -1,2 +1,2 @@
1
- export { Base as DecentralizedClient } from './base.js';
1
+ export { DecentralizedClient } from './decentralized_client.js';
2
2
  export * as messages from './messages.js';
@@ -1,2 +1,2 @@
1
- export { Base as DecentralizedClient } from './base.js';
1
+ export { DecentralizedClient } from './decentralized_client.js';
2
2
  export * as messages from './messages.js';
@@ -1,7 +1,12 @@
1
1
  import { weights } from '../../serialization/index.js';
2
2
  import { type SignalData } from './peer.js';
3
3
  import { type NodeID } from '../types.js';
4
- import { type, type ClientConnected, type AssignNodeID } from '../messages.js';
4
+ import { type, type ClientConnected } from '../messages.js';
5
+ export interface NewDecentralizedNodeInfo {
6
+ type: type.NewDecentralizedNodeInfo;
7
+ id: NodeID;
8
+ waitForMoreParticipants: boolean;
9
+ }
5
10
  export interface SignalForPeer {
6
11
  type: type.SignalForPeer;
7
12
  peer: NodeID;
@@ -20,7 +25,7 @@ export interface Payload {
20
25
  round: number;
21
26
  payload: weights.Encoded;
22
27
  }
23
- export type MessageFromServer = AssignNodeID | SignalForPeer | PeersForRound;
28
+ export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound;
24
29
  export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady;
25
30
  export type PeerMessage = Payload;
26
31
  export declare function isMessageFromServer(o: unknown): o is MessageFromServer;
@@ -6,8 +6,10 @@ export function isMessageFromServer(o) {
6
6
  return false;
7
7
  }
8
8
  switch (o.type) {
9
- case type.AssignNodeID:
10
- return 'id' in o && isNodeID(o.id);
9
+ case type.NewDecentralizedNodeInfo:
10
+ return 'id' in o && isNodeID(o.id) &&
11
+ 'waitForMoreParticipants' in o &&
12
+ typeof o.waitForMoreParticipants === 'boolean';
11
13
  case type.SignalForPeer:
12
14
  return 'peer' in o && isNodeID(o.peer) &&
13
15
  'signal' in o; // TODO check signal content?
@@ -95,8 +95,8 @@ export class WebSocketServer extends EventEmitter {
95
95
  }
96
96
  disconnect() {
97
97
  return new Promise((resolve, reject) => {
98
- this.socket.once('close', resolve);
99
- this.socket.once('error', reject);
98
+ this.socket.onclose = () => resolve();
99
+ this.socket.onerror = (e) => reject(e.message);
100
100
  this.socket.close();
101
101
  });
102
102
  }
@@ -0,0 +1,44 @@
1
+ import type { Model, WeightsContainer } from "../../index.js";
2
+ import { Client } from "../client.js";
3
+ /**
4
+ * Client class that communicates with a centralized, federated server, when training
5
+ * a specific task in the federated setting.
6
+ */
7
+ export declare class FederatedClient extends Client {
8
+ #private;
9
+ get nbOfParticipants(): number;
10
+ /**
11
+ * Opens a new WebSocket connection with the server and listens to new messages over the channel
12
+ */
13
+ private connectServer;
14
+ /**
15
+ * Initializes the connection to the server, gets our node ID
16
+ * as well as the latest training information: latest global model, current round and
17
+ * whether we are waiting for more participants.
18
+ */
19
+ connect(): Promise<Model>;
20
+ /**
21
+ * Method called when the server notifies the client that there aren't enough
22
+ * participants (anymore) to start/continue training
23
+ * The method creates a promise that will resolve once the server notifies
24
+ * the client that the training can resume via a subsequent EnoughParticipants message
25
+ * @returns a promise which resolves when enough participants joined the session
26
+ */
27
+ private waitForMoreParticipants;
28
+ /**
29
+ * Disconnection process when user quits the task.
30
+ */
31
+ disconnect(): Promise<void>;
32
+ onRoundBeginCommunication(): Promise<void>;
33
+ /**
34
+ * Send the local weight update to the server and waits (indefinitely) for the server global update
35
+ *
36
+ * If the waitingForMoreParticipants flag is set, we first wait (also indefinitely) until the
37
+ * server notifies us that the training can resume.
38
+ *
39
+ // NB: For now, we suppose a fully-federated setting.
40
+ * @param weights Local weights sent to the server at the end of the local training round
41
+ * @returns the new global weights sent by the server
42
+ */
43
+ onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
44
+ }