@epfml/discojs 2.2.2-p20240703101552.0 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/dist/aggregator/base.d.ts +9 -48
  2. package/dist/aggregator/base.js +8 -69
  3. package/dist/aggregator/get.d.ts +23 -11
  4. package/dist/aggregator/get.js +40 -23
  5. package/dist/aggregator/index.d.ts +1 -1
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +25 -6
  8. package/dist/aggregator/mean.js +62 -17
  9. package/dist/aggregator/secure.d.ts +2 -2
  10. package/dist/aggregator/secure.js +4 -7
  11. package/dist/client/base.d.ts +3 -3
  12. package/dist/client/base.js +6 -8
  13. package/dist/client/decentralized/base.d.ts +27 -10
  14. package/dist/client/decentralized/base.js +123 -86
  15. package/dist/client/decentralized/peer.js +7 -12
  16. package/dist/client/decentralized/peer_pool.js +6 -2
  17. package/dist/client/event_connection.d.ts +1 -1
  18. package/dist/client/event_connection.js +3 -3
  19. package/dist/client/federated/base.d.ts +5 -21
  20. package/dist/client/federated/base.js +38 -61
  21. package/dist/client/federated/messages.d.ts +2 -10
  22. package/dist/client/federated/messages.js +0 -1
  23. package/dist/client/index.d.ts +1 -1
  24. package/dist/client/index.js +1 -1
  25. package/dist/client/local.d.ts +3 -1
  26. package/dist/client/local.js +4 -1
  27. package/dist/client/messages.d.ts +1 -2
  28. package/dist/client/messages.js +8 -3
  29. package/dist/client/utils.d.ts +4 -2
  30. package/dist/client/utils.js +18 -3
  31. package/dist/dataset/data/data.d.ts +1 -1
  32. package/dist/dataset/data/data.js +13 -2
  33. package/dist/dataset/data/preprocessing/image_preprocessing.js +6 -4
  34. package/dist/default_tasks/cifar10.js +1 -2
  35. package/dist/default_tasks/lus_covid.js +0 -5
  36. package/dist/default_tasks/mnist.js +15 -14
  37. package/dist/default_tasks/simple_face.js +0 -2
  38. package/dist/default_tasks/titanic.js +2 -4
  39. package/dist/default_tasks/wikitext.js +7 -1
  40. package/dist/index.d.ts +0 -1
  41. package/dist/index.js +0 -1
  42. package/dist/models/gpt/config.js +1 -1
  43. package/dist/privacy.d.ts +8 -10
  44. package/dist/privacy.js +25 -40
  45. package/dist/task/task_handler.js +10 -2
  46. package/dist/task/training_information.d.ts +7 -4
  47. package/dist/task/training_information.js +25 -6
  48. package/dist/training/disco.d.ts +30 -28
  49. package/dist/training/disco.js +75 -73
  50. package/dist/training/index.d.ts +1 -1
  51. package/dist/training/index.js +1 -0
  52. package/dist/training/trainer.d.ts +16 -0
  53. package/dist/training/trainer.js +72 -0
  54. package/dist/types.d.ts +0 -2
  55. package/dist/weights/weights_container.d.ts +0 -5
  56. package/dist/weights/weights_container.js +0 -7
  57. package/package.json +1 -1
  58. package/dist/async_informant.d.ts +0 -15
  59. package/dist/async_informant.js +0 -42
  60. package/dist/training/trainer/distributed_trainer.d.ts +0 -20
  61. package/dist/training/trainer/distributed_trainer.js +0 -41
  62. package/dist/training/trainer/local_trainer.d.ts +0 -12
  63. package/dist/training/trainer/local_trainer.js +0 -24
  64. package/dist/training/trainer/trainer.d.ts +0 -32
  65. package/dist/training/trainer/trainer.js +0 -61
  66. package/dist/training/trainer/trainer_builder.d.ts +0 -23
  67. package/dist/training/trainer/trainer_builder.js +0 -47
@@ -1,5 +1,6 @@
1
1
  import { Map, Set } from 'immutable';
2
- import type { client, Model, AsyncInformant } from '../index.js';
2
+ import type { client } from '../index.js';
3
+ import { EventEmitter } from '../utils/event_emitter.js';
3
4
  export declare enum AggregationStep {
4
5
  ADD = 0,
5
6
  UPDATE = 1,
@@ -8,12 +9,13 @@ export declare enum AggregationStep {
8
9
  /**
9
10
  * Main, abstract, aggregator class whose role is to buffer contributions and to produce
10
11
  * a result based off their aggregation, whenever some defined condition is met.
12
+ *
13
+ * Emits an event whenever an aggregation step is performed.
14
+ * Users wait for this event to fetch the aggregation result.
11
15
  */
12
- export declare abstract class Base<T> {
13
- /**
14
- * The Model whose weights are updated on aggregation.
15
- */
16
- protected _model?: Model | undefined;
16
+ export declare abstract class Base<T> extends EventEmitter<{
17
+ 'aggregation': T;
18
+ }> {
17
19
  /**
18
20
  * The round cut-off for contributions.
19
21
  */
@@ -33,19 +35,6 @@ export declare abstract class Base<T> {
33
35
  * of all active nodes, depending on the aggregation scheme.
34
36
  */
35
37
  protected contributions: Map<number, Map<client.NodeID, T>>;
36
- /**
37
- * Emits the aggregation event whenever an aggregation step is performed.
38
- * Triggers the resolve of the result promise and the preparation for the
39
- * next aggregation round.
40
- */
41
- private readonly eventEmitter;
42
- protected informant?: AsyncInformant<T>;
43
- /**
44
- * The result promise which, on resolve, will contain the current aggregation result.
45
- * This promise should be fetched by any object making use of an aggregator, in order
46
- * to await upon aggregation.
47
- */
48
- protected result: Promise<T>;
49
38
  /**
50
39
  * The current aggregation round, used for assessing whether a node contribution is recent enough
51
40
  * or not.
@@ -59,10 +48,6 @@ export declare abstract class Base<T> {
59
48
  */
60
49
  protected _communicationRound: number;
61
50
  constructor(
62
- /**
63
- * The Model whose weights are updated on aggregation.
64
- */
65
- _model?: Model | undefined,
66
51
  /**
67
52
  * The round cut-off for contributions.
68
53
  */
@@ -85,7 +70,6 @@ export declare abstract class Base<T> {
85
70
  * Must store the aggregation's result in the aggregator's result promise.
86
71
  */
87
72
  abstract aggregate(): void;
88
- registerObserver(informant: AsyncInformant<T>): void;
89
73
  /**
90
74
  * Returns whether the given round is recent enough, dependent on the
91
75
  * aggregator's round cutoff.
@@ -99,16 +83,12 @@ export declare abstract class Base<T> {
99
83
  * @param from The node which triggered the logging message
100
84
  */
101
85
  log(step: AggregationStep, from?: client.NodeID): void;
102
- /**
103
- * Sets the aggregator's TF.js model.
104
- * @param model The new TF.js model
105
- */
106
- setModel(model: Model): void;
107
86
  /**
108
87
  * Adds a node's id to the set of active nodes. A node represents an active neighbor
109
88
  * peer/client within the network, whom we are communicating with during this aggregation
110
89
  * round.
111
90
  * @param nodeId The node to be added
91
+ * @returns True is the node wasn't already in the list of nodes, False if already included
112
92
  */
113
93
  registerNode(nodeId: client.NodeID): boolean;
114
94
  /**
@@ -129,27 +109,12 @@ export declare abstract class Base<T> {
129
109
  * @param round The new round
130
110
  */
131
111
  setRound(round: number): void;
132
- /**
133
- * Emits the event containing the aggregation result, which allows the result
134
- * promise to resolve and for the next aggregation round to take place.
135
- * @param aggregated The aggregation result
136
- */
137
- protected emit(aggregated: T): void;
138
112
  /**
139
113
  * Updates the aggregator's state to proceed to the next communication round.
140
114
  * If all communication rounds were performed, proceeds to the next aggregation round
141
115
  * and empties the collection of stored contributions.
142
116
  */
143
117
  nextRound(): void;
144
- private makeResult;
145
- /**
146
- * Aggregation steps are performed asynchronously, yet can be awaited upon when required.
147
- * This function gives access to the current aggregation result's promise, which will
148
- * eventually resolve and contain the result of the very next aggregation step, at the
149
- * time of the function call.
150
- * @returns The promise containing the aggregation result
151
- */
152
- receiveResult(): Promise<T>;
153
118
  /**
154
119
  * Constructs the payloads sent to other nodes as contribution.
155
120
  * @param base Object from which the payload is computed
@@ -169,10 +134,6 @@ export declare abstract class Base<T> {
169
134
  * the amount of all active nodes times the number of communication rounds.
170
135
  */
171
136
  get size(): number;
172
- /**
173
- * The aggregator's current model.
174
- */
175
- get model(): Model | undefined;
176
137
  /**
177
138
  * The current communication round.
178
139
  */
@@ -9,9 +9,11 @@ export var AggregationStep;
9
9
  /**
10
10
  * Main, abstract, aggregator class whose role is to buffer contributions and to produce
11
11
  * a result based off their aggregation, whenever some defined condition is met.
12
+ *
13
+ * Emits an event whenever an aggregation step is performed.
14
+ * Users wait for this event to fetch the aggregation result.
12
15
  */
13
- export class Base {
14
- _model;
16
+ export class Base extends EventEmitter {
15
17
  roundCutoff;
16
18
  communicationRounds;
17
19
  /**
@@ -26,19 +28,6 @@ export class Base {
26
28
  */
27
29
  // communication round -> NodeID -> T
28
30
  contributions;
29
- /**
30
- * Emits the aggregation event whenever an aggregation step is performed.
31
- * Triggers the resolve of the result promise and the preparation for the
32
- * next aggregation round.
33
- */
34
- eventEmitter = new EventEmitter();
35
- informant;
36
- /**
37
- * The result promise which, on resolve, will contain the current aggregation result.
38
- * This promise should be fetched by any object making use of an aggregator, in order
39
- * to await upon aggregation.
40
- */
41
- result;
42
31
  /**
43
32
  * The current aggregation round, used for assessing whether a node contribution is recent enough
44
33
  * or not.
@@ -52,10 +41,6 @@ export class Base {
52
41
  */
53
42
  _communicationRound = 0;
54
43
  constructor(
55
- /**
56
- * The Model whose weights are updated on aggregation.
57
- */
58
- _model,
59
44
  /**
60
45
  * The round cut-off for contributions.
61
46
  */
@@ -64,21 +49,14 @@ export class Base {
64
49
  * The number of communication rounds occurring during any given aggregation round.
65
50
  */
66
51
  communicationRounds = 1) {
67
- this._model = _model;
52
+ super();
68
53
  this.roundCutoff = roundCutoff;
69
54
  this.communicationRounds = communicationRounds;
70
55
  this.contributions = Map();
71
56
  this._nodes = Set();
72
- // Make the initial result promise
73
- this.result = this.makeResult();
74
57
  // On every aggregation, update the object's state to match the current aggregation
75
58
  // and communication rounds.
76
- this.eventEmitter.on('aggregation', () => {
77
- this.nextRound();
78
- });
79
- }
80
- registerObserver(informant) {
81
- this.informant = informant;
59
+ this.on('aggregation', () => this.nextRound());
82
60
  }
83
61
  /**
84
62
  * Returns whether the given round is recent enough, dependent on the
@@ -97,7 +75,7 @@ export class Base {
97
75
  log(step, from) {
98
76
  switch (step) {
99
77
  case AggregationStep.ADD:
100
- console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
78
+ console.log(`Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
101
79
  break;
102
80
  case AggregationStep.UPDATE:
103
81
  if (from === undefined) {
@@ -115,18 +93,12 @@ export class Base {
115
93
  }
116
94
  }
117
95
  }
118
- /**
119
- * Sets the aggregator's TF.js model.
120
- * @param model The new TF.js model
121
- */
122
- setModel(model) {
123
- this._model = model;
124
- }
125
96
  /**
126
97
  * Adds a node's id to the set of active nodes. A node represents an active neighbor
127
98
  * peer/client within the network, whom we are communicating with during this aggregation
128
99
  * round.
129
100
  * @param nodeId The node to be added
101
+ * @returns True is the node wasn't already in the list of nodes, False if already included
130
102
  */
131
103
  registerNode(nodeId) {
132
104
  if (!this.nodes.has(nodeId)) {
@@ -161,14 +133,6 @@ export class Base {
161
133
  this._round = round;
162
134
  }
163
135
  }
164
- /**
165
- * Emits the event containing the aggregation result, which allows the result
166
- * promise to resolve and for the next aggregation round to take place.
167
- * @param aggregated The aggregation result
168
- */
169
- emit(aggregated) {
170
- this.eventEmitter.emit('aggregation', aggregated);
171
- }
172
136
  /**
173
137
  * Updates the aggregator's state to proceed to the next communication round.
174
138
  * If all communication rounds were performed, proceeds to the next aggregation round
@@ -180,25 +144,6 @@ export class Base {
180
144
  this._round++;
181
145
  this.contributions = Map();
182
146
  }
183
- this.result = this.makeResult();
184
- this.informant?.update();
185
- }
186
- async makeResult() {
187
- return await new Promise((resolve) => {
188
- this.eventEmitter.once('aggregation', (w) => {
189
- resolve(w);
190
- });
191
- });
192
- }
193
- /**
194
- * Aggregation steps are performed asynchronously, yet can be awaited upon when required.
195
- * This function gives access to the current aggregation result's promise, which will
196
- * eventually resolve and contain the result of the very next aggregation step, at the
197
- * time of the function call.
198
- * @returns The promise containing the aggregation result
199
- */
200
- async receiveResult() {
201
- return await this.result;
202
147
  }
203
148
  /**
204
149
  * The set of node ids, representing our neighbors within the network.
@@ -222,12 +167,6 @@ export class Base {
222
167
  .map((m) => m.size)
223
168
  .reduce((totalSize, size) => totalSize + size) ?? 0;
224
169
  }
225
- /**
226
- * The aggregator's current model.
227
- */
228
- get model() {
229
- return this._model;
230
- }
231
170
  /**
232
171
  * The current communication round.
233
172
  */
@@ -1,16 +1,28 @@
1
1
  import type { Task } from '../index.js';
2
2
  import { aggregator } from '../index.js';
3
+ type AggregatorOptions = Partial<{
4
+ scheme: Task['trainingInformation']['scheme'];
5
+ roundCutOff: number;
6
+ threshold: number;
7
+ thresholdType: 'relative' | 'absolute';
8
+ }>;
3
9
  /**
4
- * Enumeration of the available types of aggregator.
5
- */
6
- export declare enum AggregatorChoice {
7
- MEAN = 0,
8
- SECURE = 1,
9
- BANDIT = 2
10
- }
11
- /**
12
- * Provides the aggregator object adequate to the given task.
13
- * @param task The task
10
+ * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
11
+ * Here is the ordered list of parameters used to define the aggregator and its default behavior:
12
+ * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
13
+ *
14
+ * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
15
+ * Otherwise, we default to a MeanAggregator for both training schemes.
16
+ *
17
+ * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
18
+ * Unless specified otherwise, for federated learning or local training the aggregator default to waiting
19
+ * for a single contribution to trigger a model update.
20
+ * (the server's model update for federated learning or our own contribution if training locally)
21
+ * For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
22
+ *
23
+ * @param task The task object associated with the current training session
24
+ * @param options Options passed down to the aggregator's constructor
14
25
  * @returns The aggregator
15
26
  */
16
- export declare function getAggregator(task: Task): aggregator.Aggregator;
27
+ export declare function getAggregator(task: Task, options?: AggregatorOptions): aggregator.Aggregator;
28
+ export {};
@@ -1,31 +1,48 @@
1
1
  import { aggregator } from '../index.js';
2
2
  /**
3
- * Enumeration of the available types of aggregator.
4
- */
5
- export var AggregatorChoice;
6
- (function (AggregatorChoice) {
7
- AggregatorChoice[AggregatorChoice["MEAN"] = 0] = "MEAN";
8
- AggregatorChoice[AggregatorChoice["SECURE"] = 1] = "SECURE";
9
- AggregatorChoice[AggregatorChoice["BANDIT"] = 2] = "BANDIT";
10
- })(AggregatorChoice || (AggregatorChoice = {}));
11
- /**
12
- * Provides the aggregator object adequate to the given task.
13
- * @param task The task
3
+ * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
4
+ * Here is the ordered list of parameters used to define the aggregator and its default behavior:
5
+ * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
6
+ *
7
+ * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
8
+ * Otherwise, we default to a MeanAggregator for both training schemes.
9
+ *
10
+ * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
11
+ * Unless specified otherwise, for federated learning or local training the aggregator default to waiting
12
+ * for a single contribution to trigger a model update.
13
+ * (the server's model update for federated learning or our own contribution if training locally)
14
+ * For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
15
+ *
16
+ * @param task The task object associated with the current training session
17
+ * @param options Options passed down to the aggregator's constructor
14
18
  * @returns The aggregator
15
19
  */
16
- export function getAggregator(task) {
17
- const error = new Error('not implemented');
18
- switch (task.trainingInformation.aggregator) {
19
- case AggregatorChoice.MEAN:
20
- return new aggregator.MeanAggregator();
21
- case AggregatorChoice.BANDIT:
22
- throw error;
23
- case AggregatorChoice.SECURE:
24
- if (task.trainingInformation.scheme !== 'decentralized') {
20
+ export function getAggregator(task, options = {}) {
21
+ const aggregatorType = task.trainingInformation.aggregator ?? 'mean';
22
+ const scheme = options.scheme ?? task.trainingInformation.scheme;
23
+ switch (aggregatorType) {
24
+ case 'mean':
25
+ if (scheme === 'decentralized') {
26
+ // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
27
+ options = {
28
+ roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
29
+ ...options
30
+ };
31
+ }
32
+ else {
33
+ // If scheme == 'federated' then we only expect the server's contribution at each round
34
+ // so we set the aggregation threshold to 1 contribution
35
+ // If scheme == 'local' then we only expect our own contribution
36
+ options = {
37
+ roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
38
+ ...options
39
+ };
40
+ }
41
+ return new aggregator.MeanAggregator(options.roundCutOff, options.threshold, options.thresholdType);
42
+ case 'secure':
43
+ if (scheme !== 'decentralized') {
25
44
  throw new Error('secure aggregation is currently supported for decentralized only');
26
45
  }
27
- return new aggregator.SecureAggregator();
28
- default:
29
- return new aggregator.MeanAggregator();
46
+ return new aggregator.SecureAggregator(task.trainingInformation.maxShareValue);
30
47
  }
31
48
  }
@@ -3,5 +3,5 @@ import type { Base } from './base.js';
3
3
  export { Base as AggregatorBase, AggregationStep } from './base.js';
4
4
  export { MeanAggregator } from './mean.js';
5
5
  export { SecureAggregator } from './secure.js';
6
- export { getAggregator, AggregatorChoice } from './get.js';
6
+ export { getAggregator } from './get.js';
7
7
  export type Aggregator = Base<WeightsContainer>;
@@ -1,4 +1,4 @@
1
1
  export { Base as AggregatorBase, AggregationStep } from './base.js';
2
2
  export { MeanAggregator } from './mean.js';
3
3
  export { SecureAggregator } from './secure.js';
4
- export { getAggregator, AggregatorChoice } from './get.js';
4
+ export { getAggregator } from './get.js';
@@ -1,18 +1,37 @@
1
1
  import type { Map } from "immutable";
2
2
  import { Base as Aggregator } from "./base.js";
3
- import type { Model, WeightsContainer, client } from "../index.js";
4
- /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
3
+ import type { WeightsContainer, client } from "../index.js";
4
+ type ThresholdType = 'relative' | 'absolute';
5
+ /**
6
+ * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
7
+ *
8
+ */
5
9
  export declare class MeanAggregator extends Aggregator<WeightsContainer> {
6
10
  #private;
7
11
  /**
8
- * @param threshold - how many contributions for trigger an aggregation step.
9
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
10
- * - absolute: t > 1, thus requiring t contributions
12
+ * Create a mean aggregator that averages all weight updates received when a specified threshold is met.
13
+ * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
14
+ * only accepts contributions from the current round (drops contributions from previous rounds).
15
+ *
16
+ * @param threshold - how many contributions trigger an aggregation step.
17
+ * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
18
+ * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
19
+ * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
20
+ * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
21
+ * set `threshold = 1` and `thresholdType = 'absolute'`
22
+ * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
23
+ * If `threshold != 1` then the specified thresholdType is ignored and overwritten
24
+ * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
25
+ * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
26
+ * @param roundCutoff - from how many past rounds do we still accept contributions.
27
+ * If 0 then only accept contributions from the current round,
28
+ * if 1 then the current round and the previous one, etc.
11
29
  */
12
- constructor(model?: Model, roundCutoff?: number, threshold?: number);
30
+ constructor(roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
13
31
  /** Checks whether the contributions buffer is full. */
14
32
  isFull(): boolean;
15
33
  add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
16
34
  aggregate(): void;
17
35
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
18
36
  }
37
+ export {};
@@ -1,39 +1,86 @@
1
1
  import { AggregationStep, Base as Aggregator } from "./base.js";
2
2
  import { aggregation } from "../index.js";
3
- /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
3
+ /**
4
+ * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
5
+ *
6
+ */
4
7
  export class MeanAggregator extends Aggregator {
5
8
  #threshold;
9
+ #thresholdType;
6
10
  /**
7
- * @param threshold - how many contributions for trigger an aggregation step.
8
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
9
- * - absolute: t > 1, thus requiring t contributions
11
+ * Create a mean aggregator that averages all weight updates received when a specified threshold is met.
12
+ * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
13
+ * only accepts contributions from the current round (drops contributions from previous rounds).
14
+ *
15
+ * @param threshold - how many contributions trigger an aggregation step.
16
+ * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
17
+ * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
18
+ * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
19
+ * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
20
+ * set `threshold = 1` and `thresholdType = 'absolute'`
21
+ * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
22
+ * If `threshold != 1` then the specified thresholdType is ignored and overwritten
23
+ * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
24
+ * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
25
+ * @param roundCutoff - from how many past rounds do we still accept contributions.
26
+ * If 0 then only accept contributions from the current round,
27
+ * if 1 then the current round and the previous one, etc.
10
28
  */
11
- // TODO no way to require a single contribution
12
- constructor(model, roundCutoff = 0, threshold = 1) {
29
+ constructor(roundCutoff = 0, threshold = 1, thresholdType) {
13
30
  if (threshold <= 0)
14
- throw new Error("threshold must be striclty positive");
15
- if (threshold > 1 && !Number.isInteger(threshold))
16
- throw new Error("absolute thresholds must be integeral");
17
- super(model, roundCutoff, 1);
31
+ throw new Error("threshold must be strictly positive");
32
+ if (threshold > 1 && (!Number.isInteger(threshold)))
33
+ throw new Error("absolute thresholds must be integral");
34
+ super(roundCutoff, 1);
18
35
  this.#threshold = threshold;
36
+ if (threshold < 1) {
37
+ // Throw exception if threshold and thresholdType are conflicting
38
+ if (thresholdType === 'absolute') {
39
+ throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`);
40
+ }
41
+ this.#thresholdType = 'relative';
42
+ }
43
+ else if (threshold > 1) {
44
+ // Throw exception if threshold and thresholdType are conflicting
45
+ if (thresholdType === 'relative') {
46
+ throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`);
47
+ }
48
+ this.#thresholdType = 'absolute';
49
+ }
50
+ // remaining case: threshold == 1
51
+ else {
52
+ // Print a warning regarding the default behavior when thresholdType is not specified
53
+ if (thresholdType === undefined) {
54
+ console.warn("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
55
+ "To instead wait for a single contribution, set thresholdType = 'absolute'");
56
+ this.#thresholdType = 'relative';
57
+ }
58
+ else {
59
+ this.#thresholdType = thresholdType;
60
+ }
61
+ }
19
62
  }
20
63
  /** Checks whether the contributions buffer is full. */
21
64
  isFull() {
22
- const actualThreshold = this.#threshold <= 1
65
+ const thresholdValue = this.#thresholdType == 'relative'
23
66
  ? this.#threshold * this.nodes.size
24
67
  : this.#threshold;
25
- return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
68
+ return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
26
69
  }
27
70
  add(nodeId, contribution, round, currentContributions = 0) {
28
71
  if (currentContributions !== 0)
29
72
  throw new Error("only a single communication round");
30
- if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
73
+ if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
74
+ if (!this.nodes.has(nodeId))
75
+ console.warn("Contribution rejected because node id is not registered");
76
+ if (!this.isWithinRoundCutoff(round))
77
+ console.warn(`Contribution rejected because round ${round} is not within round cutoff`);
31
78
  return false;
79
+ }
32
80
  this.log(this.contributions.hasIn([0, nodeId])
33
81
  ? AggregationStep.UPDATE
34
82
  : AggregationStep.ADD, nodeId);
35
83
  this.contributions = this.contributions.setIn([0, nodeId], contribution);
36
- this.informant?.update();
37
84
  if (this.isFull())
38
85
  this.aggregate();
39
86
  return true;
@@ -44,9 +91,7 @@ export class MeanAggregator extends Aggregator {
44
91
  throw new Error("aggregating without any contribution");
45
92
  this.log(AggregationStep.AGGREGATE);
46
93
  const result = aggregation.avg(currentContributions.values());
47
- if (this.model !== undefined)
48
- this.model.weights = result;
49
- this.emit(result);
94
+ this.emit('aggregation', result);
50
95
  }
51
96
  makePayloads(weights) {
52
97
  // Communicate our local weights to every other node, be it a peer or a server
@@ -1,6 +1,6 @@
1
1
  import { Map, List } from "immutable";
2
2
  import { Base as Aggregator } from "./base.js";
3
- import type { Model, WeightsContainer, client } from "../index.js";
3
+ import type { WeightsContainer, client } from "../index.js";
4
4
  /**
5
5
  * Aggregator implementing secure multi-party computation for decentralized learning.
6
6
  * An aggregation consists of two communication rounds:
@@ -10,7 +10,7 @@ import type { Model, WeightsContainer, client } from "../index.js";
10
10
  */
11
11
  export declare class SecureAggregator extends Aggregator<WeightsContainer> {
12
12
  private readonly maxShareValue;
13
- constructor(model?: Model, maxShareValue?: number);
13
+ constructor(maxShareValue?: number);
14
14
  aggregate(): void;
15
15
  add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
16
16
  isFull(): boolean;
@@ -11,8 +11,8 @@ import { aggregation } from "../index.js";
11
11
  */
12
12
  export class SecureAggregator extends Aggregator {
13
13
  maxShareValue;
14
- constructor(model, maxShareValue = 100) {
15
- super(model, 0, 2);
14
+ constructor(maxShareValue = 100) {
15
+ super(0, 2);
16
16
  this.maxShareValue = maxShareValue;
17
17
  }
18
18
  aggregate() {
@@ -24,7 +24,7 @@ export class SecureAggregator extends Aggregator {
24
24
  if (currentContributions === undefined)
25
25
  throw new Error("aggregating without any contribution");
26
26
  const result = aggregation.sum(currentContributions.values());
27
- this.emit(result);
27
+ this.emit('aggregation', result);
28
28
  break;
29
29
  }
30
30
  // Average the received partial sums
@@ -33,9 +33,7 @@ export class SecureAggregator extends Aggregator {
33
33
  if (currentContributions === undefined)
34
34
  throw new Error("aggregating without any contribution");
35
35
  const result = aggregation.avg(currentContributions.values());
36
- if (this.model !== undefined)
37
- this.model.weights = result;
38
- this.emit(result);
36
+ this.emit('aggregation', result);
39
37
  break;
40
38
  }
41
39
  default:
@@ -56,7 +54,6 @@ export class SecureAggregator extends Aggregator {
56
54
  ? AggregationStep.UPDATE
57
55
  : AggregationStep.ADD, nodeId);
58
56
  this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
59
- this.informant?.update();
60
57
  if (this.isFull())
61
58
  this.aggregate();
62
59
  return true;
@@ -1,4 +1,3 @@
1
- import type { Set } from 'immutable';
2
1
  import type { Model, Task, WeightsContainer } from '../index.js';
3
2
  import type { NodeID } from './types.js';
4
3
  import type { EventConnection } from './event_connection.js';
@@ -68,9 +67,10 @@ export declare abstract class Base {
68
67
  * Communication callback called the end of every training round.
69
68
  * @param _weights The most recent local weight updates
70
69
  * @param _round The current training round
70
+ * @returns aggregated weights or the local weights upon error
71
71
  */
72
- onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
73
- get nodes(): Set<NodeID>;
72
+ abstract onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<WeightsContainer>;
73
+ get nbOfParticipants(): number;
74
74
  get ownId(): NodeID;
75
75
  get server(): EventConnection;
76
76
  }
@@ -64,14 +64,12 @@ export class Base {
64
64
  * @param _round The current training round
65
65
  */
66
66
  async onRoundBeginCommunication(_weights, _round) { }
67
- /**
68
- * Communication callback called the end of every training round.
69
- * @param _weights The most recent local weight updates
70
- * @param _round The current training round
71
- */
72
- async onRoundEndCommunication(_weights, _round) { }
73
- get nodes() {
74
- return this.aggregator.nodes;
67
+ // Number of contributors to a collaborative session
68
+ // If decentralized, it should be the number of peers
69
+ // If federated, it should the number of participants excluding the server
70
+ // If local it should be 1
71
+ get nbOfParticipants() {
72
+ return this.aggregator.nodes.size; // overriden by the federated client
75
73
  }
76
74
  get ownId() {
77
75
  if (this._ownId === undefined) {