@epfml/discojs 3.0.1-p20250331130104.0 → 3.0.1-p20250401132959.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/client/client.d.ts +8 -8
- package/dist/client/client.js +29 -8
- package/dist/client/decentralized/decentralized_client.d.ts +1 -1
- package/dist/client/decentralized/decentralized_client.js +10 -8
- package/dist/client/decentralized/messages.d.ts +1 -0
- package/dist/client/federated/federated_client.d.ts +0 -2
- package/dist/client/federated/federated_client.js +2 -9
- package/dist/client/local_client.d.ts +0 -1
- package/dist/client/local_client.js +0 -3
- package/dist/client/messages.d.ts +2 -0
- package/dist/training/disco.d.ts +1 -0
- package/dist/training/disco.js +1 -0
- package/dist/training/trainer.js +1 -1
- package/package.json +2 -2
package/dist/client/client.d.ts
CHANGED
|
@@ -9,7 +9,9 @@ import { EventEmitter } from '../utils/event_emitter.js';
|
|
|
9
9
|
*/
|
|
10
10
|
export declare abstract class Client extends EventEmitter<{
|
|
11
11
|
'status': RoundStatus;
|
|
12
|
+
'participants': number;
|
|
12
13
|
}> {
|
|
14
|
+
#private;
|
|
13
15
|
readonly url: URL;
|
|
14
16
|
readonly task: Task<DataType>;
|
|
15
17
|
readonly aggregator: Aggregator;
|
|
@@ -22,13 +24,6 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
22
24
|
* until the server signals that the training can resume
|
|
23
25
|
*/
|
|
24
26
|
protected promiseForMoreParticipants: Promise<void> | undefined;
|
|
25
|
-
/**
|
|
26
|
-
* When the server notifies the client that they can resume training
|
|
27
|
-
* after waiting for more participants, we want to be able to display what
|
|
28
|
-
* we were doing before waiting (training locally or updating our model).
|
|
29
|
-
* We use this attribute to store the status to rollback to when we stop waiting
|
|
30
|
-
*/
|
|
31
|
-
private previousStatus;
|
|
32
27
|
constructor(url: URL, // The network server's URL to connect to
|
|
33
28
|
task: Task<DataType>, // The client's corresponding task
|
|
34
29
|
aggregator: Aggregator);
|
|
@@ -101,7 +96,12 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
101
96
|
* If federated, it should the number of participants excluding the server
|
|
102
97
|
* If local it should be 1
|
|
103
98
|
*/
|
|
104
|
-
|
|
99
|
+
get nbOfParticipants(): number;
|
|
100
|
+
/**
|
|
101
|
+
* Setter for the number of participants
|
|
102
|
+
* It emits the number of participants to the client
|
|
103
|
+
*/
|
|
104
|
+
set nbOfParticipants(nbOfParticipants: number);
|
|
105
105
|
get ownId(): NodeID;
|
|
106
106
|
get server(): EventConnection;
|
|
107
107
|
/**
|
package/dist/client/client.js
CHANGED
|
@@ -29,7 +29,9 @@ export class Client extends EventEmitter {
|
|
|
29
29
|
* we were doing before waiting (training locally or updating our model).
|
|
30
30
|
* We use this attribute to store the status to rollback to when we stop waiting
|
|
31
31
|
*/
|
|
32
|
-
previousStatus;
|
|
32
|
+
#previousStatus;
|
|
33
|
+
// Current number of participants including this client in the training session
|
|
34
|
+
#nbOfParticipants = 1;
|
|
33
35
|
constructor(url, // The network server's URL to connect to
|
|
34
36
|
task, // The client's corresponding task
|
|
35
37
|
aggregator) {
|
|
@@ -56,7 +58,7 @@ export class Client extends EventEmitter {
|
|
|
56
58
|
* the waiting status and once enough participants join, it can display the previous status again
|
|
57
59
|
*/
|
|
58
60
|
saveAndEmit(status) {
|
|
59
|
-
this
|
|
61
|
+
this.#previousStatus = status;
|
|
60
62
|
this.emit("status", status);
|
|
61
63
|
}
|
|
62
64
|
/**
|
|
@@ -84,12 +86,13 @@ export class Client extends EventEmitter {
|
|
|
84
86
|
setupServerCallbacks(setMessageInversionFlag) {
|
|
85
87
|
// Setup an event callback if the server signals that we should
|
|
86
88
|
// wait for more participants
|
|
87
|
-
this.server.on(type.WaitingForMoreParticipants, () => {
|
|
89
|
+
this.server.on(type.WaitingForMoreParticipants, (event) => {
|
|
88
90
|
if (this.promiseForMoreParticipants !== undefined)
|
|
89
91
|
throw new Error("Server sent multiple WaitingForMoreParticipants messages");
|
|
90
92
|
debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`);
|
|
91
93
|
// Display the waiting status right away
|
|
92
94
|
this.emit("status", "not enough participants");
|
|
95
|
+
this.nbOfParticipants = event.nbOfParticipants; // emits the `participants` event
|
|
93
96
|
// Upon receiving a WaitingForMoreParticipants message,
|
|
94
97
|
// the client will await for this promise to resolve before sending its
|
|
95
98
|
// local weight update
|
|
@@ -101,10 +104,10 @@ export class Client extends EventEmitter {
|
|
|
101
104
|
// and directly follows with an EnoughParticipants message when the 2nd participant joins
|
|
102
105
|
// However, the EnoughParticipants can arrive before the NewNodeInfo (which can be much bigger)
|
|
103
106
|
// so we check whether we received the EnoughParticipants before being assigned a node ID
|
|
104
|
-
this.server.once(type.EnoughParticipants, () => {
|
|
107
|
+
this.server.once(type.EnoughParticipants, (event) => {
|
|
105
108
|
if (this._ownId === undefined) {
|
|
106
|
-
debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
|
|
107
109
|
setMessageInversionFlag();
|
|
110
|
+
this.nbOfParticipants = event.nbOfParticipants;
|
|
108
111
|
}
|
|
109
112
|
});
|
|
110
113
|
}
|
|
@@ -118,11 +121,12 @@ export class Client extends EventEmitter {
|
|
|
118
121
|
async createPromiseForMoreParticipants() {
|
|
119
122
|
return new Promise((resolve) => {
|
|
120
123
|
// "once" is important because we can't resolve the same promise multiple times
|
|
121
|
-
this.server.once(type.EnoughParticipants, () => {
|
|
124
|
+
this.server.once(type.EnoughParticipants, (event) => {
|
|
122
125
|
debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`);
|
|
123
126
|
// Emit the last status emitted before waiting if defined
|
|
124
|
-
if (this
|
|
125
|
-
this.emit("status", this
|
|
127
|
+
if (this.#previousStatus !== undefined)
|
|
128
|
+
this.emit("status", this.#previousStatus);
|
|
129
|
+
this.nbOfParticipants = event.nbOfParticipants;
|
|
126
130
|
resolve();
|
|
127
131
|
});
|
|
128
132
|
});
|
|
@@ -154,6 +158,23 @@ export class Client extends EventEmitter {
|
|
|
154
158
|
const encoded = new Uint8Array(await response.arrayBuffer());
|
|
155
159
|
return await serialization.model.decode(encoded);
|
|
156
160
|
}
|
|
161
|
+
/**
|
|
162
|
+
* Number of contributors to a collaborative session
|
|
163
|
+
* If decentralized, it should be the number of peers
|
|
164
|
+
* If federated, it should the number of participants excluding the server
|
|
165
|
+
* If local it should be 1
|
|
166
|
+
*/
|
|
167
|
+
get nbOfParticipants() {
|
|
168
|
+
return this.#nbOfParticipants;
|
|
169
|
+
}
|
|
170
|
+
/**
|
|
171
|
+
* Setter for the number of participants
|
|
172
|
+
* It emits the number of participants to the client
|
|
173
|
+
*/
|
|
174
|
+
set nbOfParticipants(nbOfParticipants) {
|
|
175
|
+
this.#nbOfParticipants = nbOfParticipants;
|
|
176
|
+
this.emit("participants", nbOfParticipants);
|
|
177
|
+
}
|
|
157
178
|
get ownId() {
|
|
158
179
|
if (this._ownId === undefined) {
|
|
159
180
|
throw new Error('the node is not connected');
|
|
@@ -8,8 +8,8 @@ import { Client } from '../client.js';
|
|
|
8
8
|
*/
|
|
9
9
|
export declare class DecentralizedClient extends Client {
|
|
10
10
|
#private;
|
|
11
|
-
getNbOfParticipants(): number;
|
|
12
11
|
private get isDisconnected();
|
|
12
|
+
private setAggregatorNodes;
|
|
13
13
|
/**
|
|
14
14
|
* Public method called by disco.ts when starting training. This method sends
|
|
15
15
|
* a message to the server asking to join the task and be assigned a client ID.
|
|
@@ -20,14 +20,15 @@ export class DecentralizedClient extends Client {
|
|
|
20
20
|
*/
|
|
21
21
|
#pool;
|
|
22
22
|
#connections;
|
|
23
|
-
getNbOfParticipants() {
|
|
24
|
-
const nbOfParticipants = this.aggregator.nodes.size;
|
|
25
|
-
return nbOfParticipants === 0 ? 1 : nbOfParticipants;
|
|
26
|
-
}
|
|
27
23
|
// Used to handle timeouts and promise resolving after calling disconnect
|
|
28
24
|
get isDisconnected() {
|
|
29
25
|
return this._server === undefined;
|
|
30
26
|
}
|
|
27
|
+
setAggregatorNodes(nodes) {
|
|
28
|
+
this.aggregator.setNodes(nodes);
|
|
29
|
+
// Emits the `participants` event
|
|
30
|
+
this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size;
|
|
31
|
+
}
|
|
31
32
|
/**
|
|
32
33
|
* Public method called by disco.ts when starting training. This method sends
|
|
33
34
|
* a message to the server asking to join the task and be assigned a client ID.
|
|
@@ -67,7 +68,8 @@ export class DecentralizedClient extends Client {
|
|
|
67
68
|
type: type.ClientConnected
|
|
68
69
|
};
|
|
69
70
|
this.server.send(msg);
|
|
70
|
-
const { id, waitForMoreParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
|
|
71
|
+
const { id, waitForMoreParticipants, nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo);
|
|
72
|
+
this.nbOfParticipants = nbOfParticipants;
|
|
71
73
|
// This should come right after receiving the message to make sure
|
|
72
74
|
// we don't miss a subsequent message from the server
|
|
73
75
|
// We check if the server is telling us to wait for more participants
|
|
@@ -92,7 +94,7 @@ export class DecentralizedClient extends Client {
|
|
|
92
94
|
this.#pool = undefined;
|
|
93
95
|
if (this.#connections !== undefined) {
|
|
94
96
|
const peers = this.#connections.keySeq().toSet();
|
|
95
|
-
this.
|
|
97
|
+
this.setAggregatorNodes(this.aggregator.nodes.subtract(peers));
|
|
96
98
|
}
|
|
97
99
|
// Disconnect from server
|
|
98
100
|
await this.server?.disconnect();
|
|
@@ -158,7 +160,7 @@ export class DecentralizedClient extends Client {
|
|
|
158
160
|
throw new Error('received peer list contains our own id');
|
|
159
161
|
}
|
|
160
162
|
// Store the list of peers for the current round including ourselves
|
|
161
|
-
this.
|
|
163
|
+
this.setAggregatorNodes(peers.add(this.ownId));
|
|
162
164
|
this.aggregator.setRound(receivedMessage.aggregationRound); // the server gives us the round number
|
|
163
165
|
// Initiate peer to peer connections with each peer
|
|
164
166
|
// When connected, create a promise waiting for each peer's round contribution
|
|
@@ -171,7 +173,7 @@ export class DecentralizedClient extends Client {
|
|
|
171
173
|
}
|
|
172
174
|
catch (e) {
|
|
173
175
|
debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e);
|
|
174
|
-
this.
|
|
176
|
+
this.setAggregatorNodes(Set(this.ownId));
|
|
175
177
|
this.#connections = Map();
|
|
176
178
|
}
|
|
177
179
|
}
|
|
@@ -5,8 +5,6 @@ import { Client } from "../client.js";
|
|
|
5
5
|
* a specific task in the federated setting.
|
|
6
6
|
*/
|
|
7
7
|
export declare class FederatedClient extends Client {
|
|
8
|
-
#private;
|
|
9
|
-
getNbOfParticipants(): number;
|
|
10
8
|
/**
|
|
11
9
|
* Initializes the connection to the server, gets our node ID
|
|
12
10
|
* as well as the latest training information: latest global model, current round and
|
|
@@ -16,13 +16,6 @@ const SERVER_NODE_ID = "federated-server-node-id";
|
|
|
16
16
|
* a specific task in the federated setting.
|
|
17
17
|
*/
|
|
18
18
|
export class FederatedClient extends Client {
|
|
19
|
-
// Total number of other federated contributors, including this client, excluding the server
|
|
20
|
-
// E.g., if 3 users are training a federated model, nbOfParticipants is 3
|
|
21
|
-
#nbOfParticipants = 1;
|
|
22
|
-
// the number of participants excluding the server
|
|
23
|
-
getNbOfParticipants() {
|
|
24
|
-
return this.#nbOfParticipants;
|
|
25
|
-
}
|
|
26
19
|
/**
|
|
27
20
|
* Initializes the connection to the server, gets our node ID
|
|
28
21
|
* as well as the latest training information: latest global model, current round and
|
|
@@ -70,7 +63,7 @@ export class FederatedClient extends Client {
|
|
|
70
63
|
this._ownId = id;
|
|
71
64
|
debug(`[${shortenId(id)}] joined session at round ${round} `);
|
|
72
65
|
this.aggregator.setRound(round);
|
|
73
|
-
this
|
|
66
|
+
this.nbOfParticipants = nbOfParticipants;
|
|
74
67
|
// Upon connecting, the server answers with a boolean
|
|
75
68
|
// which indicates whether there are enough participants or not
|
|
76
69
|
debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants);
|
|
@@ -129,7 +122,7 @@ export class FederatedClient extends Client {
|
|
|
129
122
|
this.server.send(msg);
|
|
130
123
|
debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`);
|
|
131
124
|
const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
|
|
132
|
-
this
|
|
125
|
+
this.nbOfParticipants = nbOfParticipants; // Save the current participants
|
|
133
126
|
const serverResult = serialization.weights.decode(payloadFromServer);
|
|
134
127
|
this.aggregator.setRound(serverRound);
|
|
135
128
|
return serverResult;
|
|
@@ -5,7 +5,6 @@ import { Client } from "./client.js";
|
|
|
5
5
|
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
6
6
|
*/
|
|
7
7
|
export declare class LocalClient extends Client {
|
|
8
|
-
getNbOfParticipants(): number;
|
|
9
8
|
onRoundBeginCommunication(): Promise<void>;
|
|
10
9
|
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
11
10
|
}
|
|
@@ -4,9 +4,6 @@ import { Client } from "./client.js";
|
|
|
4
4
|
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
5
5
|
*/
|
|
6
6
|
export class LocalClient extends Client {
|
|
7
|
-
getNbOfParticipants() {
|
|
8
|
-
return 1;
|
|
9
|
-
}
|
|
10
7
|
onRoundBeginCommunication() {
|
|
11
8
|
return Promise.resolve();
|
|
12
9
|
}
|
|
@@ -19,9 +19,11 @@ export interface ClientConnected {
|
|
|
19
19
|
}
|
|
20
20
|
export interface EnoughParticipants {
|
|
21
21
|
type: type.EnoughParticipants;
|
|
22
|
+
nbOfParticipants: number;
|
|
22
23
|
}
|
|
23
24
|
export interface WaitingForMoreParticipants {
|
|
24
25
|
type: type.WaitingForMoreParticipants;
|
|
26
|
+
nbOfParticipants: number;
|
|
25
27
|
}
|
|
26
28
|
export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
|
|
27
29
|
export type NarrowMessage<D> = Extract<Message, {
|
package/dist/training/disco.d.ts
CHANGED
package/dist/training/disco.js
CHANGED
|
@@ -53,6 +53,7 @@ export class Disco extends EventEmitter {
|
|
|
53
53
|
this.trainer = new Trainer(task, client);
|
|
54
54
|
// Simply propagate the training status events emitted by the client
|
|
55
55
|
this.#client.on("status", (status) => this.emit("status", status));
|
|
56
|
+
this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants));
|
|
56
57
|
}
|
|
57
58
|
/** Train on dataset, yielding logs of every round. */
|
|
58
59
|
async *trainByRound(dataset) {
|
package/dist/training/trainer.js
CHANGED
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@epfml/discojs",
|
|
3
|
-
"version": "3.0.1-
|
|
3
|
+
"version": "3.0.1-p20250401132959.0",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"main": "dist/index.js",
|
|
6
6
|
"types": "dist/index.d.ts",
|
|
@@ -22,7 +22,7 @@
|
|
|
22
22
|
"@epfml/isomorphic-wrtc": "1",
|
|
23
23
|
"@jimp/core": "1",
|
|
24
24
|
"@jimp/plugin-resize": "1",
|
|
25
|
-
"@msgpack/msgpack": "
|
|
25
|
+
"@msgpack/msgpack": "3",
|
|
26
26
|
"@tensorflow/tfjs": "4",
|
|
27
27
|
"@xenova/transformers": "2",
|
|
28
28
|
"isomorphic-ws": "5",
|