@epfml/discojs 2.1.2-p20240515133413.0 → 2.1.2-p20240528164510.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.js +1 -0
- package/dist/aggregator/mean.d.ts +10 -15
- package/dist/aggregator/mean.js +36 -50
- package/dist/aggregator/secure.d.ts +5 -7
- package/dist/aggregator/secure.js +56 -44
- package/dist/client/federated/messages.d.ts +1 -8
- package/dist/client/federated/messages.js +1 -10
- package/dist/client/messages.d.ts +1 -3
- package/dist/client/messages.js +0 -2
- package/package.json +1 -1
package/dist/aggregator/base.js
CHANGED
|
@@ -24,6 +24,7 @@ export class Base {
|
|
|
24
24
|
* It defines the effective aggregation group, which is possibly a subset
|
|
25
25
|
* of all active nodes, depending on the aggregation scheme.
|
|
26
26
|
*/
|
|
27
|
+
// communication round -> NodeID -> T
|
|
27
28
|
contributions;
|
|
28
29
|
/**
|
|
29
30
|
* Emits the aggregation event whenever an aggregation step is performed.
|
|
@@ -1,23 +1,18 @@
|
|
|
1
|
-
import type { Map } from
|
|
2
|
-
import { Base as Aggregator } from
|
|
3
|
-
import type { Model, WeightsContainer, client } from
|
|
4
|
-
/**
|
|
5
|
-
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
6
|
-
*/
|
|
1
|
+
import type { Map } from "immutable";
|
|
2
|
+
import { Base as Aggregator } from "./base.js";
|
|
3
|
+
import type { Model, WeightsContainer, client } from "../index.js";
|
|
4
|
+
/** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
|
|
7
5
|
export declare class MeanAggregator extends Aggregator<WeightsContainer> {
|
|
6
|
+
#private;
|
|
8
7
|
/**
|
|
9
|
-
*
|
|
10
|
-
*
|
|
11
|
-
*
|
|
8
|
+
* @param threshold - how many contributions for trigger an aggregation step.
|
|
9
|
+
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
|
|
10
|
+
* - absolute: t > 1, thus requiring t contributions
|
|
12
11
|
*/
|
|
13
|
-
readonly threshold: number;
|
|
14
12
|
constructor(model?: Model, roundCutoff?: number, threshold?: number);
|
|
15
|
-
/**
|
|
16
|
-
* Checks whether the contributions buffer is full, according to the set threshold.
|
|
17
|
-
* @returns Whether the contributions buffer is full
|
|
18
|
-
*/
|
|
13
|
+
/** Checks whether the contributions buffer is full. */
|
|
19
14
|
isFull(): boolean;
|
|
20
|
-
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number): boolean;
|
|
15
|
+
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
|
|
21
16
|
aggregate(): void;
|
|
22
17
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
23
18
|
}
|
package/dist/aggregator/mean.js
CHANGED
|
@@ -1,65 +1,51 @@
|
|
|
1
|
-
import { AggregationStep, Base as Aggregator } from
|
|
2
|
-
import { aggregation } from
|
|
3
|
-
/**
|
|
4
|
-
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
5
|
-
*/
|
|
1
|
+
import { AggregationStep, Base as Aggregator } from "./base.js";
|
|
2
|
+
import { aggregation } from "../index.js";
|
|
3
|
+
/** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
|
|
6
4
|
export class MeanAggregator extends Aggregator {
|
|
5
|
+
#threshold;
|
|
7
6
|
/**
|
|
8
|
-
*
|
|
9
|
-
*
|
|
10
|
-
*
|
|
7
|
+
* @param threshold - how many contributions for trigger an aggregation step.
|
|
8
|
+
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
|
|
9
|
+
* - absolute: t > 1, thus requiring t contributions
|
|
11
10
|
*/
|
|
12
|
-
|
|
11
|
+
// TODO no way to require a single contribution
|
|
13
12
|
constructor(model, roundCutoff = 0, threshold = 1) {
|
|
13
|
+
if (threshold <= 0)
|
|
14
|
+
throw new Error("threshold must be striclty positive");
|
|
15
|
+
if (threshold > 1 && !Number.isInteger(threshold))
|
|
16
|
+
throw new Error("absolute thresholds must be integeral");
|
|
14
17
|
super(model, roundCutoff, 1);
|
|
15
|
-
|
|
16
|
-
if (threshold === undefined) {
|
|
17
|
-
this.threshold = 1;
|
|
18
|
-
// Threshold must be positive
|
|
19
|
-
}
|
|
20
|
-
else if (threshold <= 0) {
|
|
21
|
-
throw new Error('threshold must be positive');
|
|
22
|
-
// Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
|
|
23
|
-
}
|
|
24
|
-
else if (threshold > 1 && Math.round(threshold) !== threshold) {
|
|
25
|
-
throw new Error('absolute thresholds must integers');
|
|
26
|
-
}
|
|
27
|
-
else {
|
|
28
|
-
this.threshold = threshold;
|
|
29
|
-
}
|
|
18
|
+
this.#threshold = threshold;
|
|
30
19
|
}
|
|
31
|
-
/**
|
|
32
|
-
* Checks whether the contributions buffer is full, according to the set threshold.
|
|
33
|
-
* @returns Whether the contributions buffer is full
|
|
34
|
-
*/
|
|
20
|
+
/** Checks whether the contributions buffer is full. */
|
|
35
21
|
isFull() {
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
}
|
|
41
|
-
return contribs.size >= this.threshold * this.nodes.size;
|
|
42
|
-
}
|
|
43
|
-
return this.contributions.size >= this.threshold;
|
|
22
|
+
const actualThreshold = this.#threshold <= 1
|
|
23
|
+
? this.#threshold * this.nodes.size
|
|
24
|
+
: this.#threshold;
|
|
25
|
+
return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
|
|
44
26
|
}
|
|
45
|
-
add(nodeId, contribution, round) {
|
|
46
|
-
if (
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
27
|
+
add(nodeId, contribution, round, currentContributions = 0) {
|
|
28
|
+
if (currentContributions !== 0)
|
|
29
|
+
throw new Error("only a single communication round");
|
|
30
|
+
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
|
|
31
|
+
return false;
|
|
32
|
+
this.log(this.contributions.hasIn([0, nodeId])
|
|
33
|
+
? AggregationStep.UPDATE
|
|
34
|
+
: AggregationStep.ADD, nodeId);
|
|
35
|
+
this.contributions = this.contributions.setIn([0, nodeId], contribution);
|
|
36
|
+
this.informant?.update();
|
|
37
|
+
if (this.isFull())
|
|
38
|
+
this.aggregate();
|
|
39
|
+
return true;
|
|
56
40
|
}
|
|
57
41
|
aggregate() {
|
|
42
|
+
const currentContributions = this.contributions.get(0);
|
|
43
|
+
if (currentContributions === undefined)
|
|
44
|
+
throw new Error("aggregating without any contribution");
|
|
58
45
|
this.log(AggregationStep.AGGREGATE);
|
|
59
|
-
const result = aggregation.avg(
|
|
60
|
-
if (this.model !== undefined)
|
|
46
|
+
const result = aggregation.avg(currentContributions.values());
|
|
47
|
+
if (this.model !== undefined)
|
|
61
48
|
this.model.weights = result;
|
|
62
|
-
}
|
|
63
49
|
this.emit(result);
|
|
64
50
|
}
|
|
65
51
|
makePayloads(weights) {
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { Map, List } from
|
|
2
|
-
import { Base as Aggregator } from
|
|
3
|
-
import type { Model, WeightsContainer, client } from
|
|
1
|
+
import { Map, List } from "immutable";
|
|
2
|
+
import { Base as Aggregator } from "./base.js";
|
|
3
|
+
import type { Model, WeightsContainer, client } from "../index.js";
|
|
4
4
|
/**
|
|
5
5
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
6
6
|
* An aggregation consists of two communication rounds:
|
|
@@ -12,12 +12,10 @@ export declare class SecureAggregator extends Aggregator<WeightsContainer> {
|
|
|
12
12
|
private readonly maxShareValue;
|
|
13
13
|
constructor(model?: Model, maxShareValue?: number);
|
|
14
14
|
aggregate(): void;
|
|
15
|
-
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound
|
|
15
|
+
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
|
|
16
16
|
isFull(): boolean;
|
|
17
17
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
18
|
-
/**
|
|
19
|
-
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
|
|
20
|
-
*/
|
|
18
|
+
/** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
|
|
21
19
|
generateAllShares(secret: WeightsContainer): List<WeightsContainer>;
|
|
22
20
|
/**
|
|
23
21
|
* Generates one share in the same shape as the secret that is populated with values randomly chosen from
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { Map, List, Range } from
|
|
2
|
-
import * as tf from
|
|
3
|
-
import { AggregationStep, Base as Aggregator } from
|
|
4
|
-
import { aggregation } from
|
|
1
|
+
import { Map, List, Range } from "immutable";
|
|
2
|
+
import * as tf from "@tensorflow/tfjs";
|
|
3
|
+
import { AggregationStep, Base as Aggregator } from "./base.js";
|
|
4
|
+
import { aggregation } from "../index.js";
|
|
5
5
|
/**
|
|
6
6
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
7
7
|
* An aggregation consists of two communication rounds:
|
|
@@ -17,60 +17,72 @@ export class SecureAggregator extends Aggregator {
|
|
|
17
17
|
}
|
|
18
18
|
aggregate() {
|
|
19
19
|
this.log(AggregationStep.AGGREGATE);
|
|
20
|
-
|
|
20
|
+
switch (this.communicationRound) {
|
|
21
21
|
// Sum the received shares
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
case 0: {
|
|
23
|
+
const currentContributions = this.contributions.get(0);
|
|
24
|
+
if (currentContributions === undefined)
|
|
25
|
+
throw new Error("aggregating without any contribution");
|
|
26
|
+
const result = aggregation.sum(currentContributions.values());
|
|
27
|
+
this.emit(result);
|
|
28
|
+
break;
|
|
29
|
+
}
|
|
26
30
|
// Average the received partial sums
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
31
|
+
case 1: {
|
|
32
|
+
const currentContributions = this.contributions.get(1);
|
|
33
|
+
if (currentContributions === undefined)
|
|
34
|
+
throw new Error("aggregating without any contribution");
|
|
35
|
+
const result = aggregation.avg(currentContributions.values());
|
|
36
|
+
if (this.model !== undefined)
|
|
37
|
+
this.model.weights = result;
|
|
38
|
+
this.emit(result);
|
|
39
|
+
break;
|
|
30
40
|
}
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
else {
|
|
34
|
-
throw new Error('communication round is out of bounds');
|
|
41
|
+
default:
|
|
42
|
+
throw new Error("communication round is out of bounds");
|
|
35
43
|
}
|
|
36
44
|
}
|
|
37
45
|
add(nodeId, contribution, round, communicationRound) {
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
}
|
|
45
|
-
return true;
|
|
46
|
+
switch (communicationRound) {
|
|
47
|
+
case 0:
|
|
48
|
+
case 1:
|
|
49
|
+
break;
|
|
50
|
+
default:
|
|
51
|
+
throw new Error("requires communication round to be 0 or 1");
|
|
46
52
|
}
|
|
47
|
-
|
|
53
|
+
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
|
|
54
|
+
return false;
|
|
55
|
+
this.log(this.contributions.hasIn([communicationRound, nodeId])
|
|
56
|
+
? AggregationStep.UPDATE
|
|
57
|
+
: AggregationStep.ADD, nodeId);
|
|
58
|
+
this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
|
|
59
|
+
this.informant?.update();
|
|
60
|
+
if (this.isFull())
|
|
61
|
+
this.aggregate();
|
|
62
|
+
return true;
|
|
48
63
|
}
|
|
49
64
|
isFull() {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
return false;
|
|
53
|
-
}
|
|
54
|
-
return contribs.size === this.nodes.size;
|
|
65
|
+
return ((this.contributions.get(this.communicationRound)?.size ?? 0) ===
|
|
66
|
+
this.nodes.size);
|
|
55
67
|
}
|
|
56
68
|
makePayloads(weights) {
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
69
|
+
switch (this.communicationRound) {
|
|
70
|
+
case 0: {
|
|
71
|
+
const shares = this.generateAllShares(weights);
|
|
72
|
+
// Abitrarily assign our shares to the available nodes
|
|
73
|
+
return Map(List(this.nodes).zip(shares));
|
|
74
|
+
}
|
|
63
75
|
// Send our partial sum to every other nodes
|
|
64
|
-
|
|
76
|
+
case 1:
|
|
77
|
+
return this.nodes.toMap().map(() => weights);
|
|
78
|
+
default:
|
|
79
|
+
throw new Error("communication round is out of bounds");
|
|
65
80
|
}
|
|
66
81
|
}
|
|
67
|
-
/**
|
|
68
|
-
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
|
|
69
|
-
*/
|
|
82
|
+
/** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
|
|
70
83
|
generateAllShares(secret) {
|
|
71
|
-
if (this.nodes.size === 0)
|
|
72
|
-
throw new Error(
|
|
73
|
-
}
|
|
84
|
+
if (this.nodes.size === 0)
|
|
85
|
+
throw new Error("too few participants to generate shares");
|
|
74
86
|
// Generate N-1 shares
|
|
75
87
|
const shares = Range(0, this.nodes.size - 1)
|
|
76
88
|
.map(() => this.generateRandomShare(secret))
|
|
@@ -86,6 +98,6 @@ export class SecureAggregator extends Aggregator {
|
|
|
86
98
|
const MAX_SEED_BITS = 47;
|
|
87
99
|
const random = crypto.getRandomValues(new BigInt64Array(1))[0];
|
|
88
100
|
const seed = Number(BigInt.asUintN(MAX_SEED_BITS, random));
|
|
89
|
-
return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue,
|
|
101
|
+
return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, "float32", seed));
|
|
90
102
|
}
|
|
91
103
|
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { type client, type MetadataKey, type MetadataValue } from '../../index.js';
|
|
2
2
|
import { type weights } from '../../serialization/index.js';
|
|
3
3
|
import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
|
|
4
|
-
export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload |
|
|
4
|
+
export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | ReceiveServerMetadata | AssignNodeID;
|
|
5
5
|
export interface SendPayload {
|
|
6
6
|
type: type.SendPayload;
|
|
7
7
|
payload: weights.Encoded;
|
|
@@ -12,13 +12,6 @@ export interface ReceiveServerPayload {
|
|
|
12
12
|
payload: weights.Encoded;
|
|
13
13
|
round: number;
|
|
14
14
|
}
|
|
15
|
-
export interface RequestServerStatistics {
|
|
16
|
-
type: type.RequestServerStatistics;
|
|
17
|
-
}
|
|
18
|
-
export interface ReceiveServerStatistics {
|
|
19
|
-
type: type.ReceiveServerStatistics;
|
|
20
|
-
statistics: Record<string, number>;
|
|
21
|
-
}
|
|
22
15
|
export interface ReceiveServerMetadata {
|
|
23
16
|
type: type.ReceiveServerMetadata;
|
|
24
17
|
nodeId: client.NodeID;
|
|
@@ -5,20 +5,11 @@ export function isMessageFederated(raw) {
|
|
|
5
5
|
}
|
|
6
6
|
switch (raw.type) {
|
|
7
7
|
case type.ClientConnected:
|
|
8
|
-
return true;
|
|
9
8
|
case type.SendPayload:
|
|
10
|
-
return true;
|
|
11
9
|
case type.ReceiveServerPayload:
|
|
12
|
-
return true;
|
|
13
|
-
case type.RequestServerStatistics:
|
|
14
|
-
return true;
|
|
15
|
-
case type.ReceiveServerStatistics:
|
|
16
|
-
return true;
|
|
17
10
|
case type.ReceiveServerMetadata:
|
|
18
|
-
return true;
|
|
19
11
|
case type.AssignNodeID:
|
|
20
12
|
return true;
|
|
21
|
-
default:
|
|
22
|
-
return false;
|
|
23
13
|
}
|
|
14
|
+
return false;
|
|
24
15
|
}
|
|
@@ -10,9 +10,7 @@ export declare enum type {
|
|
|
10
10
|
Payload = 5,
|
|
11
11
|
SendPayload = 6,
|
|
12
12
|
ReceiveServerMetadata = 7,
|
|
13
|
-
ReceiveServerPayload = 8
|
|
14
|
-
RequestServerStatistics = 9,
|
|
15
|
-
ReceiveServerStatistics = 10
|
|
13
|
+
ReceiveServerPayload = 8
|
|
16
14
|
}
|
|
17
15
|
export interface ClientConnected {
|
|
18
16
|
type: type.ClientConnected;
|
package/dist/client/messages.js
CHANGED
|
@@ -11,8 +11,6 @@ export var type;
|
|
|
11
11
|
type[type["SendPayload"] = 6] = "SendPayload";
|
|
12
12
|
type[type["ReceiveServerMetadata"] = 7] = "ReceiveServerMetadata";
|
|
13
13
|
type[type["ReceiveServerPayload"] = 8] = "ReceiveServerPayload";
|
|
14
|
-
type[type["RequestServerStatistics"] = 9] = "RequestServerStatistics";
|
|
15
|
-
type[type["ReceiveServerStatistics"] = 10] = "ReceiveServerStatistics";
|
|
16
14
|
})(type || (type = {}));
|
|
17
15
|
export function hasMessageType(raw) {
|
|
18
16
|
if (typeof raw !== 'object' || raw === null) {
|