@epfml/discojs 3.0.1-p20240902100041.0 → 3.0.1-p20240904094219.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +16 -2
- package/dist/aggregator/base.js +25 -3
- package/dist/aggregator/mean.d.ts +1 -0
- package/dist/aggregator/mean.js +11 -6
- package/dist/aggregator/secure.js +1 -1
- package/dist/client/{base.d.ts → client.d.ts} +13 -30
- package/dist/client/{base.js → client.js} +10 -20
- package/dist/client/decentralized/{base.d.ts → decentralized_client.d.ts} +5 -5
- package/dist/client/decentralized/{base.js → decentralized_client.js} +20 -16
- package/dist/client/decentralized/index.d.ts +1 -1
- package/dist/client/decentralized/index.js +1 -1
- package/dist/client/decentralized/messages.d.ts +7 -2
- package/dist/client/decentralized/messages.js +4 -2
- package/dist/client/event_connection.js +2 -2
- package/dist/client/federated/federated_client.d.ts +44 -0
- package/dist/client/federated/federated_client.js +210 -0
- package/dist/client/federated/index.d.ts +1 -1
- package/dist/client/federated/index.js +1 -1
- package/dist/client/federated/messages.d.ts +17 -2
- package/dist/client/federated/messages.js +3 -1
- package/dist/client/index.d.ts +2 -2
- package/dist/client/index.js +2 -2
- package/dist/client/local_client.d.ts +10 -0
- package/dist/client/local_client.js +14 -0
- package/dist/client/messages.d.ts +6 -8
- package/dist/client/messages.js +23 -7
- package/dist/client/utils.js +1 -1
- package/dist/default_tasks/cifar10.js +1 -2
- package/dist/default_tasks/lus_covid.js +1 -1
- package/dist/default_tasks/mnist.js +1 -2
- package/dist/default_tasks/simple_face.js +2 -2
- package/dist/default_tasks/titanic.js +2 -2
- package/dist/default_tasks/wikitext.js +1 -1
- package/dist/index.d.ts +4 -2
- package/dist/index.js +1 -1
- package/dist/logging/logger.d.ts +1 -1
- package/dist/serialization/model.js +18 -9
- package/dist/task/index.d.ts +0 -1
- package/dist/task/index.js +0 -1
- package/dist/task/task.d.ts +0 -2
- package/dist/task/task.js +2 -4
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +3 -5
- package/dist/training/disco.d.ts +14 -16
- package/dist/training/disco.js +22 -46
- package/dist/training/index.d.ts +1 -1
- package/dist/training/trainer.d.ts +3 -2
- package/dist/training/trainer.js +12 -5
- package/dist/utils/event_emitter.js +1 -3
- package/package.json +1 -1
- package/dist/client/federated/base.d.ts +0 -38
- package/dist/client/federated/base.js +0 -130
- package/dist/client/local.d.ts +0 -5
- package/dist/client/local.js +0 -6
- package/dist/memory/base.d.ts +0 -111
- package/dist/memory/base.js +0 -9
- package/dist/memory/empty.d.ts +0 -20
- package/dist/memory/empty.js +0 -43
- package/dist/memory/index.d.ts +0 -2
- package/dist/memory/index.js +0 -2
- package/dist/task/digest.d.ts +0 -5
- package/dist/task/digest.js +0 -14
|
@@ -58,13 +58,22 @@ export declare abstract class Base<T> extends EventEmitter<{
|
|
|
58
58
|
communicationRounds?: number);
|
|
59
59
|
/**
|
|
60
60
|
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
|
|
61
|
+
* The aggregation round is increased whenever a new global model is obtained and local models are updated.
|
|
62
|
+
* Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation
|
|
63
|
+
* which requires multiple steps to obtain a global model)
|
|
61
64
|
* The contribution will be aggregated during the next aggregation step.
|
|
62
65
|
* @param nodeId The node's id
|
|
63
66
|
* @param contribution The node's contribution
|
|
64
|
-
* @param round
|
|
65
|
-
* @param communicationRound
|
|
67
|
+
* @param round aggregation round of the contribution was made
|
|
68
|
+
* @param communicationRound communication round the contribution was made within the aggregation round
|
|
69
|
+
* @returns boolean, true if the contribution has been successfully taken into account or False if it has been rejected
|
|
66
70
|
*/
|
|
67
71
|
abstract add(nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean;
|
|
72
|
+
/**
|
|
73
|
+
* Evaluates whether a given participant contribution can be used in the current aggregation round
|
|
74
|
+
* the boolean returned by `this.add` is obtained via `this.isValidContribution`
|
|
75
|
+
*/
|
|
76
|
+
isValidContribution(nodeId: client.NodeID, round: number): boolean;
|
|
68
77
|
/**
|
|
69
78
|
* Performs an aggregation step over the received node contributions.
|
|
70
79
|
* Must store the aggregation's result in the aggregator's result promise.
|
|
@@ -91,6 +100,11 @@ export declare abstract class Base<T> extends EventEmitter<{
|
|
|
91
100
|
* @returns True is the node wasn't already in the list of nodes, False if already included
|
|
92
101
|
*/
|
|
93
102
|
registerNode(nodeId: client.NodeID): boolean;
|
|
103
|
+
/**
|
|
104
|
+
* Remove a node's id from the set of active nodes.
|
|
105
|
+
* @param nodeId The node to be removed
|
|
106
|
+
*/
|
|
107
|
+
removeNode(nodeId: client.NodeID): void;
|
|
94
108
|
/**
|
|
95
109
|
* Overwrites the current set of active nodes with the given one. A node represents
|
|
96
110
|
* an active neighbor peer/client within the network, whom we are communicating with
|
package/dist/aggregator/base.js
CHANGED
|
@@ -60,6 +60,21 @@ export class Base extends EventEmitter {
|
|
|
60
60
|
// and communication rounds.
|
|
61
61
|
this.on('aggregation', () => this.nextRound());
|
|
62
62
|
}
|
|
63
|
+
/**
|
|
64
|
+
* Evaluates whether a given participant contribution can be used in the current aggregation round
|
|
65
|
+
* the boolean returned by `this.add` is obtained via `this.isValidContribution`
|
|
66
|
+
*/
|
|
67
|
+
isValidContribution(nodeId, round) {
|
|
68
|
+
if (!this.nodes.has(nodeId)) {
|
|
69
|
+
debug("Contribution rejected because node id is not registered");
|
|
70
|
+
return false;
|
|
71
|
+
}
|
|
72
|
+
if (!this.isWithinRoundCutoff(round)) {
|
|
73
|
+
debug(`Contribution rejected because round ${round} is not within round cutoff`);
|
|
74
|
+
return false;
|
|
75
|
+
}
|
|
76
|
+
return true;
|
|
77
|
+
}
|
|
63
78
|
/**
|
|
64
79
|
* Returns whether the given round is recent enough, dependent on the
|
|
65
80
|
* aggregator's round cutoff.
|
|
@@ -77,16 +92,16 @@ export class Base extends EventEmitter {
|
|
|
77
92
|
log(step, from) {
|
|
78
93
|
switch (step) {
|
|
79
94
|
case AggregationStep.ADD:
|
|
80
|
-
debug(`
|
|
95
|
+
debug(`Adding contribution from node ${from ?? '"unknown"'} for aggregation round ${this.round} and communication round ${this.communicationRound}`);
|
|
81
96
|
break;
|
|
82
97
|
case AggregationStep.UPDATE:
|
|
83
98
|
if (from === undefined) {
|
|
84
99
|
return;
|
|
85
100
|
}
|
|
86
|
-
debug(`
|
|
101
|
+
debug(`Updating contribution from node ${from} for aggregation round ${this.round} and communication round ${this.communicationRound}`);
|
|
87
102
|
break;
|
|
88
103
|
case AggregationStep.AGGREGATE:
|
|
89
|
-
debug(`
|
|
104
|
+
debug(`Buffer is full. Aggregating weights for round aggregation round ${this.round} and communication round ${this.communicationRound}`);
|
|
90
105
|
break;
|
|
91
106
|
default: {
|
|
92
107
|
const _ = step;
|
|
@@ -108,6 +123,13 @@ export class Base extends EventEmitter {
|
|
|
108
123
|
}
|
|
109
124
|
return false;
|
|
110
125
|
}
|
|
126
|
+
/**
|
|
127
|
+
* Remove a node's id from the set of active nodes.
|
|
128
|
+
* @param nodeId The node to be removed
|
|
129
|
+
*/
|
|
130
|
+
removeNode(nodeId) {
|
|
131
|
+
this._nodes = this._nodes.delete(nodeId);
|
|
132
|
+
}
|
|
111
133
|
/**
|
|
112
134
|
* Overwrites the current set of active nodes with the given one. A node represents
|
|
113
135
|
* an active neighbor peer/client within the network, whom we are communicating with
|
|
@@ -30,6 +30,7 @@ export declare class MeanAggregator extends Aggregator<WeightsContainer> {
|
|
|
30
30
|
constructor(roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
|
|
31
31
|
/** Checks whether the contributions buffer is full. */
|
|
32
32
|
isFull(): boolean;
|
|
33
|
+
set minNbOfParticipants(minNbOfParticipants: number);
|
|
33
34
|
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
|
|
34
35
|
aggregate(): void;
|
|
35
36
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
package/dist/aggregator/mean.js
CHANGED
|
@@ -9,6 +9,7 @@ const debug = createDebug("discojs:aggregator:mean");
|
|
|
9
9
|
export class MeanAggregator extends Aggregator {
|
|
10
10
|
#threshold;
|
|
11
11
|
#thresholdType;
|
|
12
|
+
#minNbOfParticipants;
|
|
12
13
|
/**
|
|
13
14
|
* Create a mean aggregator that averages all weight updates received when a specified threshold is met.
|
|
14
15
|
* By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
|
|
@@ -65,21 +66,24 @@ export class MeanAggregator extends Aggregator {
|
|
|
65
66
|
}
|
|
66
67
|
/** Checks whether the contributions buffer is full. */
|
|
67
68
|
isFull() {
|
|
69
|
+
// Make sure that we are over the minimum number of participants
|
|
70
|
+
// if specified
|
|
71
|
+
if (this.#minNbOfParticipants !== undefined &&
|
|
72
|
+
this.nodes.size < this.#minNbOfParticipants)
|
|
73
|
+
return false;
|
|
68
74
|
const thresholdValue = this.#thresholdType == 'relative'
|
|
69
75
|
? this.#threshold * this.nodes.size
|
|
70
76
|
: this.#threshold;
|
|
71
77
|
return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
|
|
72
78
|
}
|
|
79
|
+
set minNbOfParticipants(minNbOfParticipants) {
|
|
80
|
+
this.#minNbOfParticipants = minNbOfParticipants;
|
|
81
|
+
}
|
|
73
82
|
add(nodeId, contribution, round, currentContributions = 0) {
|
|
74
83
|
if (currentContributions !== 0)
|
|
75
84
|
throw new Error("only a single communication round");
|
|
76
|
-
if (!this.
|
|
77
|
-
if (!this.nodes.has(nodeId))
|
|
78
|
-
debug(`contribution rejected because node ${nodeId} is not registered`);
|
|
79
|
-
if (!this.isWithinRoundCutoff(round))
|
|
80
|
-
debug(`contribution rejected because round ${round} is not within cutoff`);
|
|
85
|
+
if (!this.isValidContribution(nodeId, round))
|
|
81
86
|
return false;
|
|
82
|
-
}
|
|
83
87
|
this.log(this.contributions.hasIn([0, nodeId])
|
|
84
88
|
? AggregationStep.UPDATE
|
|
85
89
|
: AggregationStep.ADD, nodeId);
|
|
@@ -94,6 +98,7 @@ export class MeanAggregator extends Aggregator {
|
|
|
94
98
|
throw new Error("aggregating without any contribution");
|
|
95
99
|
this.log(AggregationStep.AGGREGATE);
|
|
96
100
|
const result = aggregation.avg(currentContributions.values());
|
|
101
|
+
// Emitting the event runs the superclass' callback to increment the round
|
|
97
102
|
this.emit('aggregation', result);
|
|
98
103
|
}
|
|
99
104
|
makePayloads(weights) {
|
|
@@ -48,7 +48,7 @@ export class SecureAggregator extends Aggregator {
|
|
|
48
48
|
default:
|
|
49
49
|
throw new Error("requires communication round to be 0 or 1");
|
|
50
50
|
}
|
|
51
|
-
if (!this.
|
|
51
|
+
if (!this.isValidContribution(nodeId, round))
|
|
52
52
|
return false;
|
|
53
53
|
this.log(this.contributions.hasIn([communicationRound, nodeId])
|
|
54
54
|
? AggregationStep.UPDATE
|
|
@@ -1,23 +1,17 @@
|
|
|
1
|
-
import type { Model, Task, WeightsContainer } from '../index.js';
|
|
1
|
+
import type { Model, Task, WeightsContainer, RoundStatus } from '../index.js';
|
|
2
2
|
import type { NodeID } from './types.js';
|
|
3
3
|
import type { EventConnection } from './event_connection.js';
|
|
4
4
|
import type { Aggregator } from '../aggregator/index.js';
|
|
5
|
+
import { EventEmitter } from '../utils/event_emitter.js';
|
|
5
6
|
/**
|
|
6
7
|
* Main, abstract, class representing a Disco client in a network, which handles
|
|
7
8
|
* communication with other nodes, be it peers or a server.
|
|
8
9
|
*/
|
|
9
|
-
export declare abstract class
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
*/
|
|
10
|
+
export declare abstract class Client extends EventEmitter<{
|
|
11
|
+
'status': RoundStatus;
|
|
12
|
+
}> {
|
|
13
13
|
readonly url: URL;
|
|
14
|
-
/**
|
|
15
|
-
* The client's corresponding task.
|
|
16
|
-
*/
|
|
17
14
|
readonly task: Task;
|
|
18
|
-
/**
|
|
19
|
-
* The client's aggregator.
|
|
20
|
-
*/
|
|
21
15
|
readonly aggregator: Aggregator;
|
|
22
16
|
/**
|
|
23
17
|
* Own ID provided by the network's server.
|
|
@@ -31,23 +25,15 @@ export declare abstract class Base {
|
|
|
31
25
|
* The aggregator's result produced after aggregation.
|
|
32
26
|
*/
|
|
33
27
|
protected aggregationResult?: Promise<WeightsContainer>;
|
|
34
|
-
constructor(
|
|
35
|
-
|
|
36
|
-
* The network server's URL to connect to.
|
|
37
|
-
*/
|
|
38
|
-
url: URL,
|
|
39
|
-
/**
|
|
40
|
-
* The client's corresponding task.
|
|
41
|
-
*/
|
|
42
|
-
task: Task,
|
|
43
|
-
/**
|
|
44
|
-
* The client's aggregator.
|
|
45
|
-
*/
|
|
28
|
+
constructor(url: URL, // The network server's URL to connect to
|
|
29
|
+
task: Task, // The client's corresponding task
|
|
46
30
|
aggregator: Aggregator);
|
|
47
31
|
/**
|
|
48
32
|
* Handles the connection process from the client to any sort of network server.
|
|
33
|
+
* This method is overriden by the federated and decentralized clients
|
|
34
|
+
* By default, it fetches and returns the server's base model
|
|
49
35
|
*/
|
|
50
|
-
connect(): Promise<
|
|
36
|
+
connect(): Promise<Model>;
|
|
51
37
|
/**
|
|
52
38
|
* Handles the disconnection process of the client from any sort of network server.
|
|
53
39
|
*/
|
|
@@ -59,17 +45,14 @@ export declare abstract class Base {
|
|
|
59
45
|
getLatestModel(): Promise<Model>;
|
|
60
46
|
/**
|
|
61
47
|
* Communication callback called at the beginning of every training round.
|
|
62
|
-
* @param _weights The most recent local weight updates
|
|
63
|
-
* @param _round The current training round
|
|
64
48
|
*/
|
|
65
|
-
onRoundBeginCommunication(
|
|
49
|
+
abstract onRoundBeginCommunication(): Promise<void>;
|
|
66
50
|
/**
|
|
67
51
|
* Communication callback called the end of every training round.
|
|
68
|
-
* @param
|
|
69
|
-
* @param _round The current training round
|
|
52
|
+
* @param weights The local weight update resulting for the current local training round
|
|
70
53
|
* @returns aggregated weights or the local weights upon error
|
|
71
54
|
*/
|
|
72
|
-
abstract onRoundEndCommunication(
|
|
55
|
+
abstract onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
73
56
|
get nbOfParticipants(): number;
|
|
74
57
|
get ownId(): NodeID;
|
|
75
58
|
get server(): EventConnection;
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import axios from 'axios';
|
|
2
2
|
import { serialization } from '../index.js';
|
|
3
|
+
import { EventEmitter } from '../utils/event_emitter.js';
|
|
3
4
|
/**
|
|
4
5
|
* Main, abstract, class representing a Disco client in a network, which handles
|
|
5
6
|
* communication with other nodes, be it peers or a server.
|
|
6
7
|
*/
|
|
7
|
-
export class
|
|
8
|
+
export class Client extends EventEmitter {
|
|
8
9
|
url;
|
|
9
10
|
task;
|
|
10
11
|
aggregator;
|
|
@@ -20,27 +21,22 @@ export class Base {
|
|
|
20
21
|
* The aggregator's result produced after aggregation.
|
|
21
22
|
*/
|
|
22
23
|
aggregationResult;
|
|
23
|
-
constructor(
|
|
24
|
-
|
|
25
|
-
* The network server's URL to connect to.
|
|
26
|
-
*/
|
|
27
|
-
url,
|
|
28
|
-
/**
|
|
29
|
-
* The client's corresponding task.
|
|
30
|
-
*/
|
|
31
|
-
task,
|
|
32
|
-
/**
|
|
33
|
-
* The client's aggregator.
|
|
34
|
-
*/
|
|
24
|
+
constructor(url, // The network server's URL to connect to
|
|
25
|
+
task, // The client's corresponding task
|
|
35
26
|
aggregator) {
|
|
27
|
+
super();
|
|
36
28
|
this.url = url;
|
|
37
29
|
this.task = task;
|
|
38
30
|
this.aggregator = aggregator;
|
|
39
31
|
}
|
|
40
32
|
/**
|
|
41
33
|
* Handles the connection process from the client to any sort of network server.
|
|
34
|
+
* This method is overriden by the federated and decentralized clients
|
|
35
|
+
* By default, it fetches and returns the server's base model
|
|
42
36
|
*/
|
|
43
|
-
async connect() {
|
|
37
|
+
async connect() {
|
|
38
|
+
return this.getLatestModel();
|
|
39
|
+
}
|
|
44
40
|
/**
|
|
45
41
|
* Handles the disconnection process of the client from any sort of network server.
|
|
46
42
|
*/
|
|
@@ -58,12 +54,6 @@ export class Base {
|
|
|
58
54
|
const response = await axios.get(url.href, { responseType: 'arraybuffer' });
|
|
59
55
|
return await serialization.model.decode(new Uint8Array(response.data));
|
|
60
56
|
}
|
|
61
|
-
/**
|
|
62
|
-
* Communication callback called at the beginning of every training round.
|
|
63
|
-
* @param _weights The most recent local weight updates
|
|
64
|
-
* @param _round The current training round
|
|
65
|
-
*/
|
|
66
|
-
async onRoundBeginCommunication(_weights, _round) { }
|
|
67
57
|
// Number of contributors to a collaborative session
|
|
68
58
|
// If decentralized, it should be the number of peers
|
|
69
59
|
// If federated, it should the number of participants excluding the server
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { WeightsContainer } from "../../index.js";
|
|
1
|
+
import type { Model, WeightsContainer } from "../../index.js";
|
|
2
2
|
import { Client } from '../index.js';
|
|
3
3
|
/**
|
|
4
4
|
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
@@ -6,7 +6,7 @@ import { Client } from '../index.js';
|
|
|
6
6
|
* with the server is based off regular WebSockets, whereas peer-to-peer communication uses
|
|
7
7
|
* WebRTC for Node.js.
|
|
8
8
|
*/
|
|
9
|
-
export declare class
|
|
9
|
+
export declare class DecentralizedClient extends Client {
|
|
10
10
|
/**
|
|
11
11
|
* The pool of peers to communicate with during the current training round.
|
|
12
12
|
*/
|
|
@@ -21,7 +21,7 @@ export declare class Base extends Client {
|
|
|
21
21
|
* create peer-to-peer WebRTC connections with peers. The server is used to exchange
|
|
22
22
|
* peers network information.
|
|
23
23
|
*/
|
|
24
|
-
connect(): Promise<
|
|
24
|
+
connect(): Promise<Model>;
|
|
25
25
|
/**
|
|
26
26
|
* Create a WebSocket connection with the server
|
|
27
27
|
* The client then waits for the server to forward it other client's network information.
|
|
@@ -37,7 +37,7 @@ export declare class Base extends Client {
|
|
|
37
37
|
* and waits for it to resolve.
|
|
38
38
|
*
|
|
39
39
|
*/
|
|
40
|
-
onRoundBeginCommunication(
|
|
40
|
+
onRoundBeginCommunication(): Promise<void>;
|
|
41
41
|
/**
|
|
42
42
|
* At each communication rounds, awaits peers contributions and add them to the client's aggregator.
|
|
43
43
|
* This method is used as callback by getPeers when connecting to the rounds' peers
|
|
@@ -45,5 +45,5 @@ export declare class Base extends Client {
|
|
|
45
45
|
* @param round
|
|
46
46
|
*/
|
|
47
47
|
private receivePayloads;
|
|
48
|
-
onRoundEndCommunication(weights: WeightsContainer
|
|
48
|
+
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
49
49
|
}
|
|
@@ -14,7 +14,7 @@ const debug = createDebug("discojs:client:decentralized");
|
|
|
14
14
|
* with the server is based off regular WebSockets, whereas peer-to-peer communication uses
|
|
15
15
|
* WebRTC for Node.js.
|
|
16
16
|
*/
|
|
17
|
-
export class
|
|
17
|
+
export class DecentralizedClient extends Client {
|
|
18
18
|
/**
|
|
19
19
|
* The pool of peers to communicate with during the current training round.
|
|
20
20
|
*/
|
|
@@ -33,6 +33,7 @@ export class Base extends Client {
|
|
|
33
33
|
* peers network information.
|
|
34
34
|
*/
|
|
35
35
|
async connect() {
|
|
36
|
+
const model = await super.connect(); // Get the server base model
|
|
36
37
|
const serverURL = new URL('', this.url.href);
|
|
37
38
|
switch (this.url.protocol) {
|
|
38
39
|
case 'http:':
|
|
@@ -44,19 +45,20 @@ export class Base extends Client {
|
|
|
44
45
|
default:
|
|
45
46
|
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
46
47
|
}
|
|
47
|
-
serverURL.pathname += `
|
|
48
|
+
serverURL.pathname += `decentralized/${this.task.id}`;
|
|
48
49
|
this._server = await this.connectServer(serverURL);
|
|
49
50
|
const msg = {
|
|
50
51
|
type: type.ClientConnected
|
|
51
52
|
};
|
|
52
53
|
this.server.send(msg);
|
|
53
|
-
const
|
|
54
|
-
debug(`[${
|
|
54
|
+
const { id } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
|
|
55
|
+
debug(`[${id}] assigned id generated by server`);
|
|
55
56
|
if (this._ownId !== undefined) {
|
|
56
57
|
throw new Error('received id from server but was already received');
|
|
57
58
|
}
|
|
58
|
-
this._ownId =
|
|
59
|
-
this.pool = new PeerPool(
|
|
59
|
+
this._ownId = id;
|
|
60
|
+
this.pool = new PeerPool(id);
|
|
61
|
+
return model;
|
|
60
62
|
}
|
|
61
63
|
/**
|
|
62
64
|
* Create a WebSocket connection with the server
|
|
@@ -96,21 +98,22 @@ export class Base extends Client {
|
|
|
96
98
|
* and waits for it to resolve.
|
|
97
99
|
*
|
|
98
100
|
*/
|
|
99
|
-
async onRoundBeginCommunication(
|
|
101
|
+
async onRoundBeginCommunication() {
|
|
100
102
|
if (this.server === undefined) {
|
|
101
103
|
throw new Error("peer's server is undefined, make sure to call `client.connect()` first");
|
|
102
104
|
}
|
|
103
105
|
if (this.pool === undefined) {
|
|
104
106
|
throw new Error('peer pool is undefined, make sure to call `client.connect()` first');
|
|
105
107
|
}
|
|
108
|
+
this.emit("status", "Retrieving peers' information");
|
|
106
109
|
// Reset peers list at each round of training to make sure client works with an updated peers
|
|
107
110
|
// list, maintained by the server. Adds any received weights to the aggregator.
|
|
108
|
-
// this.connections = await this.waitForPeers(round)
|
|
109
111
|
// Tell the server we are ready for the next round
|
|
110
112
|
const readyMessage = { type: type.PeerIsReady };
|
|
111
113
|
this.server.send(readyMessage);
|
|
112
114
|
// Wait for the server to answer with the list of peers for the round
|
|
113
115
|
try {
|
|
116
|
+
debug(`[${this.ownId}] is waiting for peer list for round ${this.aggregator.round}`);
|
|
114
117
|
const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound, undefined, "Timeout waiting for the round's peer list");
|
|
115
118
|
const peers = Set(receivedMessage.peers);
|
|
116
119
|
if (this.ownId !== undefined && peers.has(this.ownId)) {
|
|
@@ -124,18 +127,19 @@ export class Base extends Client {
|
|
|
124
127
|
// Init receipt of peers weights
|
|
125
128
|
// this awaits the peer's weight update and adds it to
|
|
126
129
|
// our aggregator upon reception
|
|
127
|
-
(conn) => { this.receivePayloads(conn, round); });
|
|
128
|
-
debug(`[${this.ownId}] received peers for round ${round}: %o`, connections.keySeq().toJS());
|
|
130
|
+
(conn) => { this.receivePayloads(conn, this.aggregator.round); });
|
|
131
|
+
debug(`[${this.ownId}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS());
|
|
129
132
|
this.connections = connections;
|
|
130
133
|
}
|
|
131
134
|
catch (e) {
|
|
132
|
-
debug(`[${this.ownId}] while beginning round: %o`, e);
|
|
135
|
+
debug(`Error for [${this.ownId}] while beginning round: %o`, e);
|
|
133
136
|
this.aggregator.setNodes(Set(this.ownId));
|
|
134
137
|
this.connections = Map();
|
|
135
138
|
}
|
|
136
139
|
// Store the promise for the current round's aggregation result.
|
|
137
140
|
// We will await for it to resolve at the end of the round when exchanging weight updates.
|
|
138
141
|
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
142
|
+
this.emit("status", "Training the model on the data you connected");
|
|
139
143
|
}
|
|
140
144
|
/**
|
|
141
145
|
* At each communication rounds, awaits peers contributions and add them to the client's aggregator.
|
|
@@ -156,18 +160,18 @@ export class Base extends Client {
|
|
|
156
160
|
}
|
|
157
161
|
}
|
|
158
162
|
catch (e) {
|
|
159
|
-
if (this.isDisconnected)
|
|
163
|
+
if (this.isDisconnected)
|
|
160
164
|
return;
|
|
161
|
-
}
|
|
162
|
-
debug(`[${this.ownId}] while receiving payloads: %o`, e);
|
|
165
|
+
debug(`Error for [${this.ownId}] while receiving payloads: %o`, e);
|
|
163
166
|
}
|
|
164
167
|
} while (++currentCommunicationRounds < this.aggregator.communicationRounds);
|
|
165
168
|
});
|
|
166
169
|
}
|
|
167
|
-
async onRoundEndCommunication(weights
|
|
170
|
+
async onRoundEndCommunication(weights) {
|
|
168
171
|
if (this.aggregationResult === undefined) {
|
|
169
172
|
throw new TypeError('aggregation result promise is undefined');
|
|
170
173
|
}
|
|
174
|
+
this.emit("status", "Updating the model with other participants' models");
|
|
171
175
|
// Perform the required communication rounds. Each communication round consists in sending our local payload,
|
|
172
176
|
// followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
|
|
173
177
|
// A communication round's payload is the aggregation result of the previous communication round. The first
|
|
@@ -181,7 +185,7 @@ export class Base extends Client {
|
|
|
181
185
|
try {
|
|
182
186
|
await Promise.all(payloads.map(async (payload, id) => {
|
|
183
187
|
if (id === this.ownId) {
|
|
184
|
-
this.aggregator.add(this.ownId, payload, round, r);
|
|
188
|
+
this.aggregator.add(this.ownId, payload, this.aggregator.round, r);
|
|
185
189
|
}
|
|
186
190
|
else {
|
|
187
191
|
const peer = this.connections?.get(id);
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export {
|
|
1
|
+
export { DecentralizedClient } from './decentralized_client.js';
|
|
2
2
|
export * as messages from './messages.js';
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export {
|
|
1
|
+
export { DecentralizedClient } from './decentralized_client.js';
|
|
2
2
|
export * as messages from './messages.js';
|
|
@@ -1,7 +1,12 @@
|
|
|
1
1
|
import { weights } from '../../serialization/index.js';
|
|
2
2
|
import { type SignalData } from './peer.js';
|
|
3
3
|
import { type NodeID } from '../types.js';
|
|
4
|
-
import { type, type ClientConnected
|
|
4
|
+
import { type, type ClientConnected } from '../messages.js';
|
|
5
|
+
export interface NewDecentralizedNodeInfo {
|
|
6
|
+
type: type.NewDecentralizedNodeInfo;
|
|
7
|
+
id: NodeID;
|
|
8
|
+
waitForMoreParticipants: boolean;
|
|
9
|
+
}
|
|
5
10
|
export interface SignalForPeer {
|
|
6
11
|
type: type.SignalForPeer;
|
|
7
12
|
peer: NodeID;
|
|
@@ -20,7 +25,7 @@ export interface Payload {
|
|
|
20
25
|
round: number;
|
|
21
26
|
payload: weights.Encoded;
|
|
22
27
|
}
|
|
23
|
-
export type MessageFromServer =
|
|
28
|
+
export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound;
|
|
24
29
|
export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady;
|
|
25
30
|
export type PeerMessage = Payload;
|
|
26
31
|
export declare function isMessageFromServer(o: unknown): o is MessageFromServer;
|
|
@@ -6,8 +6,10 @@ export function isMessageFromServer(o) {
|
|
|
6
6
|
return false;
|
|
7
7
|
}
|
|
8
8
|
switch (o.type) {
|
|
9
|
-
case type.
|
|
10
|
-
return 'id' in o && isNodeID(o.id)
|
|
9
|
+
case type.NewDecentralizedNodeInfo:
|
|
10
|
+
return 'id' in o && isNodeID(o.id) &&
|
|
11
|
+
'waitForMoreParticipants' in o &&
|
|
12
|
+
typeof o.waitForMoreParticipants === 'boolean';
|
|
11
13
|
case type.SignalForPeer:
|
|
12
14
|
return 'peer' in o && isNodeID(o.peer) &&
|
|
13
15
|
'signal' in o; // TODO check signal content?
|
|
@@ -95,8 +95,8 @@ export class WebSocketServer extends EventEmitter {
|
|
|
95
95
|
}
|
|
96
96
|
disconnect() {
|
|
97
97
|
return new Promise((resolve, reject) => {
|
|
98
|
-
this.socket.
|
|
99
|
-
this.socket.
|
|
98
|
+
this.socket.onclose = () => resolve();
|
|
99
|
+
this.socket.onerror = (e) => reject(e.message);
|
|
100
100
|
this.socket.close();
|
|
101
101
|
});
|
|
102
102
|
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import type { Model, WeightsContainer } from "../../index.js";
|
|
2
|
+
import { Client } from "../client.js";
|
|
3
|
+
/**
|
|
4
|
+
* Client class that communicates with a centralized, federated server, when training
|
|
5
|
+
* a specific task in the federated setting.
|
|
6
|
+
*/
|
|
7
|
+
export declare class FederatedClient extends Client {
|
|
8
|
+
#private;
|
|
9
|
+
get nbOfParticipants(): number;
|
|
10
|
+
/**
|
|
11
|
+
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
12
|
+
*/
|
|
13
|
+
private connectServer;
|
|
14
|
+
/**
|
|
15
|
+
* Initializes the connection to the server, gets our node ID
|
|
16
|
+
* as well as the latest training information: latest global model, current round and
|
|
17
|
+
* whether we are waiting for more participants.
|
|
18
|
+
*/
|
|
19
|
+
connect(): Promise<Model>;
|
|
20
|
+
/**
|
|
21
|
+
* Method called when the server notifies the client that there aren't enough
|
|
22
|
+
* participants (anymore) to start/continue training
|
|
23
|
+
* The method creates a promise that will resolve once the server notifies
|
|
24
|
+
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
25
|
+
* @returns a promise which resolves when enough participants joined the session
|
|
26
|
+
*/
|
|
27
|
+
private waitForMoreParticipants;
|
|
28
|
+
/**
|
|
29
|
+
* Disconnection process when user quits the task.
|
|
30
|
+
*/
|
|
31
|
+
disconnect(): Promise<void>;
|
|
32
|
+
onRoundBeginCommunication(): Promise<void>;
|
|
33
|
+
/**
|
|
34
|
+
* Send the local weight update to the server and waits (indefinitely) for the server global update
|
|
35
|
+
*
|
|
36
|
+
* If the waitingForMoreParticipants flag is set, we first wait (also indefinitely) until the
|
|
37
|
+
* server notifies us that the training can resume.
|
|
38
|
+
*
|
|
39
|
+
// NB: For now, we suppose a fully-federated setting.
|
|
40
|
+
* @param weights Local weights sent to the server at the end of the local training round
|
|
41
|
+
* @returns the new global weights sent by the server
|
|
42
|
+
*/
|
|
43
|
+
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
44
|
+
}
|