@epfml/discojs 3.0.1-p20241007204240.0 → 3.0.1-p20241014092014.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
  2. package/dist/aggregator/{base.js → aggregator.js} +48 -36
  3. package/dist/aggregator/get.d.ts +2 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/index.d.ts +1 -4
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +4 -4
  8. package/dist/aggregator/mean.js +5 -15
  9. package/dist/aggregator/secure.d.ts +4 -4
  10. package/dist/aggregator/secure.js +7 -17
  11. package/dist/client/client.d.ts +71 -17
  12. package/dist/client/client.js +115 -14
  13. package/dist/client/decentralized/decentralized_client.d.ts +11 -13
  14. package/dist/client/decentralized/decentralized_client.js +121 -84
  15. package/dist/client/decentralized/messages.d.ts +10 -4
  16. package/dist/client/decentralized/messages.js +7 -6
  17. package/dist/client/federated/federated_client.d.ts +1 -13
  18. package/dist/client/federated/federated_client.js +15 -94
  19. package/dist/client/federated/messages.d.ts +2 -7
  20. package/dist/client/local_client.d.ts +1 -0
  21. package/dist/client/local_client.js +3 -0
  22. package/dist/client/messages.d.ts +14 -7
  23. package/dist/client/messages.js +13 -11
  24. package/dist/default_tasks/cifar10.js +1 -1
  25. package/dist/default_tasks/lus_covid.js +1 -0
  26. package/dist/default_tasks/mnist.js +1 -1
  27. package/dist/default_tasks/simple_face.js +1 -0
  28. package/dist/default_tasks/titanic.js +1 -0
  29. package/dist/default_tasks/wikitext.js +1 -0
  30. package/dist/task/training_information.d.ts +1 -2
  31. package/dist/task/training_information.js +6 -8
  32. package/dist/training/disco.d.ts +4 -1
  33. package/dist/training/trainer.js +1 -1
  34. package/dist/utils/event_emitter.d.ts +3 -3
  35. package/dist/utils/event_emitter.js +10 -9
  36. package/package.json +1 -1
@@ -1,5 +1,5 @@
1
1
  import { Map, Set } from 'immutable';
2
- import type { client } from '../index.js';
2
+ import type { client, WeightsContainer } from '../index.js';
3
3
  import { EventEmitter } from '../utils/event_emitter.js';
4
4
  export declare enum AggregationStep {
5
5
  ADD = 0,
@@ -10,11 +10,11 @@ export declare enum AggregationStep {
10
10
  * Main, abstract, aggregator class whose role is to buffer contributions and to produce
11
11
  * a result based off their aggregation, whenever some defined condition is met.
12
12
  *
13
- * Emits an event whenever an aggregation step is performed.
14
- * Users wait for this event to fetch the aggregation result.
13
+ * Emits an event whenever an aggregation step is performed with the counrd's aggregated weights.
14
+ * Users subscribes to this event to get the aggregation result.
15
15
  */
16
- export declare abstract class Base<T> extends EventEmitter<{
17
- 'aggregation': T;
16
+ export declare abstract class Aggregator extends EventEmitter<{
17
+ 'aggregation': WeightsContainer;
18
18
  }> {
19
19
  /**
20
20
  * The round cut-off for contributions.
@@ -34,7 +34,7 @@ export declare abstract class Base<T> extends EventEmitter<{
34
34
  * It defines the effective aggregation group, which is possibly a subset
35
35
  * of all active nodes, depending on the aggregation scheme.
36
36
  */
37
- protected contributions: Map<number, Map<client.NodeID, T>>;
37
+ protected contributions: Map<number, Map<client.NodeID, WeightsContainer>>;
38
38
  /**
39
39
  * The current aggregation round, used for assessing whether a node contribution is recent enough
40
40
  * or not.
@@ -56,36 +56,45 @@ export declare abstract class Base<T> extends EventEmitter<{
56
56
  * The number of communication rounds occurring during any given aggregation round.
57
57
  */
58
58
  communicationRounds?: number);
59
+ /**
60
+ * Convenience method to subscribe to the 'aggregation' event.
61
+ * Await this promise returns the aggregated weights for the current round.
62
+ *
63
+ * @returns a promise for the aggregated weights
64
+ */
65
+ getPromiseForAggregation(): Promise<WeightsContainer>;
59
66
  /**
60
67
  * Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
61
68
  * The aggregation round is increased whenever a new global model is obtained and local models are updated.
62
69
  * Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
63
- * which requires multiple steps to obtain a global model)
64
- * The contribution will be aggregated during the next aggregation step.
70
+ * which requires multiple steps to obtain a global model)
71
+ * The contribution is aggregated during the next aggregation step.
72
+ *
65
73
  * @param nodeId The node's id
66
74
  * @param contribution The node's contribution
67
- * @param round aggregation round of the contribution was made
68
- * @param communicationRound communication round the contribution was made within the aggregation round
69
- * @returns boolean, true if the contribution has been successfully taken into account or False if it has been rejected
70
75
  */
71
- abstract add(nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean;
76
+ add(nodeId: client.NodeID, contribution: WeightsContainer, aggregationRound: number, communicationRound?: number): void;
77
+ protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void;
72
78
  /**
73
79
  * Evaluates whether a given participant contribution can be used in the current aggregation round
74
80
  * the boolean returned by `this.add` is obtained via `this.isValidContribution`
81
+ *
82
+ * @param nodeId the node id of the contribution to be added
83
+ * @param round the aggregation round of the contribution to be added
75
84
  */
76
85
  isValidContribution(nodeId: client.NodeID, round: number): boolean;
77
86
  /**
78
87
  * Performs an aggregation step over the received node contributions.
79
88
  * Must store the aggregation's result in the aggregator's result promise.
80
89
  */
81
- abstract aggregate(): void;
90
+ protected abstract aggregate(): WeightsContainer;
82
91
  /**
83
92
  * Returns whether the given round is recent enough, dependent on the
84
93
  * aggregator's round cutoff.
85
94
  * @param round The round
86
95
  * @returns True if the round is recent enough, false otherwise
87
96
  */
88
- isWithinRoundCutoff(round: number): boolean;
97
+ private isWithinRoundCutoff;
89
98
  /**
90
99
  * Logs useful messages during the various aggregation steps.
91
100
  * @param step The aggregation step
@@ -112,28 +121,17 @@ export declare abstract class Base<T> extends EventEmitter<{
112
121
  * @param nodeIds The new set of nodes
113
122
  */
114
123
  setNodes(nodeIds: Set<client.NodeID>): void;
115
- /**
116
- * Empties the current set of "nodes". Usually called at the end of an aggregation round,
117
- * if the set of nodes is meant to change or to be actualized.
118
- */
119
- resetNodes(): void;
120
124
  /**
121
125
  * Sets the aggregator's round number. To be used whenever the aggregator is out of sync
122
126
  * with the network's round.
123
127
  * @param round The new round
124
128
  */
125
129
  setRound(round: number): void;
126
- /**
127
- * Updates the aggregator's state to proceed to the next communication round.
128
- * If all communication rounds were performed, proceeds to the next aggregation round
129
- * and empties the collection of stored contributions.
130
- */
131
- nextRound(): void;
132
130
  /**
133
131
  * Constructs the payloads sent to other nodes as contribution.
134
132
  * @param base Object from which the payload is computed
135
133
  */
136
- abstract makePayloads(base: T): Map<client.NodeID, T>;
134
+ abstract makePayloads(base: WeightsContainer): Map<client.NodeID, WeightsContainer>;
137
135
  abstract isFull(): boolean;
138
136
  /**
139
137
  * The set of node ids, representing our neighbors within the network.
@@ -143,11 +141,6 @@ export declare abstract class Base<T> extends EventEmitter<{
143
141
  * The aggregation round.
144
142
  */
145
143
  get round(): number;
146
- /**
147
- * The aggregator's current size, defined by its number of contributions. The size is bounded by
148
- * the amount of all active nodes times the number of communication rounds.
149
- */
150
- get size(): number;
151
144
  /**
152
145
  * The current communication round.
153
146
  */
@@ -12,10 +12,10 @@ export var AggregationStep;
12
12
  * Main, abstract, aggregator class whose role is to buffer contributions and to produce
13
13
  * a result based off their aggregation, whenever some defined condition is met.
14
14
  *
15
- * Emits an event whenever an aggregation step is performed.
16
- * Users wait for this event to fetch the aggregation result.
15
+ * Emits an event whenever an aggregation step is performed with the counrd's aggregated weights.
16
+ * Users subscribes to this event to get the aggregation result.
17
17
  */
18
- export class Base extends EventEmitter {
18
+ export class Aggregator extends EventEmitter {
19
19
  roundCutoff;
20
20
  communicationRounds;
21
21
  /**
@@ -28,7 +28,7 @@ export class Base extends EventEmitter {
28
28
  * It defines the effective aggregation group, which is possibly a subset
29
29
  * of all active nodes, depending on the aggregation scheme.
30
30
  */
31
- // communication round -> NodeID -> T
31
+ // communication round -> NodeID -> WeightsContainer
32
32
  contributions;
33
33
  /**
34
34
  * The current aggregation round, used for assessing whether a node contribution is recent enough
@@ -56,13 +56,54 @@ export class Base extends EventEmitter {
56
56
  this.communicationRounds = communicationRounds;
57
57
  this.contributions = Map();
58
58
  this._nodes = Set();
59
- // On every aggregation, update the object's state to match the current aggregation
60
- // and communication rounds.
61
- this.on('aggregation', () => this.nextRound());
59
+ }
60
+ /**
61
+ * Convenience method to subscribe to the 'aggregation' event.
62
+ * Await this promise returns the aggregated weights for the current round.
63
+ *
64
+ * @returns a promise for the aggregated weights
65
+ */
66
+ getPromiseForAggregation() {
67
+ return new Promise((resolve) => this.once('aggregation', resolve));
68
+ }
69
+ /**
70
+ * Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
71
+ * The aggregation round is increased whenever a new global model is obtained and local models are updated.
72
+ * Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
73
+ * which requires multiple steps to obtain a global model)
74
+ * The contribution is aggregated during the next aggregation step.
75
+ *
76
+ * @param nodeId The node's id
77
+ * @param contribution The node's contribution
78
+ */
79
+ add(nodeId, contribution, aggregationRound, communicationRound) {
80
+ if (!this.isValidContribution(nodeId, aggregationRound))
81
+ throw new Error("Tried adding an invalid contribution. Handle this case before calling add.");
82
+ // call the abstract method _add, implemented by subclasses
83
+ this._add(nodeId, contribution, communicationRound);
84
+ // If the aggregator has enough contributions then aggregate the weights
85
+ // and emit the 'aggregation' event
86
+ if (this.isFull()) {
87
+ const aggregatedWeights = this.aggregate();
88
+ // On each aggregation, increment the communication round
89
+ // If all communication rounds were performed, proceed to the next aggregation round
90
+ // and empty the past contributions.
91
+ this._communicationRound++;
92
+ if (this.communicationRound === this.communicationRounds) {
93
+ this._communicationRound = 0;
94
+ this._round++;
95
+ this.contributions = Map();
96
+ }
97
+ // Emitting the 'aggregation' communicates the weights to subscribers
98
+ this.emit('aggregation', aggregatedWeights);
99
+ }
62
100
  }
63
101
  /**
64
102
  * Evaluates whether a given participant contribution can be used in the current aggregation round
65
103
  * the boolean returned by `this.add` is obtained via `this.isValidContribution`
104
+ *
105
+ * @param nodeId the node id of the contribution to be added
106
+ * @param round the aggregation round of the contribution to be added
66
107
  */
67
108
  isValidContribution(nodeId, round) {
68
109
  if (!this.nodes.has(nodeId)) {
@@ -139,13 +180,6 @@ export class Base extends EventEmitter {
139
180
  setNodes(nodeIds) {
140
181
  this._nodes = nodeIds;
141
182
  }
142
- /**
143
- * Empties the current set of "nodes". Usually called at the end of an aggregation round,
144
- * if the set of nodes is meant to change or to be actualized.
145
- */
146
- resetNodes() {
147
- this._nodes = Set();
148
- }
149
183
  /**
150
184
  * Sets the aggregator's round number. To be used whenever the aggregator is out of sync
151
185
  * with the network's round.
@@ -156,18 +190,6 @@ export class Base extends EventEmitter {
156
190
  this._round = round;
157
191
  }
158
192
  }
159
- /**
160
- * Updates the aggregator's state to proceed to the next communication round.
161
- * If all communication rounds were performed, proceeds to the next aggregation round
162
- * and empties the collection of stored contributions.
163
- */
164
- nextRound() {
165
- if (++this._communicationRound === this.communicationRounds) {
166
- this._communicationRound = 0;
167
- this._round++;
168
- this.contributions = Map();
169
- }
170
- }
171
193
  /**
172
194
  * The set of node ids, representing our neighbors within the network.
173
195
  */
@@ -180,16 +202,6 @@ export class Base extends EventEmitter {
180
202
  get round() {
181
203
  return this._round;
182
204
  }
183
- /**
184
- * The aggregator's current size, defined by its number of contributions. The size is bounded by
185
- * the amount of all active nodes times the number of communication rounds.
186
- */
187
- get size() {
188
- return this.contributions
189
- .valueSeq()
190
- .map((m) => m.size)
191
- .reduce((totalSize, size) => totalSize + size) ?? 0;
192
- }
193
205
  /**
194
206
  * The current communication round.
195
207
  */
@@ -9,9 +9,9 @@ type AggregatorOptions = Partial<{
9
9
  /**
10
10
  * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
11
11
  * Here is the ordered list of parameters used to define the aggregator and its default behavior:
12
- * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
12
+ * task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme
13
13
  *
14
- * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
14
+ * If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values.
15
15
  * Otherwise, we default to a MeanAggregator for both training schemes.
16
16
  *
17
17
  * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
@@ -2,9 +2,9 @@ import { aggregator } from '../index.js';
2
2
  /**
3
3
  * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
4
4
  * Here is the ordered list of parameters used to define the aggregator and its default behavior:
5
- * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
5
+ * task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme
6
6
  *
7
- * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
7
+ * If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values.
8
8
  * Otherwise, we default to a MeanAggregator for both training schemes.
9
9
  *
10
10
  * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
@@ -18,9 +18,9 @@ import { aggregator } from '../index.js';
18
18
  * @returns The aggregator
19
19
  */
20
20
  export function getAggregator(task, options = {}) {
21
- const aggregatorType = task.trainingInformation.aggregator ?? 'mean';
21
+ const aggregationStrategy = task.trainingInformation.aggregationStrategy ?? 'mean';
22
22
  const scheme = options.scheme ?? task.trainingInformation.scheme;
23
- switch (aggregatorType) {
23
+ switch (aggregationStrategy) {
24
24
  case 'mean':
25
25
  if (scheme === 'decentralized') {
26
26
  // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
@@ -1,7 +1,4 @@
1
- import type { WeightsContainer } from '../weights/index.js';
2
- import type { Base } from './base.js';
3
- export { Base as AggregatorBase, AggregationStep } from './base.js';
1
+ export { Aggregator, AggregationStep } from './aggregator.js';
4
2
  export { MeanAggregator } from './mean.js';
5
3
  export { SecureAggregator } from './secure.js';
6
4
  export { getAggregator } from './get.js';
7
- export type Aggregator = Base<WeightsContainer>;
@@ -1,4 +1,4 @@
1
- export { Base as AggregatorBase, AggregationStep } from './base.js';
1
+ export { Aggregator, AggregationStep } from './aggregator.js';
2
2
  export { MeanAggregator } from './mean.js';
3
3
  export { SecureAggregator } from './secure.js';
4
4
  export { getAggregator } from './get.js';
@@ -1,12 +1,12 @@
1
1
  import type { Map } from "immutable";
2
- import { Base as Aggregator } from "./base.js";
2
+ import { Aggregator } from "./aggregator.js";
3
3
  import type { WeightsContainer, client } from "../index.js";
4
4
  type ThresholdType = 'relative' | 'absolute';
5
5
  /**
6
6
  * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
7
7
  *
8
8
  */
9
- export declare class MeanAggregator extends Aggregator<WeightsContainer> {
9
+ export declare class MeanAggregator extends Aggregator {
10
10
  #private;
11
11
  /**
12
12
  * Create a mean aggregator that averages all weight updates received when a specified threshold is met.
@@ -31,8 +31,8 @@ export declare class MeanAggregator extends Aggregator<WeightsContainer> {
31
31
  /** Checks whether the contributions buffer is full. */
32
32
  isFull(): boolean;
33
33
  set minNbOfParticipants(minNbOfParticipants: number);
34
- add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
35
- aggregate(): void;
34
+ _add(nodeId: client.NodeID, contribution: WeightsContainer): void;
35
+ aggregate(): WeightsContainer;
36
36
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
37
37
  }
38
38
  export {};
@@ -1,5 +1,5 @@
1
1
  import createDebug from "debug";
2
- import { AggregationStep, Base as Aggregator } from "./base.js";
2
+ import { AggregationStep, Aggregator } from "./aggregator.js";
3
3
  import { aggregation } from "../index.js";
4
4
  const debug = createDebug("discojs:aggregator:mean");
5
5
  /**
@@ -54,7 +54,7 @@ export class MeanAggregator extends Aggregator {
54
54
  else {
55
55
  // Print a warning regarding the default behavior when thresholdType is not specified
56
56
  if (thresholdType === undefined) {
57
- // TODO enforce validity by splitting features instead of warning
57
+ // TODO enforce validity by splitting the different threshold types into separate classes instead of warning
58
58
  debug("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
59
59
  "To instead wait for a single contribution, set thresholdType = 'absolute'");
60
60
  this.#thresholdType = 'relative';
@@ -79,18 +79,9 @@ export class MeanAggregator extends Aggregator {
79
79
  set minNbOfParticipants(minNbOfParticipants) {
80
80
  this.#minNbOfParticipants = minNbOfParticipants;
81
81
  }
82
- add(nodeId, contribution, round, currentContributions = 0) {
83
- if (currentContributions !== 0)
84
- throw new Error("only a single communication round");
85
- if (!this.isValidContribution(nodeId, round))
86
- return false;
87
- this.log(this.contributions.hasIn([0, nodeId])
88
- ? AggregationStep.UPDATE
89
- : AggregationStep.ADD, nodeId);
82
+ _add(nodeId, contribution) {
83
+ this.log(this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
90
84
  this.contributions = this.contributions.setIn([0, nodeId], contribution);
91
- if (this.isFull())
92
- this.aggregate();
93
- return true;
94
85
  }
95
86
  aggregate() {
96
87
  const currentContributions = this.contributions.get(0);
@@ -98,8 +89,7 @@ export class MeanAggregator extends Aggregator {
98
89
  throw new Error("aggregating without any contribution");
99
90
  this.log(AggregationStep.AGGREGATE);
100
91
  const result = aggregation.avg(currentContributions.values());
101
- // Emitting the event runs the superclass' callback to increment the round
102
- this.emit('aggregation', result);
92
+ return result;
103
93
  }
104
94
  makePayloads(weights) {
105
95
  // Communicate our local weights to every other node, be it a peer or a server
@@ -1,5 +1,5 @@
1
1
  import { Map, List } from "immutable";
2
- import { Base as Aggregator } from "./base.js";
2
+ import { Aggregator } from "./aggregator.js";
3
3
  import type { WeightsContainer, client } from "../index.js";
4
4
  /**
5
5
  * Aggregator implementing secure multi-party computation for decentralized learning.
@@ -8,11 +8,11 @@ import type { WeightsContainer, client } from "../index.js";
8
8
  * - then, they sum their received shares and communicate the result.
9
9
  * Finally, nodes are able to average the received partial sums to establish the aggregation result.
10
10
  */
11
- export declare class SecureAggregator extends Aggregator<WeightsContainer> {
11
+ export declare class SecureAggregator extends Aggregator {
12
12
  private readonly maxShareValue;
13
13
  constructor(maxShareValue?: number);
14
- aggregate(): void;
15
- add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
14
+ aggregate(): WeightsContainer;
15
+ _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound: number): void;
16
16
  isFull(): boolean;
17
17
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
18
18
  /** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
@@ -1,6 +1,6 @@
1
1
  import { Map, List, Range } from "immutable";
2
2
  import * as tf from "@tensorflow/tfjs";
3
- import { AggregationStep, Base as Aggregator } from "./base.js";
3
+ import { AggregationStep, Aggregator } from "./aggregator.js";
4
4
  import { aggregation } from "../index.js";
5
5
  /**
6
6
  * Aggregator implementing secure multi-party computation for decentralized learning.
@@ -23,24 +23,20 @@ export class SecureAggregator extends Aggregator {
23
23
  const currentContributions = this.contributions.get(0);
24
24
  if (currentContributions === undefined)
25
25
  throw new Error("aggregating without any contribution");
26
- const result = aggregation.sum(currentContributions.values());
27
- this.emit('aggregation', result);
28
- break;
26
+ return aggregation.sum(currentContributions.values());
29
27
  }
30
28
  // Average the received partial sums
31
29
  case 1: {
32
30
  const currentContributions = this.contributions.get(1);
33
31
  if (currentContributions === undefined)
34
32
  throw new Error("aggregating without any contribution");
35
- const result = aggregation.avg(currentContributions.values());
36
- this.emit('aggregation', result);
37
- break;
33
+ return aggregation.avg(currentContributions.values());
38
34
  }
39
35
  default:
40
36
  throw new Error("communication round is out of bounds");
41
37
  }
42
38
  }
43
- add(nodeId, contribution, round, communicationRound) {
39
+ _add(nodeId, contribution, communicationRound) {
44
40
  switch (communicationRound) {
45
41
  case 0:
46
42
  case 1:
@@ -48,15 +44,9 @@ export class SecureAggregator extends Aggregator {
48
44
  default:
49
45
  throw new Error("requires communication round to be 0 or 1");
50
46
  }
51
- if (!this.isValidContribution(nodeId, round))
52
- return false;
53
- this.log(this.contributions.hasIn([communicationRound, nodeId])
54
- ? AggregationStep.UPDATE
55
- : AggregationStep.ADD, nodeId);
47
+ this.log(this.contributions.hasIn([communicationRound, nodeId]) ?
48
+ AggregationStep.UPDATE : AggregationStep.ADD, nodeId.slice(0, 4));
56
49
  this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
57
- if (this.isFull())
58
- this.aggregate();
59
- return true;
60
50
  }
61
51
  isFull() {
62
52
  return ((this.contributions.get(this.communicationRound)?.size ?? 0) ===
@@ -66,7 +56,7 @@ export class SecureAggregator extends Aggregator {
66
56
  switch (this.communicationRound) {
67
57
  case 0: {
68
58
  const shares = this.generateAllShares(weights);
69
- // Abitrarily assign our shares to the available nodes
59
+ // Arbitrarily assign our shares to the available nodes
70
60
  return Map(List(this.nodes).zip(shares));
71
61
  }
72
62
  // Send our partial sum to every other nodes
@@ -13,21 +13,35 @@ export declare abstract class Client extends EventEmitter<{
13
13
  readonly url: URL;
14
14
  readonly task: Task;
15
15
  readonly aggregator: Aggregator;
16
- /**
17
- * Own ID provided by the network's server.
18
- */
19
16
  protected _ownId?: NodeID;
17
+ protected _server?: EventConnection;
18
+ protected aggregationResult?: Promise<WeightsContainer>;
20
19
  /**
21
- * The network's server.
20
+ * When the server notifies clients to pause and wait until more
21
+ * participants join, we rely on this promise to wait
22
+ * until the server signals that the training can resume
22
23
  */
23
- protected _server?: EventConnection;
24
+ protected promiseForMoreParticipants: Promise<void> | undefined;
24
25
  /**
25
- * The aggregator's result produced after aggregation.
26
+ * When the server notifies the client that they can resume training
27
+ * after waiting for more participants, we want to be able to display what
28
+ * we were doing before waiting (training locally or updating our model).
29
+ * We use this attribute to store the status to rollback to when we stop waiting
26
30
  */
27
- protected aggregationResult?: Promise<WeightsContainer>;
31
+ private previousStatus;
28
32
  constructor(url: URL, // The network server's URL to connect to
29
33
  task: Task, // The client's corresponding task
30
34
  aggregator: Aggregator);
35
+ /**
36
+ * Communication callback called at the beginning of every training round.
37
+ */
38
+ abstract onRoundBeginCommunication(): Promise<void>;
39
+ /**
40
+ * Communication callback called the end of every training round.
41
+ * @param weights The local weight update resulting for the current local training round
42
+ * @returns aggregated weights or the local weights upon error
43
+ */
44
+ abstract onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
31
45
  /**
32
46
  * Handles the connection process from the client to any sort of network server.
33
47
  * This method is overriden by the federated and decentralized clients
@@ -39,21 +53,61 @@ export declare abstract class Client extends EventEmitter<{
39
53
  */
40
54
  disconnect(): Promise<void>;
41
55
  /**
42
- * Fetches the latest model available on the network's server, for the adequate task.
43
- * @returns The latest model
56
+ * Emits the round status specified. It also stores the status emitted such that
57
+ * if the server tells the client to wait for more participants, it can display
58
+ * the waiting status and once enough participants join, it can display the previous status again
44
59
  */
45
- getLatestModel(): Promise<Model>;
60
+ protected saveAndEmit(status: RoundStatus): void;
46
61
  /**
47
- * Communication callback called at the beginning of every training round.
62
+ * For both federated and decentralized clients, we listen to the server to tell
63
+ * us whether there are enough participants to train. If not, we pause until further notice.
64
+ * When a client connects to the server, the server answers with the session information (id,
65
+ * number of participants) and whether there are enough participants.
66
+ * When there are the server sends a new EnoughParticipant message to update the client.
67
+ *
68
+ * `setMessageInversionFlag` is used to address the following scenario:
69
+ * 1. Client 1 connect to the server
70
+ * 2. Server answers with message A containing "not enough participants"
71
+ * 3. Before A arrives a new client joins. There are enough participants now.
72
+ * 4. Server updates client 1 with message B saying "there are enough participants"
73
+ * 5. Due to network and message sizes message B can arrive before A.
74
+ * i.e. "there are enough participants" arrives before "not enough participants"
75
+ * ending up with client 1 thinking it needs to wait for more participants.
76
+ *
77
+ * To keep track of this message inversion, `setMessageInversionFlag`
78
+ * tells us whether a message inversion occurred (by setting a boolean to true)
79
+ *
80
+ * @param setMessageInversionFlag function flagging whether a message inversion occurred
81
+ * between a NewNodeInfo message and an EnoughParticipant message.
48
82
  */
49
- abstract onRoundBeginCommunication(): Promise<void>;
83
+ protected setupServerCallbacks(setMessageInversionFlag: () => void): void;
50
84
  /**
51
- * Communication callback called the end of every training round.
52
- * @param weights The local weight update resulting for the current local training round
53
- * @returns aggregated weights or the local weights upon error
85
+ * Method called when the server notifies the client that there aren't enough
86
+ * participants (anymore) to start/continue training
87
+ * The method creates a promise that will resolve once the server notifies
88
+ * the client that the training can resume via a subsequent EnoughParticipants message
89
+ * @returns a promise which resolves when enough participants joined the session
54
90
  */
55
- abstract onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
56
- get nbOfParticipants(): number;
91
+ protected createPromiseForMoreParticipants(): Promise<void>;
92
+ protected waitForParticipantsIfNeeded(): Promise<void>;
93
+ /**
94
+ * Fetches the latest model available on the network's server, for the adequate task.
95
+ * @returns The latest model
96
+ */
97
+ getLatestModel(): Promise<Model>;
98
+ /**
99
+ * Number of contributors to a collaborative session
100
+ * If decentralized, it should be the number of peers
101
+ * If federated, it should the number of participants excluding the server
102
+ * If local it should be 1
103
+ */
104
+ abstract getNbOfParticipants(): number;
57
105
  get ownId(): NodeID;
58
106
  get server(): EventConnection;
107
+ /**
108
+ * Whether the client should wait until more
109
+ * participants join the session, i.e. a promise has been created
110
+ */
111
+ get waitingForMoreParticipants(): boolean;
59
112
  }
113
+ export declare function shortenId(id: string): string;