@epfml/discojs 3.0.1-p20241001093123.0 → 3.0.1-p20241014092014.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
- package/dist/aggregator/{base.js → aggregator.js} +48 -36
- package/dist/aggregator/get.d.ts +2 -2
- package/dist/aggregator/get.js +4 -4
- package/dist/aggregator/index.d.ts +1 -4
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +4 -4
- package/dist/aggregator/mean.js +5 -15
- package/dist/aggregator/secure.d.ts +4 -4
- package/dist/aggregator/secure.js +7 -17
- package/dist/client/client.d.ts +71 -17
- package/dist/client/client.js +115 -14
- package/dist/client/decentralized/decentralized_client.d.ts +11 -13
- package/dist/client/decentralized/decentralized_client.js +121 -84
- package/dist/client/decentralized/messages.d.ts +10 -4
- package/dist/client/decentralized/messages.js +7 -6
- package/dist/client/federated/federated_client.d.ts +1 -13
- package/dist/client/federated/federated_client.js +15 -94
- package/dist/client/federated/messages.d.ts +2 -7
- package/dist/client/local_client.d.ts +1 -0
- package/dist/client/local_client.js +3 -0
- package/dist/client/messages.d.ts +14 -7
- package/dist/client/messages.js +13 -11
- package/dist/default_tasks/cifar10.js +1 -1
- package/dist/default_tasks/lus_covid.js +1 -0
- package/dist/default_tasks/mnist.js +1 -1
- package/dist/default_tasks/simple_face.js +1 -0
- package/dist/default_tasks/titanic.js +1 -0
- package/dist/default_tasks/wikitext.js +1 -0
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +6 -8
- package/dist/training/disco.d.ts +4 -1
- package/dist/training/trainer.js +1 -1
- package/dist/utils/event_emitter.d.ts +3 -3
- package/dist/utils/event_emitter.js +10 -9
- package/package.json +1 -1
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import { serialization } from "../../index.js";
|
|
3
|
-
import { Client } from "../client.js";
|
|
3
|
+
import { Client, shortenId } from "../client.js";
|
|
4
4
|
import { type } from "../messages.js";
|
|
5
5
|
import { waitMessage, waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
|
|
6
6
|
import * as messages from "./messages.js";
|
|
@@ -19,38 +19,10 @@ export class FederatedClient extends Client {
|
|
|
19
19
|
// Total number of other federated contributors, including this client, excluding the server
|
|
20
20
|
// E.g., if 3 users are training a federated model, nbOfParticipants is 3
|
|
21
21
|
#nbOfParticipants = 1;
|
|
22
|
-
/**
|
|
23
|
-
* When the server notifies clients to pause and wait until more
|
|
24
|
-
* participants join, we rely on this promise to wait
|
|
25
|
-
* until the server signals that the training can resume
|
|
26
|
-
*/
|
|
27
|
-
#promiseForMoreParticipants = undefined;
|
|
28
|
-
/**
|
|
29
|
-
* When the server notifies the client that they can resume training
|
|
30
|
-
* after waiting for more participants, we want to be able to display what
|
|
31
|
-
* we were doing before waiting (training locally or updating our model).
|
|
32
|
-
* We use this attribute to store the status to rollback to when we stop waiting
|
|
33
|
-
*/
|
|
34
|
-
#previousStatus = undefined;
|
|
35
|
-
/**
|
|
36
|
-
* Whether the client should wait until more
|
|
37
|
-
* participants join the session, i.e. a promise has been created
|
|
38
|
-
*/
|
|
39
|
-
get #waitingForMoreParticipants() {
|
|
40
|
-
return this.#promiseForMoreParticipants !== undefined;
|
|
41
|
-
}
|
|
42
22
|
// the number of participants excluding the server
|
|
43
|
-
|
|
23
|
+
getNbOfParticipants() {
|
|
44
24
|
return this.#nbOfParticipants;
|
|
45
25
|
}
|
|
46
|
-
/**
|
|
47
|
-
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
48
|
-
*/
|
|
49
|
-
async connectServer(url) {
|
|
50
|
-
const server = await WebSocketServer.connect(url, messages.isMessageFederated, // can only receive federated message types from the server
|
|
51
|
-
messages.isMessageFederated);
|
|
52
|
-
return server;
|
|
53
|
-
}
|
|
54
26
|
/**
|
|
55
27
|
* Initializes the connection to the server, gets our node ID
|
|
56
28
|
* as well as the latest training information: latest global model, current round and
|
|
@@ -70,31 +42,12 @@ export class FederatedClient extends Client {
|
|
|
70
42
|
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
71
43
|
}
|
|
72
44
|
serverURL.pathname += `federated/${this.task.id}`;
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
debug(`[${id.slice(0, 4)}] received WaitingForMoreParticipants message from server`);
|
|
78
|
-
// Display the waiting status right away
|
|
79
|
-
this.emit("status", "Waiting for more participants");
|
|
80
|
-
// Upon receiving a WaitingForMoreParticipants message,
|
|
81
|
-
// the client will await for this promise to resolve before sending its
|
|
82
|
-
// local weight update
|
|
83
|
-
this.#promiseForMoreParticipants = this.waitForMoreParticipants();
|
|
84
|
-
});
|
|
85
|
-
// As an example assume we need at least 2 participants to train,
|
|
86
|
-
// When two participants join almost at the same time, the server
|
|
87
|
-
// sends a NewFederatedNodeInfo with waitForMoreParticipants=true to the first participant
|
|
88
|
-
// and directly follows with an EnoughParticipants message when the 2nd participant joins
|
|
89
|
-
// However, the EnoughParticipants can arrive before the NewFederatedNodeInfo (which is much bigger)
|
|
90
|
-
// so we check whether we received the EnoughParticipants before being assigned a node ID
|
|
45
|
+
// Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
46
|
+
this._server = await WebSocketServer.connect(serverURL, messages.isMessageFederated, // can only receive federated message types from the server
|
|
47
|
+
messages.isMessageFederated);
|
|
48
|
+
// c.f. setupServerCallbacks doc for explanation
|
|
91
49
|
let receivedEnoughParticipants = false;
|
|
92
|
-
this.
|
|
93
|
-
if (this._ownId === undefined) {
|
|
94
|
-
debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
|
|
95
|
-
receivedEnoughParticipants = true;
|
|
96
|
-
}
|
|
97
|
-
});
|
|
50
|
+
this.setupServerCallbacks(() => receivedEnoughParticipants = true);
|
|
98
51
|
this.aggregator.registerNode(SERVER_NODE_ID);
|
|
99
52
|
const msg = {
|
|
100
53
|
type: type.ClientConnected,
|
|
@@ -109,40 +62,21 @@ export class FederatedClient extends Client {
|
|
|
109
62
|
if (waitForMoreParticipants && !receivedEnoughParticipants) {
|
|
110
63
|
// Create a promise that resolves when enough participants join
|
|
111
64
|
// The client will await this promise before sending its local weight update
|
|
112
|
-
this
|
|
65
|
+
this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
|
|
113
66
|
}
|
|
114
67
|
if (this._ownId !== undefined) {
|
|
115
68
|
throw new Error('received id from server but was already received');
|
|
116
69
|
}
|
|
117
70
|
this._ownId = id;
|
|
118
|
-
debug(`[${id
|
|
71
|
+
debug(`[${shortenId(id)}] joined session at round ${round} `);
|
|
119
72
|
this.aggregator.setRound(round);
|
|
120
73
|
this.#nbOfParticipants = nbOfParticipants;
|
|
121
74
|
// Upon connecting, the server answers with a boolean
|
|
122
75
|
// which indicates whether there are enough participants or not
|
|
123
|
-
debug(`[${this.ownId
|
|
76
|
+
debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants);
|
|
124
77
|
model.weights = serialization.weights.decode(payload);
|
|
125
78
|
return model;
|
|
126
79
|
}
|
|
127
|
-
/**
|
|
128
|
-
* Method called when the server notifies the client that there aren't enough
|
|
129
|
-
* participants (anymore) to start/continue training
|
|
130
|
-
* The method creates a promise that will resolve once the server notifies
|
|
131
|
-
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
132
|
-
* @returns a promise which resolves when enough participants joined the session
|
|
133
|
-
*/
|
|
134
|
-
async waitForMoreParticipants() {
|
|
135
|
-
return new Promise((resolve) => {
|
|
136
|
-
// "once" is important because we can't resolve the same promise multiple times
|
|
137
|
-
this.server.once(type.EnoughParticipants, () => {
|
|
138
|
-
debug(`[${this.ownId.slice(0, 4)}] received EnoughParticipants message from server`);
|
|
139
|
-
// Emit the last status emitted before waiting if defined
|
|
140
|
-
if (this.#previousStatus !== undefined)
|
|
141
|
-
this.emit("status", this.#previousStatus);
|
|
142
|
-
resolve();
|
|
143
|
-
});
|
|
144
|
-
});
|
|
145
|
-
}
|
|
146
80
|
/**
|
|
147
81
|
* Disconnection process when user quits the task.
|
|
148
82
|
*/
|
|
@@ -155,10 +89,7 @@ export class FederatedClient extends Client {
|
|
|
155
89
|
onRoundBeginCommunication() {
|
|
156
90
|
// Prepare the result promise for the incoming round
|
|
157
91
|
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
158
|
-
|
|
159
|
-
// Once enough new participants join we can display the previous status again
|
|
160
|
-
this.#previousStatus = "Training the model on the data you connected";
|
|
161
|
-
this.emit("status", this.#previousStatus);
|
|
92
|
+
this.saveAndEmit("local training");
|
|
162
93
|
return Promise.resolve();
|
|
163
94
|
}
|
|
164
95
|
/**
|
|
@@ -176,18 +107,8 @@ export class FederatedClient extends Client {
|
|
|
176
107
|
throw new Error("local aggregation result was not set");
|
|
177
108
|
}
|
|
178
109
|
// First we check if we are waiting for more participants before sending our weight update
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
debug(`[${this.ownId.slice(0, 4)}] is awaiting the promise for more participants`);
|
|
182
|
-
this.emit("status", "Waiting for more participants");
|
|
183
|
-
await this.#promiseForMoreParticipants;
|
|
184
|
-
// Make sure to set the promise back to undefined once resolved
|
|
185
|
-
this.#promiseForMoreParticipants = undefined;
|
|
186
|
-
}
|
|
187
|
-
// Save the status in case participants leave and we switch to waiting for more participants
|
|
188
|
-
// Once enough new participants join we can display the previous status again
|
|
189
|
-
this.#previousStatus = "Updating the model with other participants' models";
|
|
190
|
-
this.emit("status", this.#previousStatus);
|
|
110
|
+
await this.waitForParticipantsIfNeeded();
|
|
111
|
+
this.saveAndEmit("updating model");
|
|
191
112
|
// Send our local contribution to the server
|
|
192
113
|
// and receive the server global update for this round as an answer to our contribution
|
|
193
114
|
const payloadToServer = this.aggregator.makePayloads(weights).first();
|
|
@@ -198,9 +119,9 @@ export class FederatedClient extends Client {
|
|
|
198
119
|
};
|
|
199
120
|
// Need to await the resulting global model right after sending our local contribution
|
|
200
121
|
// to make sure we don't miss it
|
|
201
|
-
debug(`[${this.ownId
|
|
122
|
+
debug(`[${shortenId(this.ownId)}] sent its local update to the server for round ${this.aggregator.round}`);
|
|
202
123
|
this.server.send(msg);
|
|
203
|
-
debug(`[${this.ownId
|
|
124
|
+
debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`);
|
|
204
125
|
const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
|
|
205
126
|
this.#nbOfParticipants = nbOfParticipants; // Save the current participants
|
|
206
127
|
const serverResult = serialization.weights.decode(payloadFromServer);
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import { type weights } from '../../serialization/index.js';
|
|
2
2
|
import { type NodeID } from '..//types.js';
|
|
3
|
-
import { type
|
|
3
|
+
import { type } from '../messages.js';
|
|
4
|
+
import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
|
|
4
5
|
export type MessageFederated = ClientConnected | NewFederatedNodeInfo | SendPayload | ReceiveServerPayload | WaitingForMoreParticipants | EnoughParticipants;
|
|
5
6
|
export interface NewFederatedNodeInfo {
|
|
6
7
|
type: type.NewFederatedNodeInfo;
|
|
@@ -21,10 +22,4 @@ export interface ReceiveServerPayload {
|
|
|
21
22
|
round: number;
|
|
22
23
|
nbOfParticipants: number;
|
|
23
24
|
}
|
|
24
|
-
export interface EnoughParticipants {
|
|
25
|
-
type: type.EnoughParticipants;
|
|
26
|
-
}
|
|
27
|
-
export interface WaitingForMoreParticipants {
|
|
28
|
-
type: type.WaitingForMoreParticipants;
|
|
29
|
-
}
|
|
30
25
|
export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
|
|
@@ -5,6 +5,7 @@ import { Client } from "./client.js";
|
|
|
5
5
|
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
6
6
|
*/
|
|
7
7
|
export declare class LocalClient extends Client {
|
|
8
|
+
getNbOfParticipants(): number;
|
|
8
9
|
onRoundBeginCommunication(): Promise<void>;
|
|
9
10
|
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
10
11
|
}
|
|
@@ -4,6 +4,9 @@ import { Client } from "./client.js";
|
|
|
4
4
|
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
5
5
|
*/
|
|
6
6
|
export class LocalClient extends Client {
|
|
7
|
+
getNbOfParticipants() {
|
|
8
|
+
return 1;
|
|
9
|
+
}
|
|
7
10
|
onRoundBeginCommunication() {
|
|
8
11
|
return Promise.resolve();
|
|
9
12
|
}
|
|
@@ -3,19 +3,26 @@ import type * as federated from './federated/messages.js';
|
|
|
3
3
|
export declare enum type {
|
|
4
4
|
ClientConnected = 0,
|
|
5
5
|
NewDecentralizedNodeInfo = 1,
|
|
6
|
-
|
|
6
|
+
JoinRound = 2,
|
|
7
7
|
PeerIsReady = 3,
|
|
8
8
|
PeersForRound = 4,
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
9
|
+
SignalForPeer = 5,
|
|
10
|
+
Payload = 6,
|
|
11
|
+
NewFederatedNodeInfo = 7,
|
|
12
|
+
WaitingForMoreParticipants = 8,
|
|
13
|
+
EnoughParticipants = 9,
|
|
14
|
+
SendPayload = 10,
|
|
15
|
+
ReceiveServerPayload = 11
|
|
15
16
|
}
|
|
16
17
|
export interface ClientConnected {
|
|
17
18
|
type: type.ClientConnected;
|
|
18
19
|
}
|
|
20
|
+
export interface EnoughParticipants {
|
|
21
|
+
type: type.EnoughParticipants;
|
|
22
|
+
}
|
|
23
|
+
export interface WaitingForMoreParticipants {
|
|
24
|
+
type: type.WaitingForMoreParticipants;
|
|
25
|
+
}
|
|
19
26
|
export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
|
|
20
27
|
export type NarrowMessage<D> = Extract<Message, {
|
|
21
28
|
type: D;
|
package/dist/client/messages.js
CHANGED
|
@@ -9,34 +9,36 @@ export var type;
|
|
|
9
9
|
// answers with its peer id and also tells the client whether we are waiting
|
|
10
10
|
// for more participants before starting training
|
|
11
11
|
type[type["NewDecentralizedNodeInfo"] = 1] = "NewDecentralizedNodeInfo";
|
|
12
|
-
// Message
|
|
13
|
-
//
|
|
14
|
-
type[type["
|
|
12
|
+
// Message sent by peers to the server to signal they want to
|
|
13
|
+
// join the next round
|
|
14
|
+
type[type["JoinRound"] = 2] = "JoinRound";
|
|
15
15
|
// Message sent by nodes to server signaling they are ready to
|
|
16
16
|
// start the next round
|
|
17
17
|
type[type["PeerIsReady"] = 3] = "PeerIsReady";
|
|
18
18
|
// Sent by the server to participating peers containing the list
|
|
19
19
|
// of peers for the round
|
|
20
20
|
type[type["PeersForRound"] = 4] = "PeersForRound";
|
|
21
|
+
// Message forwarded by the server from a client to another client
|
|
22
|
+
// to establish a peer-to-peer (WebRTC) connection
|
|
23
|
+
type[type["SignalForPeer"] = 5] = "SignalForPeer";
|
|
21
24
|
// The weight update
|
|
22
|
-
type[type["Payload"] =
|
|
25
|
+
type[type["Payload"] = 6] = "Payload";
|
|
23
26
|
/* Federated */
|
|
24
27
|
// The server answers the ClientConnected message with the necessary information
|
|
25
28
|
// to start training: node id, latest model global weights, current round etc
|
|
26
|
-
type[type["NewFederatedNodeInfo"] =
|
|
29
|
+
type[type["NewFederatedNodeInfo"] = 7] = "NewFederatedNodeInfo";
|
|
27
30
|
// Message sent by server to notify clients that there are not enough
|
|
28
31
|
// participants to continue training
|
|
29
|
-
type[type["WaitingForMoreParticipants"] =
|
|
32
|
+
type[type["WaitingForMoreParticipants"] = 8] = "WaitingForMoreParticipants";
|
|
30
33
|
// Message sent by server to notify clients that there are now enough
|
|
31
34
|
// participants to start training collaboratively
|
|
32
|
-
type[type["EnoughParticipants"] =
|
|
33
|
-
type[type["SendPayload"] =
|
|
34
|
-
type[type["ReceiveServerPayload"] =
|
|
35
|
+
type[type["EnoughParticipants"] = 9] = "EnoughParticipants";
|
|
36
|
+
type[type["SendPayload"] = 10] = "SendPayload";
|
|
37
|
+
type[type["ReceiveServerPayload"] = 11] = "ReceiveServerPayload";
|
|
35
38
|
})(type || (type = {}));
|
|
36
39
|
export function hasMessageType(raw) {
|
|
37
|
-
if (typeof raw !== 'object' || raw === null)
|
|
40
|
+
if (typeof raw !== 'object' || raw === null)
|
|
38
41
|
return false;
|
|
39
|
-
}
|
|
40
42
|
const o = raw;
|
|
41
43
|
if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
|
|
42
44
|
return false;
|
|
@@ -29,8 +29,8 @@ export const cifar10 = {
|
|
|
29
29
|
IMAGE_W: 224,
|
|
30
30
|
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
|
|
31
31
|
scheme: 'decentralized',
|
|
32
|
+
aggregationStrategy: 'mean',
|
|
32
33
|
privacy: { clippingRadius: 20, noiseScale: 1 },
|
|
33
|
-
decentralizedSecure: true,
|
|
34
34
|
minNbOfParticipants: 3,
|
|
35
35
|
maxShareValue: 100,
|
|
36
36
|
tensorBackend: 'tfjs'
|
|
@@ -29,7 +29,7 @@ export const mnist = {
|
|
|
29
29
|
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
30
30
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
31
31
|
scheme: 'decentralized',
|
|
32
|
-
|
|
32
|
+
aggregationStrategy: 'secure',
|
|
33
33
|
minNbOfParticipants: 3,
|
|
34
34
|
maxShareValue: 100,
|
|
35
35
|
tensorBackend: 'tfjs'
|
|
@@ -25,6 +25,7 @@ export const wikitext = {
|
|
|
25
25
|
dataType: 'text',
|
|
26
26
|
preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
|
|
27
27
|
scheme: 'federated',
|
|
28
|
+
aggregationStrategy: 'mean',
|
|
28
29
|
minNbOfParticipants: 2,
|
|
29
30
|
epochs: 6,
|
|
30
31
|
// Unused by wikitext because data already comes split
|
|
@@ -18,10 +18,9 @@ export interface TrainingInformation {
|
|
|
18
18
|
LABEL_LIST?: string[];
|
|
19
19
|
scheme: 'decentralized' | 'federated' | 'local';
|
|
20
20
|
privacy?: Privacy;
|
|
21
|
-
decentralizedSecure?: boolean;
|
|
22
21
|
maxShareValue?: number;
|
|
23
22
|
minNbOfParticipants: number;
|
|
24
|
-
|
|
23
|
+
aggregationStrategy?: 'mean' | 'secure';
|
|
25
24
|
tokenizer?: string | PreTrainedTokenizer;
|
|
26
25
|
maxSequenceLength?: number;
|
|
27
26
|
tensorBackend: 'tfjs' | 'gpt';
|
|
@@ -24,7 +24,7 @@ export function isTrainingInformation(raw) {
|
|
|
24
24
|
if (typeof raw !== 'object' || raw === null) {
|
|
25
25
|
return false;
|
|
26
26
|
}
|
|
27
|
-
const { IMAGE_H, IMAGE_W, LABEL_LIST,
|
|
27
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregationStrategy, batchSize, dataType, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
|
|
28
28
|
if (typeof dataType !== 'string' ||
|
|
29
29
|
typeof epochs !== 'number' ||
|
|
30
30
|
typeof batchSize !== 'number' ||
|
|
@@ -33,8 +33,7 @@ export function isTrainingInformation(raw) {
|
|
|
33
33
|
typeof minNbOfParticipants !== 'number' ||
|
|
34
34
|
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
35
35
|
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
36
|
-
(
|
|
37
|
-
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
36
|
+
(aggregationStrategy !== undefined && typeof aggregationStrategy !== 'string') ||
|
|
38
37
|
(privacy !== undefined && !isPrivacy(privacy)) ||
|
|
39
38
|
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
40
39
|
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
@@ -45,8 +44,8 @@ export function isTrainingInformation(raw) {
|
|
|
45
44
|
(preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
|
|
46
45
|
return false;
|
|
47
46
|
}
|
|
48
|
-
if (
|
|
49
|
-
switch (
|
|
47
|
+
if (aggregationStrategy !== undefined) {
|
|
48
|
+
switch (aggregationStrategy) {
|
|
50
49
|
case 'mean': break;
|
|
51
50
|
case 'secure': break;
|
|
52
51
|
default: return false;
|
|
@@ -58,7 +57,7 @@ export function isTrainingInformation(raw) {
|
|
|
58
57
|
case 'text': break;
|
|
59
58
|
default: return false;
|
|
60
59
|
}
|
|
61
|
-
//
|
|
60
|
+
// interdependencies on data type
|
|
62
61
|
if (dataType === 'image') {
|
|
63
62
|
if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
|
|
64
63
|
return false;
|
|
@@ -87,10 +86,9 @@ export function isTrainingInformation(raw) {
|
|
|
87
86
|
IMAGE_W,
|
|
88
87
|
IMAGE_H,
|
|
89
88
|
LABEL_LIST,
|
|
90
|
-
|
|
89
|
+
aggregationStrategy,
|
|
91
90
|
batchSize,
|
|
92
91
|
dataType,
|
|
93
|
-
decentralizedSecure,
|
|
94
92
|
privacy,
|
|
95
93
|
epochs,
|
|
96
94
|
inputColumns,
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -7,7 +7,10 @@ interface DiscoConfig {
|
|
|
7
7
|
scheme: TrainingInformation["scheme"];
|
|
8
8
|
logger: Logger;
|
|
9
9
|
}
|
|
10
|
-
export type RoundStatus =
|
|
10
|
+
export type RoundStatus = 'not enough participants' | // Server notification to wait for more participants
|
|
11
|
+
'updating model' | // fetching/aggregating local updates into a global model
|
|
12
|
+
'local training' | // Training the model locally
|
|
13
|
+
'connecting to peers';
|
|
11
14
|
/**
|
|
12
15
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
13
16
|
* a convenient object providing a reduced yet complete API that wraps model training and
|
package/dist/training/trainer.js
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
type Listener<T> = (_: T) => void
|
|
1
|
+
type Listener<T> = (_: T) => void | Promise<void>;
|
|
2
2
|
/**
|
|
3
3
|
* Call handlers on given events
|
|
4
4
|
*
|
|
5
5
|
* @typeParam I object/mapping from event name to emitted value type
|
|
6
6
|
*/
|
|
7
7
|
export declare class EventEmitter<I extends Record<string, unknown>> {
|
|
8
|
-
private
|
|
8
|
+
#private;
|
|
9
9
|
/**
|
|
10
10
|
* @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
|
|
11
11
|
*/
|
|
@@ -13,7 +13,7 @@ export declare class EventEmitter<I extends Record<string, unknown>> {
|
|
|
13
13
|
[E in keyof I]?: Listener<I[E]>;
|
|
14
14
|
});
|
|
15
15
|
/**
|
|
16
|
-
* Register listener to call on event
|
|
16
|
+
* Register listener to call on event.
|
|
17
17
|
*
|
|
18
18
|
* @param event event name to listen to
|
|
19
19
|
* @param listener handler to call
|
|
@@ -6,7 +6,8 @@ import { List } from 'immutable';
|
|
|
6
6
|
* @typeParam I object/mapping from event name to emitted value type
|
|
7
7
|
*/
|
|
8
8
|
export class EventEmitter {
|
|
9
|
-
|
|
9
|
+
// List of callbacks to run per event
|
|
10
|
+
#listeners = {};
|
|
10
11
|
/**
|
|
11
12
|
* @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
|
|
12
13
|
*/
|
|
@@ -19,14 +20,14 @@ export class EventEmitter {
|
|
|
19
20
|
}
|
|
20
21
|
}
|
|
21
22
|
/**
|
|
22
|
-
* Register listener to call on event
|
|
23
|
+
* Register listener to call on event.
|
|
23
24
|
*
|
|
24
25
|
* @param event event name to listen to
|
|
25
26
|
* @param listener handler to call
|
|
26
27
|
*/
|
|
27
28
|
on(event, listener) {
|
|
28
|
-
const eventListeners = this
|
|
29
|
-
this
|
|
29
|
+
const eventListeners = this.#listeners[event] ?? List();
|
|
30
|
+
this.#listeners[event] = eventListeners.push([false, listener]);
|
|
30
31
|
}
|
|
31
32
|
/**
|
|
32
33
|
* Register listener to call once on next event
|
|
@@ -35,8 +36,8 @@ export class EventEmitter {
|
|
|
35
36
|
* @param listener handler to call next time
|
|
36
37
|
*/
|
|
37
38
|
once(event, listener) {
|
|
38
|
-
const eventListeners = this
|
|
39
|
-
this
|
|
39
|
+
const eventListeners = this.#listeners[event] ?? List();
|
|
40
|
+
this.#listeners[event] = eventListeners.push([true, listener]);
|
|
40
41
|
}
|
|
41
42
|
/**
|
|
42
43
|
* Send value to registered listeners of event name
|
|
@@ -45,9 +46,9 @@ export class EventEmitter {
|
|
|
45
46
|
* @param value what to call listeners with
|
|
46
47
|
*/
|
|
47
48
|
emit(event, value) {
|
|
48
|
-
const eventListeners = this
|
|
49
|
-
this
|
|
50
|
-
eventListeners.forEach(([_, listener]) => { listener(value); });
|
|
49
|
+
const eventListeners = this.#listeners[event] ?? List();
|
|
50
|
+
this.#listeners[event] = eventListeners.filterNot(([once]) => once);
|
|
51
|
+
eventListeners.forEach(async ([_, listener]) => { await listener(value); });
|
|
51
52
|
}
|
|
52
53
|
}
|
|
53
54
|
/** `EventEmitter` for all events */
|