@epfml/discojs 2.1.2-p20240515133413.0 → 2.1.2-p20240528164510.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,7 @@ export class Base {
24
24
  * It defines the effective aggregation group, which is possibly a subset
25
25
  * of all active nodes, depending on the aggregation scheme.
26
26
  */
27
+ // communication round -> NodeID -> T
27
28
  contributions;
28
29
  /**
29
30
  * Emits the aggregation event whenever an aggregation step is performed.
@@ -1,23 +1,18 @@
1
- import type { Map } from 'immutable';
2
- import { Base as Aggregator } from './base.js';
3
- import type { Model, WeightsContainer, client } from '../index.js';
4
- /**
5
- * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
6
- */
1
+ import type { Map } from "immutable";
2
+ import { Base as Aggregator } from "./base.js";
3
+ import type { Model, WeightsContainer, client } from "../index.js";
4
+ /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
7
5
  export declare class MeanAggregator extends Aggregator<WeightsContainer> {
6
+ #private;
8
7
  /**
9
- * The threshold t to fulfill to trigger an aggregation step. It can either be:
10
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
11
- * - absolute: t > 1, thus requiring t contributions
8
+ * @param threshold - how many contributions for trigger an aggregation step.
9
+ * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
10
+ * - absolute: t > 1, thus requiring t contributions
12
11
  */
13
- readonly threshold: number;
14
12
  constructor(model?: Model, roundCutoff?: number, threshold?: number);
15
- /**
16
- * Checks whether the contributions buffer is full, according to the set threshold.
17
- * @returns Whether the contributions buffer is full
18
- */
13
+ /** Checks whether the contributions buffer is full. */
19
14
  isFull(): boolean;
20
- add(nodeId: client.NodeID, contribution: WeightsContainer, round: number): boolean;
15
+ add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
21
16
  aggregate(): void;
22
17
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
23
18
  }
@@ -1,65 +1,51 @@
1
- import { AggregationStep, Base as Aggregator } from './base.js';
2
- import { aggregation } from '../index.js';
3
- /**
4
- * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
5
- */
1
+ import { AggregationStep, Base as Aggregator } from "./base.js";
2
+ import { aggregation } from "../index.js";
3
+ /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
6
4
  export class MeanAggregator extends Aggregator {
5
+ #threshold;
7
6
  /**
8
- * The threshold t to fulfill to trigger an aggregation step. It can either be:
9
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
10
- * - absolute: t > 1, thus requiring t contributions
7
+ * @param threshold - how many contributions for trigger an aggregation step.
8
+ * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
9
+ * - absolute: t > 1, thus requiring t contributions
11
10
  */
12
- threshold;
11
+ // TODO no way to require a single contribution
13
12
  constructor(model, roundCutoff = 0, threshold = 1) {
13
+ if (threshold <= 0)
14
+ throw new Error("threshold must be striclty positive");
15
+ if (threshold > 1 && !Number.isInteger(threshold))
16
+ throw new Error("absolute thresholds must be integeral");
14
17
  super(model, roundCutoff, 1);
15
- // Default threshold is 100% of node participation
16
- if (threshold === undefined) {
17
- this.threshold = 1;
18
- // Threshold must be positive
19
- }
20
- else if (threshold <= 0) {
21
- throw new Error('threshold must be positive');
22
- // Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
23
- }
24
- else if (threshold > 1 && Math.round(threshold) !== threshold) {
25
- throw new Error('absolute thresholds must integers');
26
- }
27
- else {
28
- this.threshold = threshold;
29
- }
18
+ this.#threshold = threshold;
30
19
  }
31
- /**
32
- * Checks whether the contributions buffer is full, according to the set threshold.
33
- * @returns Whether the contributions buffer is full
34
- */
20
+ /** Checks whether the contributions buffer is full. */
35
21
  isFull() {
36
- if (this.threshold <= 1) {
37
- const contribs = this.contributions.get(this.communicationRound);
38
- if (contribs === undefined) {
39
- return false;
40
- }
41
- return contribs.size >= this.threshold * this.nodes.size;
42
- }
43
- return this.contributions.size >= this.threshold;
22
+ const actualThreshold = this.#threshold <= 1
23
+ ? this.#threshold * this.nodes.size
24
+ : this.#threshold;
25
+ return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
44
26
  }
45
- add(nodeId, contribution, round) {
46
- if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
47
- this.log(this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
48
- this.contributions = this.contributions.setIn([0, nodeId], contribution);
49
- this.informant?.update();
50
- if (this.isFull()) {
51
- this.aggregate();
52
- }
53
- return true;
54
- }
55
- return false;
27
+ add(nodeId, contribution, round, currentContributions = 0) {
28
+ if (currentContributions !== 0)
29
+ throw new Error("only a single communication round");
30
+ if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
31
+ return false;
32
+ this.log(this.contributions.hasIn([0, nodeId])
33
+ ? AggregationStep.UPDATE
34
+ : AggregationStep.ADD, nodeId);
35
+ this.contributions = this.contributions.setIn([0, nodeId], contribution);
36
+ this.informant?.update();
37
+ if (this.isFull())
38
+ this.aggregate();
39
+ return true;
56
40
  }
57
41
  aggregate() {
42
+ const currentContributions = this.contributions.get(0);
43
+ if (currentContributions === undefined)
44
+ throw new Error("aggregating without any contribution");
58
45
  this.log(AggregationStep.AGGREGATE);
59
- const result = aggregation.avg(this.contributions.get(0)?.values());
60
- if (this.model !== undefined) {
46
+ const result = aggregation.avg(currentContributions.values());
47
+ if (this.model !== undefined)
61
48
  this.model.weights = result;
62
- }
63
49
  this.emit(result);
64
50
  }
65
51
  makePayloads(weights) {
@@ -1,6 +1,6 @@
1
- import { Map, List } from 'immutable';
2
- import { Base as Aggregator } from './base.js';
3
- import type { Model, WeightsContainer, client } from '../index.js';
1
+ import { Map, List } from "immutable";
2
+ import { Base as Aggregator } from "./base.js";
3
+ import type { Model, WeightsContainer, client } from "../index.js";
4
4
  /**
5
5
  * Aggregator implementing secure multi-party computation for decentralized learning.
6
6
  * An aggregation consists of two communication rounds:
@@ -12,12 +12,10 @@ export declare class SecureAggregator extends Aggregator<WeightsContainer> {
12
12
  private readonly maxShareValue;
13
13
  constructor(model?: Model, maxShareValue?: number);
14
14
  aggregate(): void;
15
- add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound: number): boolean;
15
+ add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
16
16
  isFull(): boolean;
17
17
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
18
- /**
19
- * Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
20
- */
18
+ /** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
21
19
  generateAllShares(secret: WeightsContainer): List<WeightsContainer>;
22
20
  /**
23
21
  * Generates one share in the same shape as the secret that is populated with values randomly chosen from
@@ -1,7 +1,7 @@
1
- import { Map, List, Range } from 'immutable';
2
- import * as tf from '@tensorflow/tfjs';
3
- import { AggregationStep, Base as Aggregator } from './base.js';
4
- import { aggregation } from '../index.js';
1
+ import { Map, List, Range } from "immutable";
2
+ import * as tf from "@tensorflow/tfjs";
3
+ import { AggregationStep, Base as Aggregator } from "./base.js";
4
+ import { aggregation } from "../index.js";
5
5
  /**
6
6
  * Aggregator implementing secure multi-party computation for decentralized learning.
7
7
  * An aggregation consists of two communication rounds:
@@ -17,60 +17,72 @@ export class SecureAggregator extends Aggregator {
17
17
  }
18
18
  aggregate() {
19
19
  this.log(AggregationStep.AGGREGATE);
20
- if (this.communicationRound === 0) {
20
+ switch (this.communicationRound) {
21
21
  // Sum the received shares
22
- const result = aggregation.sum(this.contributions.get(0)?.values());
23
- this.emit(result);
24
- }
25
- else if (this.communicationRound === 1) {
22
+ case 0: {
23
+ const currentContributions = this.contributions.get(0);
24
+ if (currentContributions === undefined)
25
+ throw new Error("aggregating without any contribution");
26
+ const result = aggregation.sum(currentContributions.values());
27
+ this.emit(result);
28
+ break;
29
+ }
26
30
  // Average the received partial sums
27
- const result = aggregation.avg(this.contributions.get(1)?.values());
28
- if (this.model !== undefined) {
29
- this.model.weights = result;
31
+ case 1: {
32
+ const currentContributions = this.contributions.get(1);
33
+ if (currentContributions === undefined)
34
+ throw new Error("aggregating without any contribution");
35
+ const result = aggregation.avg(currentContributions.values());
36
+ if (this.model !== undefined)
37
+ this.model.weights = result;
38
+ this.emit(result);
39
+ break;
30
40
  }
31
- this.emit(result);
32
- }
33
- else {
34
- throw new Error('communication round is out of bounds');
41
+ default:
42
+ throw new Error("communication round is out of bounds");
35
43
  }
36
44
  }
37
45
  add(nodeId, contribution, round, communicationRound) {
38
- if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
39
- this.log(this.contributions.hasIn([communicationRound, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
40
- this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
41
- this.informant?.update();
42
- if (this.isFull()) {
43
- this.aggregate();
44
- }
45
- return true;
46
+ switch (communicationRound) {
47
+ case 0:
48
+ case 1:
49
+ break;
50
+ default:
51
+ throw new Error("requires communication round to be 0 or 1");
46
52
  }
47
- return false;
53
+ if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
54
+ return false;
55
+ this.log(this.contributions.hasIn([communicationRound, nodeId])
56
+ ? AggregationStep.UPDATE
57
+ : AggregationStep.ADD, nodeId);
58
+ this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
59
+ this.informant?.update();
60
+ if (this.isFull())
61
+ this.aggregate();
62
+ return true;
48
63
  }
49
64
  isFull() {
50
- const contribs = this.contributions.get(this.communicationRound);
51
- if (contribs === undefined) {
52
- return false;
53
- }
54
- return contribs.size === this.nodes.size;
65
+ return ((this.contributions.get(this.communicationRound)?.size ?? 0) ===
66
+ this.nodes.size);
55
67
  }
56
68
  makePayloads(weights) {
57
- if (this.communicationRound === 0) {
58
- const shares = this.generateAllShares(weights);
59
- // Abitrarily assign our shares to the available nodes
60
- return Map(List(this.nodes).zip(shares));
61
- }
62
- else {
69
+ switch (this.communicationRound) {
70
+ case 0: {
71
+ const shares = this.generateAllShares(weights);
72
+ // Abitrarily assign our shares to the available nodes
73
+ return Map(List(this.nodes).zip(shares));
74
+ }
63
75
  // Send our partial sum to every other nodes
64
- return this.nodes.toMap().map(() => weights);
76
+ case 1:
77
+ return this.nodes.toMap().map(() => weights);
78
+ default:
79
+ throw new Error("communication round is out of bounds");
65
80
  }
66
81
  }
67
- /**
68
- * Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
69
- */
82
+ /** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
70
83
  generateAllShares(secret) {
71
- if (this.nodes.size === 0) {
72
- throw new Error('too few participants to generate shares');
73
- }
84
+ if (this.nodes.size === 0)
85
+ throw new Error("too few participants to generate shares");
74
86
  // Generate N-1 shares
75
87
  const shares = Range(0, this.nodes.size - 1)
76
88
  .map(() => this.generateRandomShare(secret))
@@ -86,6 +98,6 @@ export class SecureAggregator extends Aggregator {
86
98
  const MAX_SEED_BITS = 47;
87
99
  const random = crypto.getRandomValues(new BigInt64Array(1))[0];
88
100
  const seed = Number(BigInt.asUintN(MAX_SEED_BITS, random));
89
- return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, 'float32', seed));
101
+ return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, "float32", seed));
90
102
  }
91
103
  }
@@ -1,7 +1,7 @@
1
1
  import { type client, type MetadataKey, type MetadataValue } from '../../index.js';
2
2
  import { type weights } from '../../serialization/index.js';
3
3
  import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
4
- export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | RequestServerStatistics | ReceiveServerStatistics | ReceiveServerMetadata | AssignNodeID;
4
+ export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | ReceiveServerMetadata | AssignNodeID;
5
5
  export interface SendPayload {
6
6
  type: type.SendPayload;
7
7
  payload: weights.Encoded;
@@ -12,13 +12,6 @@ export interface ReceiveServerPayload {
12
12
  payload: weights.Encoded;
13
13
  round: number;
14
14
  }
15
- export interface RequestServerStatistics {
16
- type: type.RequestServerStatistics;
17
- }
18
- export interface ReceiveServerStatistics {
19
- type: type.ReceiveServerStatistics;
20
- statistics: Record<string, number>;
21
- }
22
15
  export interface ReceiveServerMetadata {
23
16
  type: type.ReceiveServerMetadata;
24
17
  nodeId: client.NodeID;
@@ -5,20 +5,11 @@ export function isMessageFederated(raw) {
5
5
  }
6
6
  switch (raw.type) {
7
7
  case type.ClientConnected:
8
- return true;
9
8
  case type.SendPayload:
10
- return true;
11
9
  case type.ReceiveServerPayload:
12
- return true;
13
- case type.RequestServerStatistics:
14
- return true;
15
- case type.ReceiveServerStatistics:
16
- return true;
17
10
  case type.ReceiveServerMetadata:
18
- return true;
19
11
  case type.AssignNodeID:
20
12
  return true;
21
- default:
22
- return false;
23
13
  }
14
+ return false;
24
15
  }
@@ -10,9 +10,7 @@ export declare enum type {
10
10
  Payload = 5,
11
11
  SendPayload = 6,
12
12
  ReceiveServerMetadata = 7,
13
- ReceiveServerPayload = 8,
14
- RequestServerStatistics = 9,
15
- ReceiveServerStatistics = 10
13
+ ReceiveServerPayload = 8
16
14
  }
17
15
  export interface ClientConnected {
18
16
  type: type.ClientConnected;
@@ -11,8 +11,6 @@ export var type;
11
11
  type[type["SendPayload"] = 6] = "SendPayload";
12
12
  type[type["ReceiveServerMetadata"] = 7] = "ReceiveServerMetadata";
13
13
  type[type["ReceiveServerPayload"] = 8] = "ReceiveServerPayload";
14
- type[type["RequestServerStatistics"] = 9] = "RequestServerStatistics";
15
- type[type["ReceiveServerStatistics"] = 10] = "ReceiveServerStatistics";
16
14
  })(type || (type = {}));
17
15
  export function hasMessageType(raw) {
18
16
  if (typeof raw !== 'object' || raw === null) {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.1.2-p20240515133413.0",
3
+ "version": "2.1.2-p20240528164510.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",