@epfml/discojs 2.1.2-p20240723143623.0 → 2.1.2-p20240723160120.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/dist/aggregator/base.d.ts +8 -48
  2. package/dist/aggregator/base.js +6 -68
  3. package/dist/aggregator/get.d.ts +0 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/mean.d.ts +2 -2
  6. package/dist/aggregator/mean.js +3 -6
  7. package/dist/aggregator/secure.d.ts +2 -2
  8. package/dist/aggregator/secure.js +4 -7
  9. package/dist/client/base.d.ts +2 -1
  10. package/dist/client/base.js +0 -6
  11. package/dist/client/decentralized/base.d.ts +2 -2
  12. package/dist/client/decentralized/base.js +9 -8
  13. package/dist/client/federated/base.d.ts +1 -1
  14. package/dist/client/federated/base.js +2 -1
  15. package/dist/client/local.d.ts +3 -1
  16. package/dist/client/local.js +4 -1
  17. package/dist/default_tasks/cifar10.js +1 -2
  18. package/dist/default_tasks/mnist.js +0 -2
  19. package/dist/default_tasks/simple_face.js +0 -2
  20. package/dist/default_tasks/titanic.js +0 -2
  21. package/dist/index.d.ts +0 -1
  22. package/dist/index.js +0 -1
  23. package/dist/privacy.d.ts +8 -10
  24. package/dist/privacy.js +25 -40
  25. package/dist/task/training_information.d.ts +6 -2
  26. package/dist/task/training_information.js +17 -5
  27. package/dist/training/disco.d.ts +30 -28
  28. package/dist/training/disco.js +76 -61
  29. package/dist/training/index.d.ts +1 -1
  30. package/dist/training/index.js +1 -0
  31. package/dist/training/trainer.d.ts +16 -0
  32. package/dist/training/trainer.js +72 -0
  33. package/dist/weights/weights_container.d.ts +0 -5
  34. package/dist/weights/weights_container.js +0 -7
  35. package/package.json +1 -1
  36. package/dist/async_informant.d.ts +0 -15
  37. package/dist/async_informant.js +0 -42
  38. package/dist/training/trainer/distributed_trainer.d.ts +0 -20
  39. package/dist/training/trainer/distributed_trainer.js +0 -41
  40. package/dist/training/trainer/local_trainer.d.ts +0 -12
  41. package/dist/training/trainer/local_trainer.js +0 -24
  42. package/dist/training/trainer/trainer.d.ts +0 -32
  43. package/dist/training/trainer/trainer.js +0 -61
  44. package/dist/training/trainer/trainer_builder.d.ts +0 -23
  45. package/dist/training/trainer/trainer_builder.js +0 -47
@@ -1,5 +1,6 @@
1
1
  import { Map, Set } from 'immutable';
2
- import type { client, Model, AsyncInformant } from '../index.js';
2
+ import type { client } from '../index.js';
3
+ import { EventEmitter } from '../utils/event_emitter.js';
3
4
  export declare enum AggregationStep {
4
5
  ADD = 0,
5
6
  UPDATE = 1,
@@ -8,12 +9,13 @@ export declare enum AggregationStep {
8
9
  /**
9
10
  * Main, abstract, aggregator class whose role is to buffer contributions and to produce
10
11
  * a result based off their aggregation, whenever some defined condition is met.
12
+ *
13
+ * Emits an event whenever an aggregation step is performed.
14
+ * Users wait for this event to fetch the aggregation result.
11
15
  */
12
- export declare abstract class Base<T> {
13
- /**
14
- * The Model whose weights are updated on aggregation.
15
- */
16
- protected _model?: Model | undefined;
16
+ export declare abstract class Base<T> extends EventEmitter<{
17
+ 'aggregation': T;
18
+ }> {
17
19
  /**
18
20
  * The round cut-off for contributions.
19
21
  */
@@ -33,19 +35,6 @@ export declare abstract class Base<T> {
33
35
  * of all active nodes, depending on the aggregation scheme.
34
36
  */
35
37
  protected contributions: Map<number, Map<client.NodeID, T>>;
36
- /**
37
- * Emits the aggregation event whenever an aggregation step is performed.
38
- * Triggers the resolve of the result promise and the preparation for the
39
- * next aggregation round.
40
- */
41
- private readonly eventEmitter;
42
- protected informant?: AsyncInformant<T>;
43
- /**
44
- * The result promise which, on resolve, will contain the current aggregation result.
45
- * This promise should be fetched by any object making use of an aggregator, in order
46
- * to await upon aggregation.
47
- */
48
- protected result: Promise<T>;
49
38
  /**
50
39
  * The current aggregation round, used for assessing whether a node contribution is recent enough
51
40
  * or not.
@@ -59,10 +48,6 @@ export declare abstract class Base<T> {
59
48
  */
60
49
  protected _communicationRound: number;
61
50
  constructor(
62
- /**
63
- * The Model whose weights are updated on aggregation.
64
- */
65
- _model?: Model | undefined,
66
51
  /**
67
52
  * The round cut-off for contributions.
68
53
  */
@@ -85,7 +70,6 @@ export declare abstract class Base<T> {
85
70
  * Must store the aggregation's result in the aggregator's result promise.
86
71
  */
87
72
  abstract aggregate(): void;
88
- registerObserver(informant: AsyncInformant<T>): void;
89
73
  /**
90
74
  * Returns whether the given round is recent enough, dependent on the
91
75
  * aggregator's round cutoff.
@@ -99,11 +83,6 @@ export declare abstract class Base<T> {
99
83
  * @param from The node which triggered the logging message
100
84
  */
101
85
  log(step: AggregationStep, from?: client.NodeID): void;
102
- /**
103
- * Sets the aggregator's model.
104
- * @param model The new model
105
- */
106
- setModel(model: Model): void;
107
86
  /**
108
87
  * Adds a node's id to the set of active nodes. A node represents an active neighbor
109
88
  * peer/client within the network, whom we are communicating with during this aggregation
@@ -130,27 +109,12 @@ export declare abstract class Base<T> {
130
109
  * @param round The new round
131
110
  */
132
111
  setRound(round: number): void;
133
- /**
134
- * Emits the event containing the aggregation result, which allows the result
135
- * promise to resolve and for the next aggregation round to take place.
136
- * @param aggregated The aggregation result
137
- */
138
- protected emit(aggregated: T): void;
139
112
  /**
140
113
  * Updates the aggregator's state to proceed to the next communication round.
141
114
  * If all communication rounds were performed, proceeds to the next aggregation round
142
115
  * and empties the collection of stored contributions.
143
116
  */
144
117
  nextRound(): void;
145
- private makeResult;
146
- /**
147
- * Aggregation steps are performed asynchronously, yet can be awaited upon when required.
148
- * This function gives access to the current aggregation result's promise, which will
149
- * eventually resolve and contain the result of the very next aggregation step, at the
150
- * time of the function call.
151
- * @returns The promise containing the aggregation result
152
- */
153
- receiveResult(): Promise<T>;
154
118
  /**
155
119
  * Constructs the payloads sent to other nodes as contribution.
156
120
  * @param base Object from which the payload is computed
@@ -170,10 +134,6 @@ export declare abstract class Base<T> {
170
134
  * the amount of all active nodes times the number of communication rounds.
171
135
  */
172
136
  get size(): number;
173
- /**
174
- * The aggregator's current model.
175
- */
176
- get model(): Model | undefined;
177
137
  /**
178
138
  * The current communication round.
179
139
  */
@@ -9,9 +9,11 @@ export var AggregationStep;
9
9
  /**
10
10
  * Main, abstract, aggregator class whose role is to buffer contributions and to produce
11
11
  * a result based off their aggregation, whenever some defined condition is met.
12
+ *
13
+ * Emits an event whenever an aggregation step is performed.
14
+ * Users wait for this event to fetch the aggregation result.
12
15
  */
13
- export class Base {
14
- _model;
16
+ export class Base extends EventEmitter {
15
17
  roundCutoff;
16
18
  communicationRounds;
17
19
  /**
@@ -26,19 +28,6 @@ export class Base {
26
28
  */
27
29
  // communication round -> NodeID -> T
28
30
  contributions;
29
- /**
30
- * Emits the aggregation event whenever an aggregation step is performed.
31
- * Triggers the resolve of the result promise and the preparation for the
32
- * next aggregation round.
33
- */
34
- eventEmitter = new EventEmitter();
35
- informant;
36
- /**
37
- * The result promise which, on resolve, will contain the current aggregation result.
38
- * This promise should be fetched by any object making use of an aggregator, in order
39
- * to await upon aggregation.
40
- */
41
- result;
42
31
  /**
43
32
  * The current aggregation round, used for assessing whether a node contribution is recent enough
44
33
  * or not.
@@ -52,10 +41,6 @@ export class Base {
52
41
  */
53
42
  _communicationRound = 0;
54
43
  constructor(
55
- /**
56
- * The Model whose weights are updated on aggregation.
57
- */
58
- _model,
59
44
  /**
60
45
  * The round cut-off for contributions.
61
46
  */
@@ -64,21 +49,14 @@ export class Base {
64
49
  * The number of communication rounds occurring during any given aggregation round.
65
50
  */
66
51
  communicationRounds = 1) {
67
- this._model = _model;
52
+ super();
68
53
  this.roundCutoff = roundCutoff;
69
54
  this.communicationRounds = communicationRounds;
70
55
  this.contributions = Map();
71
56
  this._nodes = Set();
72
- // Make the initial result promise
73
- this.result = this.makeResult();
74
57
  // On every aggregation, update the object's state to match the current aggregation
75
58
  // and communication rounds.
76
- this.eventEmitter.on('aggregation', () => {
77
- this.nextRound();
78
- });
79
- }
80
- registerObserver(informant) {
81
- this.informant = informant;
59
+ this.on('aggregation', () => this.nextRound());
82
60
  }
83
61
  /**
84
62
  * Returns whether the given round is recent enough, dependent on the
@@ -115,13 +93,6 @@ export class Base {
115
93
  }
116
94
  }
117
95
  }
118
- /**
119
- * Sets the aggregator's model.
120
- * @param model The new model
121
- */
122
- setModel(model) {
123
- this._model = model;
124
- }
125
96
  /**
126
97
  * Adds a node's id to the set of active nodes. A node represents an active neighbor
127
98
  * peer/client within the network, whom we are communicating with during this aggregation
@@ -162,14 +133,6 @@ export class Base {
162
133
  this._round = round;
163
134
  }
164
135
  }
165
- /**
166
- * Emits the event containing the aggregation result, which allows the result
167
- * promise to resolve and for the next aggregation round to take place.
168
- * @param aggregated The aggregation result
169
- */
170
- emit(aggregated) {
171
- this.eventEmitter.emit('aggregation', aggregated);
172
- }
173
136
  /**
174
137
  * Updates the aggregator's state to proceed to the next communication round.
175
138
  * If all communication rounds were performed, proceeds to the next aggregation round
@@ -181,25 +144,6 @@ export class Base {
181
144
  this._round++;
182
145
  this.contributions = Map();
183
146
  }
184
- this.result = this.makeResult();
185
- this.informant?.update();
186
- }
187
- async makeResult() {
188
- return await new Promise((resolve) => {
189
- this.eventEmitter.once('aggregation', (w) => {
190
- resolve(w);
191
- });
192
- });
193
- }
194
- /**
195
- * Aggregation steps are performed asynchronously, yet can be awaited upon when required.
196
- * This function gives access to the current aggregation result's promise, which will
197
- * eventually resolve and contain the result of the very next aggregation step, at the
198
- * time of the function call.
199
- * @returns The promise containing the aggregation result
200
- */
201
- async receiveResult() {
202
- return await this.result;
203
147
  }
204
148
  /**
205
149
  * The set of node ids, representing our neighbors within the network.
@@ -223,12 +167,6 @@ export class Base {
223
167
  .map((m) => m.size)
224
168
  .reduce((totalSize, size) => totalSize + size) ?? 0;
225
169
  }
226
- /**
227
- * The aggregator's current model.
228
- */
229
- get model() {
230
- return this._model;
231
- }
232
170
  /**
233
171
  * The current communication round.
234
172
  */
@@ -1,8 +1,6 @@
1
1
  import type { Task } from '../index.js';
2
2
  import { aggregator } from '../index.js';
3
- import type { Model } from "../index.js";
4
3
  type AggregatorOptions = Partial<{
5
- model: Model;
6
4
  scheme: Task['trainingInformation']['scheme'];
7
5
  roundCutOff: number;
8
6
  threshold: number;
@@ -25,7 +25,7 @@ export function getAggregator(task, options = {}) {
25
25
  if (scheme === 'decentralized') {
26
26
  // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
27
27
  options = {
28
- model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
28
+ roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
29
29
  ...options
30
30
  };
31
31
  }
@@ -34,15 +34,15 @@ export function getAggregator(task, options = {}) {
34
34
  // so we set the aggregation threshold to 1 contribution
35
35
  // If scheme == 'local' then we only expect our own contribution
36
36
  options = {
37
- model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
37
+ roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
38
38
  ...options
39
39
  };
40
40
  }
41
- return new aggregator.MeanAggregator(options.model, options.roundCutOff, options.threshold, options.thresholdType);
41
+ return new aggregator.MeanAggregator(options.roundCutOff, options.threshold, options.thresholdType);
42
42
  case 'secure':
43
43
  if (scheme !== 'decentralized') {
44
44
  throw new Error('secure aggregation is currently supported for decentralized only');
45
45
  }
46
- return new aggregator.SecureAggregator(options.model, task.trainingInformation.maxShareValue);
46
+ return new aggregator.SecureAggregator(task.trainingInformation.maxShareValue);
47
47
  }
48
48
  }
@@ -1,6 +1,6 @@
1
1
  import type { Map } from "immutable";
2
2
  import { Base as Aggregator } from "./base.js";
3
- import type { Model, WeightsContainer, client } from "../index.js";
3
+ import type { WeightsContainer, client } from "../index.js";
4
4
  type ThresholdType = 'relative' | 'absolute';
5
5
  /**
6
6
  * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
@@ -27,7 +27,7 @@ export declare class MeanAggregator extends Aggregator<WeightsContainer> {
27
27
  * If 0 then only accept contributions from the current round,
28
28
  * if 1 then the current round and the previous one, etc.
29
29
  */
30
- constructor(model?: Model, roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
30
+ constructor(roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
31
31
  /** Checks whether the contributions buffer is full. */
32
32
  isFull(): boolean;
33
33
  add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
@@ -26,12 +26,12 @@ export class MeanAggregator extends Aggregator {
26
26
  * If 0 then only accept contributions from the current round,
27
27
  * if 1 then the current round and the previous one, etc.
28
28
  */
29
- constructor(model, roundCutoff = 0, threshold = 1, thresholdType) {
29
+ constructor(roundCutoff = 0, threshold = 1, thresholdType) {
30
30
  if (threshold <= 0)
31
31
  throw new Error("threshold must be strictly positive");
32
32
  if (threshold > 1 && (!Number.isInteger(threshold)))
33
33
  throw new Error("absolute thresholds must be integral");
34
- super(model, roundCutoff, 1);
34
+ super(roundCutoff, 1);
35
35
  this.#threshold = threshold;
36
36
  if (threshold < 1) {
37
37
  // Throw exception if threshold and thresholdType are conflicting
@@ -81,7 +81,6 @@ export class MeanAggregator extends Aggregator {
81
81
  ? AggregationStep.UPDATE
82
82
  : AggregationStep.ADD, nodeId);
83
83
  this.contributions = this.contributions.setIn([0, nodeId], contribution);
84
- this.informant?.update();
85
84
  if (this.isFull())
86
85
  this.aggregate();
87
86
  return true;
@@ -92,9 +91,7 @@ export class MeanAggregator extends Aggregator {
92
91
  throw new Error("aggregating without any contribution");
93
92
  this.log(AggregationStep.AGGREGATE);
94
93
  const result = aggregation.avg(currentContributions.values());
95
- if (this.model !== undefined)
96
- this.model.weights = result;
97
- this.emit(result);
94
+ this.emit('aggregation', result);
98
95
  }
99
96
  makePayloads(weights) {
100
97
  // Communicate our local weights to every other node, be it a peer or a server
@@ -1,6 +1,6 @@
1
1
  import { Map, List } from "immutable";
2
2
  import { Base as Aggregator } from "./base.js";
3
- import type { Model, WeightsContainer, client } from "../index.js";
3
+ import type { WeightsContainer, client } from "../index.js";
4
4
  /**
5
5
  * Aggregator implementing secure multi-party computation for decentralized learning.
6
6
  * An aggregation consists of two communication rounds:
@@ -10,7 +10,7 @@ import type { Model, WeightsContainer, client } from "../index.js";
10
10
  */
11
11
  export declare class SecureAggregator extends Aggregator<WeightsContainer> {
12
12
  private readonly maxShareValue;
13
- constructor(model?: Model, maxShareValue?: number);
13
+ constructor(maxShareValue?: number);
14
14
  aggregate(): void;
15
15
  add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
16
16
  isFull(): boolean;
@@ -11,8 +11,8 @@ import { aggregation } from "../index.js";
11
11
  */
12
12
  export class SecureAggregator extends Aggregator {
13
13
  maxShareValue;
14
- constructor(model, maxShareValue = 100) {
15
- super(model, 0, 2);
14
+ constructor(maxShareValue = 100) {
15
+ super(0, 2);
16
16
  this.maxShareValue = maxShareValue;
17
17
  }
18
18
  aggregate() {
@@ -24,7 +24,7 @@ export class SecureAggregator extends Aggregator {
24
24
  if (currentContributions === undefined)
25
25
  throw new Error("aggregating without any contribution");
26
26
  const result = aggregation.sum(currentContributions.values());
27
- this.emit(result);
27
+ this.emit('aggregation', result);
28
28
  break;
29
29
  }
30
30
  // Average the received partial sums
@@ -33,9 +33,7 @@ export class SecureAggregator extends Aggregator {
33
33
  if (currentContributions === undefined)
34
34
  throw new Error("aggregating without any contribution");
35
35
  const result = aggregation.avg(currentContributions.values());
36
- if (this.model !== undefined)
37
- this.model.weights = result;
38
- this.emit(result);
36
+ this.emit('aggregation', result);
39
37
  break;
40
38
  }
41
39
  default:
@@ -56,7 +54,6 @@ export class SecureAggregator extends Aggregator {
56
54
  ? AggregationStep.UPDATE
57
55
  : AggregationStep.ADD, nodeId);
58
56
  this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
59
- this.informant?.update();
60
57
  if (this.isFull())
61
58
  this.aggregate();
62
59
  return true;
@@ -67,8 +67,9 @@ export declare abstract class Base {
67
67
  * Communication callback called the end of every training round.
68
68
  * @param _weights The most recent local weight updates
69
69
  * @param _round The current training round
70
+ * @returns aggregated weights or the local weights upon error
70
71
  */
71
- onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
72
+ abstract onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<WeightsContainer>;
72
73
  get nbOfParticipants(): number;
73
74
  get ownId(): NodeID;
74
75
  get server(): EventConnection;
@@ -64,12 +64,6 @@ export class Base {
64
64
  * @param _round The current training round
65
65
  */
66
66
  async onRoundBeginCommunication(_weights, _round) { }
67
- /**
68
- * Communication callback called the end of every training round.
69
- * @param _weights The most recent local weight updates
70
- * @param _round The current training round
71
- */
72
- async onRoundEndCommunication(_weights, _round) { }
73
67
  // Number of contributors to a collaborative session
74
68
  // If decentralized, it should be the number of peers
75
69
  // If federated, it should the number of participants excluding the server
@@ -1,4 +1,4 @@
1
- import { type WeightsContainer } from '../../index.js';
1
+ import type { WeightsContainer } from "../../index.js";
2
2
  import { Client } from '../index.js';
3
3
  /**
4
4
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
@@ -45,5 +45,5 @@ export declare class Base extends Client {
45
45
  * @param round
46
46
  */
47
47
  private receivePayloads;
48
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
48
+ onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
49
49
  }
@@ -1,5 +1,5 @@
1
1
  import { Map, Set } from 'immutable';
2
- import { serialization } from '../../index.js';
2
+ import { serialization } from "../../index.js";
3
3
  import { Client } from '../index.js';
4
4
  import { type } from '../messages.js';
5
5
  import { timeout } from '../utils.js';
@@ -133,7 +133,7 @@ export class Base extends Client {
133
133
  }
134
134
  // Store the promise for the current round's aggregation result.
135
135
  // We will await for it to resolve at the end of the round when exchanging weight updates.
136
- this.aggregationResult = this.aggregator.receiveResult();
136
+ this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
137
137
  }
138
138
  /**
139
139
  * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
@@ -163,12 +163,15 @@ export class Base extends Client {
163
163
  });
164
164
  }
165
165
  async onRoundEndCommunication(weights, round) {
166
- let result = weights;
166
+ if (this.aggregationResult === undefined) {
167
+ throw new TypeError('aggregation result promise is undefined');
168
+ }
167
169
  // Perform the required communication rounds. Each communication round consists in sending our local payload,
168
170
  // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
169
171
  // A communication round's payload is the aggregation result of the previous communication round. The first
170
172
  // communication round simply sends our training result, i.e. model weights updates. This scheme allows for
171
173
  // the aggregator to define any complex multi-round aggregation mechanism.
174
+ let result = weights;
172
175
  for (let r = 0; r < this.aggregator.communicationRounds; r++) {
173
176
  // Generate our payloads for this communication round and send them to all ready connected peers
174
177
  if (this.connections !== undefined) {
@@ -198,9 +201,6 @@ export class Base extends Client {
198
201
  throw new Error('error while sending weights');
199
202
  }
200
203
  }
201
- if (this.aggregationResult === undefined) {
202
- throw new TypeError('aggregation result promise is undefined');
203
- }
204
204
  // Wait for aggregation before proceeding to the next communication round.
205
205
  // The current result will be used as payload for the eventual next communication round.
206
206
  try {
@@ -211,7 +211,7 @@ export class Base extends Client {
211
211
  }
212
212
  catch (e) {
213
213
  if (this.isDisconnected) {
214
- return;
214
+ return weights;
215
215
  }
216
216
  console.error(e);
217
217
  break;
@@ -219,10 +219,11 @@ export class Base extends Client {
219
219
  // There is at least one communication round remaining
220
220
  if (r < this.aggregator.communicationRounds - 1) {
221
221
  // Reuse the aggregation result
222
- this.aggregationResult = this.aggregator.receiveResult();
222
+ this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
223
223
  }
224
224
  }
225
225
  // Reset the peers list for the next round
226
226
  this.aggregator.resetNodes();
227
+ return await this.aggregationResult;
227
228
  }
228
229
  }
@@ -28,7 +28,7 @@ export declare class Base extends Client {
28
28
  */
29
29
  disconnect(): Promise<void>;
30
30
  onRoundBeginCommunication(): Promise<void>;
31
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
31
+ onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
32
32
  /**
33
33
  * Send a message containing our local weight updates to the federated server.
34
34
  * And waits for the server to reply with the most recent aggregated weights
@@ -68,7 +68,7 @@ export class Base extends Client {
68
68
  }
69
69
  onRoundBeginCommunication() {
70
70
  // Prepare the result promise for the incoming round
71
- this.aggregationResult = this.aggregator.receiveResult();
71
+ this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
72
72
  return Promise.resolve();
73
73
  }
74
74
  async onRoundEndCommunication(weights, round) {
@@ -90,6 +90,7 @@ export class Base extends Client {
90
90
  console.info(`[${this.ownId}] Server result is either stale or not received`);
91
91
  this.aggregator.nextRound();
92
92
  }
93
+ return await this.aggregationResult;
93
94
  }
94
95
  /**
95
96
  * Send a message containing our local weight updates to the federated server.
@@ -1,3 +1,5 @@
1
- import { Base } from './base.js';
1
+ import { WeightsContainer } from "../weights/weights_container.js";
2
+ import { Base } from "./base.js";
2
3
  export declare class Local extends Base {
4
+ onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
3
5
  }
@@ -1,3 +1,6 @@
1
- import { Base } from './base.js';
1
+ import { Base } from "./base.js";
2
2
  export class Local extends Base {
3
+ onRoundEndCommunication(weights) {
4
+ return Promise.resolve(weights);
5
+ }
3
6
  }
@@ -30,8 +30,7 @@ export const cifar10 = {
30
30
  IMAGE_W: 224,
31
31
  LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
32
32
  scheme: 'decentralized',
33
- noiseScale: undefined,
34
- clippingRadius: 20,
33
+ privacy: { clippingRadius: 20, noiseScale: 1 },
35
34
  decentralizedSecure: true,
36
35
  minimumReadyPeers: 3,
37
36
  maxShareValue: 100,
@@ -30,8 +30,6 @@ export const mnist = {
30
30
  preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
31
31
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
32
32
  scheme: 'decentralized',
33
- noiseScale: undefined,
34
- clippingRadius: undefined,
35
33
  decentralizedSecure: true,
36
34
  minimumReadyPeers: 3,
37
35
  maxShareValue: 100,
@@ -29,8 +29,6 @@ export const simpleFace = {
29
29
  IMAGE_W: 200,
30
30
  LABEL_LIST: ['child', 'adult'],
31
31
  scheme: 'federated', // secure aggregation not yet implemented for federated
32
- noiseScale: undefined,
33
- clippingRadius: undefined,
34
32
  tensorBackend: 'tfjs'
35
33
  }
36
34
  };
@@ -63,8 +63,6 @@ export const titanic = {
63
63
  'Survived'
64
64
  ],
65
65
  scheme: 'federated', // secure aggregation not yet implemented for FeAI
66
- noiseScale: undefined,
67
- clippingRadius: undefined,
68
66
  tensorBackend: 'tfjs'
69
67
  }
70
68
  };
package/dist/index.d.ts CHANGED
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
5
5
  export * as client from './client/index.js';
6
6
  export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
- export { AsyncInformant } from './async_informant.js';
9
8
  export { Logger, ConsoleLogger } from './logging/index.js';
10
9
  export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
11
10
  export { Disco, RoundLogs } from './training/index.js';
package/dist/index.js CHANGED
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
5
5
  export * as client from './client/index.js';
6
6
  export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
- export { AsyncInformant } from './async_informant.js';
9
8
  export { ConsoleLogger } from './logging/index.js';
10
9
  export { Memory, Empty as EmptyMemory } from './memory/index.js';
11
10
  export { Disco } from './training/index.js';
package/dist/privacy.d.ts CHANGED
@@ -1,11 +1,9 @@
1
- import type { Task, WeightsContainer } from './index.js';
1
+ import type { WeightsContainer } from "./index.js";
2
+ /** Scramble weights */
3
+ export declare function addNoise(weights: WeightsContainer, deviation: number): WeightsContainer;
2
4
  /**
3
- * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
4
- * The previous round's weights are the last weights pulled from server/peers.
5
- * The current round's weights are obtained after a single round of training, from the previous round's weights.
6
- * @param updatedWeights weights from the current round
7
- * @param staleWeights weights from the previous round
8
- * @param task the task
9
- * @returns the noised weights for the current round
10
- */
11
- export declare function addDifferentialPrivacy(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, task: Task): WeightsContainer;
5
+ * Keep weights' norm within radius
6
+ *
7
+ * @param radius maximum norm
8
+ **/
9
+ export declare function clipNorm(weights: WeightsContainer, radius: number): Promise<WeightsContainer>;