@epfml/discojs 3.0.1-p20240902094132.0 → 3.0.1-p20240902162912.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +16 -2
- package/dist/aggregator/base.js +25 -3
- package/dist/aggregator/mean.d.ts +1 -0
- package/dist/aggregator/mean.js +11 -6
- package/dist/aggregator/secure.js +1 -1
- package/dist/client/{base.d.ts → client.d.ts} +13 -30
- package/dist/client/{base.js → client.js} +10 -20
- package/dist/client/decentralized/{base.d.ts → decentralized_client.d.ts} +5 -5
- package/dist/client/decentralized/{base.js → decentralized_client.js} +20 -16
- package/dist/client/decentralized/index.d.ts +1 -1
- package/dist/client/decentralized/index.js +1 -1
- package/dist/client/decentralized/messages.d.ts +7 -2
- package/dist/client/decentralized/messages.js +4 -2
- package/dist/client/event_connection.js +2 -2
- package/dist/client/federated/federated_client.d.ts +44 -0
- package/dist/client/federated/federated_client.js +210 -0
- package/dist/client/federated/index.d.ts +1 -1
- package/dist/client/federated/index.js +1 -1
- package/dist/client/federated/messages.d.ts +17 -2
- package/dist/client/federated/messages.js +3 -1
- package/dist/client/index.d.ts +2 -2
- package/dist/client/index.js +2 -2
- package/dist/client/local_client.d.ts +10 -0
- package/dist/client/local_client.js +14 -0
- package/dist/client/messages.d.ts +6 -8
- package/dist/client/messages.js +23 -7
- package/dist/client/utils.js +1 -1
- package/dist/default_tasks/cifar10.js +1 -1
- package/dist/default_tasks/lus_covid.js +1 -0
- package/dist/default_tasks/mnist.js +1 -1
- package/dist/default_tasks/simple_face.js +2 -1
- package/dist/default_tasks/titanic.js +2 -1
- package/dist/default_tasks/wikitext.js +1 -0
- package/dist/index.d.ts +4 -1
- package/dist/index.js +1 -0
- package/dist/logging/logger.d.ts +1 -1
- package/dist/task/index.d.ts +0 -1
- package/dist/task/index.js +0 -1
- package/dist/task/task.d.ts +0 -2
- package/dist/task/task.js +2 -4
- package/dist/task/training_information.d.ts +1 -1
- package/dist/task/training_information.js +3 -3
- package/dist/training/disco.d.ts +11 -12
- package/dist/training/disco.js +19 -34
- package/dist/training/index.d.ts +1 -1
- package/dist/training/trainer.d.ts +3 -2
- package/dist/training/trainer.js +12 -5
- package/dist/utils/event_emitter.js +1 -3
- package/package.json +1 -1
- package/dist/client/federated/base.d.ts +0 -38
- package/dist/client/federated/base.js +0 -130
- package/dist/client/local.d.ts +0 -5
- package/dist/client/local.js +0 -6
- package/dist/task/digest.d.ts +0 -5
- package/dist/task/digest.js +0 -14
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
2
|
+
import { serialization } from "../../index.js";
|
|
3
|
+
import { Client } from "../client.js";
|
|
4
|
+
import { type } from "../messages.js";
|
|
5
|
+
import { waitMessage, waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
|
|
6
|
+
import * as messages from "./messages.js";
|
|
7
|
+
const debug = createDebug("discojs:client:federated");
|
|
8
|
+
/**
|
|
9
|
+
* Arbitrary node id assigned to the federated server which we are communicating with.
|
|
10
|
+
* Indeed, the server acts as a node within the network. In the federated setting described
|
|
11
|
+
* by this client class, the server is the only node which we are communicating with.
|
|
12
|
+
*/
|
|
13
|
+
const SERVER_NODE_ID = "federated-server-node-id";
|
|
14
|
+
/**
|
|
15
|
+
* Client class that communicates with a centralized, federated server, when training
|
|
16
|
+
* a specific task in the federated setting.
|
|
17
|
+
*/
|
|
18
|
+
export class FederatedClient extends Client {
|
|
19
|
+
// Total number of other federated contributors, including this client, excluding the server
|
|
20
|
+
// E.g., if 3 users are training a federated model, nbOfParticipants is 3
|
|
21
|
+
#nbOfParticipants = 1;
|
|
22
|
+
/**
|
|
23
|
+
* When the server notifies clients to pause and wait until more
|
|
24
|
+
* participants join, we rely on this promise to wait
|
|
25
|
+
* until the server signals that the training can resume
|
|
26
|
+
*/
|
|
27
|
+
#promiseForMoreParticipants = undefined;
|
|
28
|
+
/**
|
|
29
|
+
* When the server notifies the client that they can resume training
|
|
30
|
+
* after waiting for more participants, we want to be able to display what
|
|
31
|
+
* we were doing before waiting (training locally or updating our model).
|
|
32
|
+
* We use this attribute to store the status to rollback to when we stop waiting
|
|
33
|
+
*/
|
|
34
|
+
#previousStatus = undefined;
|
|
35
|
+
/**
|
|
36
|
+
* Whether the client should wait until more
|
|
37
|
+
* participants join the session, i.e. a promise has been created
|
|
38
|
+
*/
|
|
39
|
+
get #waitingForMoreParticipants() {
|
|
40
|
+
return this.#promiseForMoreParticipants !== undefined;
|
|
41
|
+
}
|
|
42
|
+
// the number of participants excluding the server
|
|
43
|
+
get nbOfParticipants() {
|
|
44
|
+
return this.#nbOfParticipants;
|
|
45
|
+
}
|
|
46
|
+
/**
|
|
47
|
+
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
48
|
+
*/
|
|
49
|
+
async connectServer(url) {
|
|
50
|
+
const server = await WebSocketServer.connect(url, messages.isMessageFederated, // can only receive federated message types from the server
|
|
51
|
+
messages.isMessageFederated);
|
|
52
|
+
return server;
|
|
53
|
+
}
|
|
54
|
+
/**
|
|
55
|
+
* Initializes the connection to the server, gets our node ID
|
|
56
|
+
* as well as the latest training information: latest global model, current round and
|
|
57
|
+
* whether we are waiting for more participants.
|
|
58
|
+
*/
|
|
59
|
+
async connect() {
|
|
60
|
+
const model = await super.connect(); // Get the server base model
|
|
61
|
+
const serverURL = new URL("", this.url.href);
|
|
62
|
+
switch (this.url.protocol) {
|
|
63
|
+
case "http:":
|
|
64
|
+
serverURL.protocol = "ws:";
|
|
65
|
+
break;
|
|
66
|
+
case "https:":
|
|
67
|
+
serverURL.protocol = "wss:";
|
|
68
|
+
break;
|
|
69
|
+
default:
|
|
70
|
+
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
71
|
+
}
|
|
72
|
+
serverURL.pathname += `federated/${this.task.id}`;
|
|
73
|
+
this._server = await this.connectServer(serverURL);
|
|
74
|
+
// Setup an event callback if the server signals that we should
|
|
75
|
+
// wait for more participants
|
|
76
|
+
this.server.on(type.WaitingForMoreParticipants, () => {
|
|
77
|
+
debug(`[${id.slice(0, 4)}] received WaitingForMoreParticipants message from server`);
|
|
78
|
+
// Display the waiting status right away
|
|
79
|
+
this.emit("status", "Waiting for more participants");
|
|
80
|
+
// Upon receiving a WaitingForMoreParticipants message,
|
|
81
|
+
// the client will await for this promise to resolve before sending its
|
|
82
|
+
// local weight update
|
|
83
|
+
this.#promiseForMoreParticipants = this.waitForMoreParticipants();
|
|
84
|
+
});
|
|
85
|
+
// As an example assume we need at least 2 participants to train,
|
|
86
|
+
// When two participants join almost at the same time, the server
|
|
87
|
+
// sends a NewFederatedNodeInfo with waitForMoreParticipants=true to the first participant
|
|
88
|
+
// and directly follows with an EnoughParticipants message when the 2nd participant joins
|
|
89
|
+
// However, the EnoughParticipants can arrive before the NewFederatedNodeInfo (which is much bigger)
|
|
90
|
+
// so we check whether we received the EnoughParticipants before being assigned a node ID
|
|
91
|
+
let receivedEnoughParticipants = false;
|
|
92
|
+
this.server.once(type.EnoughParticipants, () => {
|
|
93
|
+
if (this._ownId === undefined) {
|
|
94
|
+
debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
|
|
95
|
+
receivedEnoughParticipants = true;
|
|
96
|
+
}
|
|
97
|
+
});
|
|
98
|
+
this.aggregator.registerNode(SERVER_NODE_ID);
|
|
99
|
+
const msg = {
|
|
100
|
+
type: type.ClientConnected,
|
|
101
|
+
};
|
|
102
|
+
this.server.send(msg);
|
|
103
|
+
const { id, waitForMoreParticipants, payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.NewFederatedNodeInfo);
|
|
104
|
+
// This should come right after receiving the message to make sure
|
|
105
|
+
// we don't miss a subsequent message from the server
|
|
106
|
+
// We check if the server is telling us to wait for more participants
|
|
107
|
+
// and we also check if a EnoughParticipant message ended up arriving
|
|
108
|
+
// before the NewFederatedNodeInfo
|
|
109
|
+
if (waitForMoreParticipants && !receivedEnoughParticipants) {
|
|
110
|
+
// Create a promise that resolves when enough participants join
|
|
111
|
+
// The client will await this promise before sending its local weight update
|
|
112
|
+
this.#promiseForMoreParticipants = this.waitForMoreParticipants();
|
|
113
|
+
}
|
|
114
|
+
if (this._ownId !== undefined) {
|
|
115
|
+
throw new Error('received id from server but was already received');
|
|
116
|
+
}
|
|
117
|
+
this._ownId = id;
|
|
118
|
+
debug(`[${id.slice(0, 4)}] joined session at round ${round} `);
|
|
119
|
+
this.aggregator.setRound(round);
|
|
120
|
+
this.#nbOfParticipants = nbOfParticipants;
|
|
121
|
+
// Upon connecting, the server answers with a boolean
|
|
122
|
+
// which indicates whether there are enough participants or not
|
|
123
|
+
debug(`[${this.ownId.slice(0, 4)}] upon connecting, wait for participant flag %o`, this.#waitingForMoreParticipants);
|
|
124
|
+
model.weights = serialization.weights.decode(payload);
|
|
125
|
+
return model;
|
|
126
|
+
}
|
|
127
|
+
/**
|
|
128
|
+
* Method called when the server notifies the client that there aren't enough
|
|
129
|
+
* participants (anymore) to start/continue training
|
|
130
|
+
* The method creates a promise that will resolve once the server notifies
|
|
131
|
+
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
132
|
+
* @returns a promise which resolves when enough participants joined the session
|
|
133
|
+
*/
|
|
134
|
+
async waitForMoreParticipants() {
|
|
135
|
+
return new Promise((resolve) => {
|
|
136
|
+
// "once" is important because we can't resolve the same promise multiple times
|
|
137
|
+
this.server.once(type.EnoughParticipants, () => {
|
|
138
|
+
debug(`[${this.ownId.slice(0, 4)}] received EnoughParticipants message from server`);
|
|
139
|
+
// Emit the last status emitted before waiting if defined
|
|
140
|
+
if (this.#previousStatus !== undefined)
|
|
141
|
+
this.emit("status", this.#previousStatus);
|
|
142
|
+
resolve();
|
|
143
|
+
});
|
|
144
|
+
});
|
|
145
|
+
}
|
|
146
|
+
/**
|
|
147
|
+
* Disconnection process when user quits the task.
|
|
148
|
+
*/
|
|
149
|
+
async disconnect() {
|
|
150
|
+
await this.server.disconnect();
|
|
151
|
+
this._server = undefined;
|
|
152
|
+
this._ownId = undefined;
|
|
153
|
+
this.aggregator.setNodes(this.aggregator.nodes.delete(SERVER_NODE_ID));
|
|
154
|
+
}
|
|
155
|
+
onRoundBeginCommunication() {
|
|
156
|
+
// Prepare the result promise for the incoming round
|
|
157
|
+
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
158
|
+
// Save the status in case participants leave and we switch to waiting for more participants
|
|
159
|
+
// Once enough new participants join we can display the previous status again
|
|
160
|
+
this.#previousStatus = "Training the model on the data you connected";
|
|
161
|
+
this.emit("status", this.#previousStatus);
|
|
162
|
+
return Promise.resolve();
|
|
163
|
+
}
|
|
164
|
+
/**
|
|
165
|
+
* Send the local weight update to the server and waits (indefinitely) for the server global update
|
|
166
|
+
*
|
|
167
|
+
* If the waitingForMoreParticipants flag is set, we first wait (also indefinitely) until the
|
|
168
|
+
* server notifies us that the training can resume.
|
|
169
|
+
*
|
|
170
|
+
// NB: For now, we suppose a fully-federated setting.
|
|
171
|
+
* @param weights Local weights sent to the server at the end of the local training round
|
|
172
|
+
* @returns the new global weights sent by the server
|
|
173
|
+
*/
|
|
174
|
+
async onRoundEndCommunication(weights) {
|
|
175
|
+
if (this.aggregationResult === undefined) {
|
|
176
|
+
throw new Error("local aggregation result was not set");
|
|
177
|
+
}
|
|
178
|
+
// First we check if we are waiting for more participants before sending our weight update
|
|
179
|
+
if (this.#waitingForMoreParticipants) {
|
|
180
|
+
// wait for the promise to resolve, which takes as long as it takes for new participants to join
|
|
181
|
+
debug(`[${this.ownId.slice(0, 4)}] is awaiting the promise for more participants`);
|
|
182
|
+
this.emit("status", "Waiting for more participants");
|
|
183
|
+
await this.#promiseForMoreParticipants;
|
|
184
|
+
// Make sure to set the promise back to undefined once resolved
|
|
185
|
+
this.#promiseForMoreParticipants = undefined;
|
|
186
|
+
}
|
|
187
|
+
// Save the status in case participants leave and we switch to waiting for more participants
|
|
188
|
+
// Once enough new participants join we can display the previous status again
|
|
189
|
+
this.#previousStatus = "Updating the model with other participants' models";
|
|
190
|
+
this.emit("status", this.#previousStatus);
|
|
191
|
+
// Send our local contribution to the server
|
|
192
|
+
// and receive the server global update for this round as an answer to our contribution
|
|
193
|
+
const payloadToServer = this.aggregator.makePayloads(weights).first();
|
|
194
|
+
const msg = {
|
|
195
|
+
type: type.SendPayload,
|
|
196
|
+
payload: await serialization.weights.encode(payloadToServer),
|
|
197
|
+
round: this.aggregator.round,
|
|
198
|
+
};
|
|
199
|
+
// Need to await the resulting global model right after sending our local contribution
|
|
200
|
+
// to make sure we don't miss it
|
|
201
|
+
debug(`[${this.ownId.slice(0, 4)}] sent its local update to the server for round ${this.aggregator.round}`);
|
|
202
|
+
this.server.send(msg);
|
|
203
|
+
debug(`[${this.ownId.slice(0, 4)}] is waiting for server update for round ${this.aggregator.round + 1}`);
|
|
204
|
+
const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
|
|
205
|
+
this.#nbOfParticipants = nbOfParticipants; // Save the current participants
|
|
206
|
+
const serverResult = serialization.weights.decode(payloadFromServer);
|
|
207
|
+
this.aggregator.setRound(serverRound);
|
|
208
|
+
return serverResult;
|
|
209
|
+
}
|
|
210
|
+
}
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export {
|
|
1
|
+
export { FederatedClient } from './federated_client.js';
|
|
2
2
|
export * as messages from './messages.js';
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export {
|
|
1
|
+
export { FederatedClient } from './federated_client.js';
|
|
2
2
|
export * as messages from './messages.js';
|
|
@@ -1,6 +1,15 @@
|
|
|
1
1
|
import { type weights } from '../../serialization/index.js';
|
|
2
|
-
import { type
|
|
3
|
-
|
|
2
|
+
import { type NodeID } from '..//types.js';
|
|
3
|
+
import { type, type ClientConnected } from '../messages.js';
|
|
4
|
+
export type MessageFederated = ClientConnected | NewFederatedNodeInfo | SendPayload | ReceiveServerPayload | WaitingForMoreParticipants | EnoughParticipants;
|
|
5
|
+
export interface NewFederatedNodeInfo {
|
|
6
|
+
type: type.NewFederatedNodeInfo;
|
|
7
|
+
id: NodeID;
|
|
8
|
+
waitForMoreParticipants: boolean;
|
|
9
|
+
payload: weights.Encoded;
|
|
10
|
+
round: number;
|
|
11
|
+
nbOfParticipants: number;
|
|
12
|
+
}
|
|
4
13
|
export interface SendPayload {
|
|
5
14
|
type: type.SendPayload;
|
|
6
15
|
payload: weights.Encoded;
|
|
@@ -12,4 +21,10 @@ export interface ReceiveServerPayload {
|
|
|
12
21
|
round: number;
|
|
13
22
|
nbOfParticipants: number;
|
|
14
23
|
}
|
|
24
|
+
export interface EnoughParticipants {
|
|
25
|
+
type: type.EnoughParticipants;
|
|
26
|
+
}
|
|
27
|
+
export interface WaitingForMoreParticipants {
|
|
28
|
+
type: type.WaitingForMoreParticipants;
|
|
29
|
+
}
|
|
15
30
|
export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
|
|
@@ -5,9 +5,11 @@ export function isMessageFederated(raw) {
|
|
|
5
5
|
}
|
|
6
6
|
switch (raw.type) {
|
|
7
7
|
case type.ClientConnected:
|
|
8
|
+
case type.NewFederatedNodeInfo:
|
|
8
9
|
case type.SendPayload:
|
|
9
10
|
case type.ReceiveServerPayload:
|
|
10
|
-
case type.
|
|
11
|
+
case type.WaitingForMoreParticipants:
|
|
12
|
+
case type.EnoughParticipants:
|
|
11
13
|
return true;
|
|
12
14
|
}
|
|
13
15
|
return false;
|
package/dist/client/index.d.ts
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
export {
|
|
1
|
+
export { Client } from './client.js';
|
|
2
2
|
export * from './types.js';
|
|
3
3
|
export * as aggregator from '../aggregator/index.js';
|
|
4
4
|
export * as decentralized from './decentralized/index.js';
|
|
5
5
|
export * as federated from './federated/index.js';
|
|
6
6
|
export * as messages from './messages.js';
|
|
7
7
|
export { getClient, timeout } from './utils.js';
|
|
8
|
-
export {
|
|
8
|
+
export { LocalClient } from './local_client.js';
|
package/dist/client/index.js
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
export {
|
|
1
|
+
export { Client } from './client.js';
|
|
2
2
|
export * from './types.js';
|
|
3
3
|
export * as aggregator from '../aggregator/index.js';
|
|
4
4
|
export * as decentralized from './decentralized/index.js';
|
|
5
5
|
export * as federated from './federated/index.js';
|
|
6
6
|
export * as messages from './messages.js';
|
|
7
7
|
export { getClient, timeout } from './utils.js';
|
|
8
|
-
export {
|
|
8
|
+
export { LocalClient } from './local_client.js';
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
import { WeightsContainer } from "../index.js";
|
|
2
|
+
import { Client } from "./client.js";
|
|
3
|
+
/**
|
|
4
|
+
* A LocalClient represents a Disco user training only on their local data without collaborating
|
|
5
|
+
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
6
|
+
*/
|
|
7
|
+
export declare class LocalClient extends Client {
|
|
8
|
+
onRoundBeginCommunication(): Promise<void>;
|
|
9
|
+
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
10
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import { Client } from "./client.js";
|
|
2
|
+
/**
|
|
3
|
+
* A LocalClient represents a Disco user training only on their local data without collaborating
|
|
4
|
+
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
5
|
+
*/
|
|
6
|
+
export class LocalClient extends Client {
|
|
7
|
+
onRoundBeginCommunication() {
|
|
8
|
+
return Promise.resolve();
|
|
9
|
+
}
|
|
10
|
+
// Simply return the local weights
|
|
11
|
+
onRoundEndCommunication(weights) {
|
|
12
|
+
return Promise.resolve(weights);
|
|
13
|
+
}
|
|
14
|
+
}
|
|
@@ -1,23 +1,21 @@
|
|
|
1
1
|
import type * as decentralized from './decentralized/messages.js';
|
|
2
2
|
import type * as federated from './federated/messages.js';
|
|
3
|
-
import { type NodeID } from './types.js';
|
|
4
3
|
export declare enum type {
|
|
5
4
|
ClientConnected = 0,
|
|
6
|
-
|
|
5
|
+
NewDecentralizedNodeInfo = 1,
|
|
7
6
|
SignalForPeer = 2,
|
|
8
7
|
PeerIsReady = 3,
|
|
9
8
|
PeersForRound = 4,
|
|
10
9
|
Payload = 5,
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
NewFederatedNodeInfo = 6,
|
|
11
|
+
WaitingForMoreParticipants = 7,
|
|
12
|
+
EnoughParticipants = 8,
|
|
13
|
+
SendPayload = 9,
|
|
14
|
+
ReceiveServerPayload = 10
|
|
13
15
|
}
|
|
14
16
|
export interface ClientConnected {
|
|
15
17
|
type: type.ClientConnected;
|
|
16
18
|
}
|
|
17
|
-
export interface AssignNodeID {
|
|
18
|
-
type: type.AssignNodeID;
|
|
19
|
-
id: NodeID;
|
|
20
|
-
}
|
|
21
19
|
export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
|
|
22
20
|
export type NarrowMessage<D> = Extract<Message, {
|
|
23
21
|
type: D;
|
package/dist/client/messages.js
CHANGED
|
@@ -1,21 +1,37 @@
|
|
|
1
1
|
export var type;
|
|
2
2
|
(function (type) {
|
|
3
3
|
// Sent from client to server as first point of contact to join a task.
|
|
4
|
-
// The server answers with an node id in a
|
|
4
|
+
// The server answers with an node id in a NewFederatedNodeInfo
|
|
5
|
+
// or NewDecentralizedNodeInfo message
|
|
5
6
|
type[type["ClientConnected"] = 0] = "ClientConnected";
|
|
6
|
-
// When a user joins a task with a ClientConnected message, the server
|
|
7
|
-
// answers with an AssignNodeID message with its peer id.
|
|
8
|
-
type[type["AssignNodeID"] = 1] = "AssignNodeID";
|
|
9
7
|
/* Decentralized */
|
|
8
|
+
// When a user joins a task with a ClientConnected message, the server
|
|
9
|
+
// answers with its peer id and also tells the client whether we are waiting
|
|
10
|
+
// for more participants before starting training
|
|
11
|
+
type[type["NewDecentralizedNodeInfo"] = 1] = "NewDecentralizedNodeInfo";
|
|
10
12
|
// Message forwarded by the server from a client to another client
|
|
11
13
|
// to establish a peer-to-peer (WebRTC) connection
|
|
12
14
|
type[type["SignalForPeer"] = 2] = "SignalForPeer";
|
|
15
|
+
// Message sent by nodes to server signaling they are ready to
|
|
16
|
+
// start the next round
|
|
13
17
|
type[type["PeerIsReady"] = 3] = "PeerIsReady";
|
|
18
|
+
// Sent by the server to participating peers containing the list
|
|
19
|
+
// of peers for the round
|
|
14
20
|
type[type["PeersForRound"] = 4] = "PeersForRound";
|
|
21
|
+
// The weight update
|
|
15
22
|
type[type["Payload"] = 5] = "Payload";
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
23
|
+
/* Federated */
|
|
24
|
+
// The server answers the ClientConnected message with the necessary information
|
|
25
|
+
// to start training: node id, latest model global weights, current round etc
|
|
26
|
+
type[type["NewFederatedNodeInfo"] = 6] = "NewFederatedNodeInfo";
|
|
27
|
+
// Message sent by server to notify clients that there are not enough
|
|
28
|
+
// participants to continue training
|
|
29
|
+
type[type["WaitingForMoreParticipants"] = 7] = "WaitingForMoreParticipants";
|
|
30
|
+
// Message sent by server to notify clients that there are now enough
|
|
31
|
+
// participants to start training collaboratively
|
|
32
|
+
type[type["EnoughParticipants"] = 8] = "EnoughParticipants";
|
|
33
|
+
type[type["SendPayload"] = 9] = "SendPayload";
|
|
34
|
+
type[type["ReceiveServerPayload"] = 10] = "ReceiveServerPayload";
|
|
19
35
|
})(type || (type = {}));
|
|
20
36
|
export function hasMessageType(raw) {
|
|
21
37
|
if (typeof raw !== 'object' || raw === null) {
|
package/dist/client/utils.js
CHANGED
|
@@ -13,7 +13,7 @@ export function getClient(trainingScheme, serverURL, task, aggregator) {
|
|
|
13
13
|
case 'federated':
|
|
14
14
|
return new clients.federated.FederatedClient(serverURL, task, aggregator);
|
|
15
15
|
case 'local':
|
|
16
|
-
return new clients.
|
|
16
|
+
return new clients.LocalClient(serverURL, task, aggregator);
|
|
17
17
|
default: {
|
|
18
18
|
const _ = trainingScheme;
|
|
19
19
|
throw new Error('should never happen');
|
|
@@ -28,7 +28,8 @@ export const simpleFace = {
|
|
|
28
28
|
IMAGE_H: 200,
|
|
29
29
|
IMAGE_W: 200,
|
|
30
30
|
LABEL_LIST: ['child', 'adult'],
|
|
31
|
-
scheme: 'federated',
|
|
31
|
+
scheme: 'federated',
|
|
32
|
+
minNbOfParticipants: 2,
|
|
32
33
|
tensorBackend: 'tfjs'
|
|
33
34
|
}
|
|
34
35
|
};
|
|
@@ -26,6 +26,7 @@ export const wikitext = {
|
|
|
26
26
|
modelID: 'llm-raw-model',
|
|
27
27
|
preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
|
|
28
28
|
scheme: 'federated',
|
|
29
|
+
minNbOfParticipants: 2,
|
|
29
30
|
epochs: 6,
|
|
30
31
|
// Unused by wikitext because data already comes split
|
|
31
32
|
// But if set to 0 then the webapp doesn't display the validation metrics
|
package/dist/index.d.ts
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
export * as data from './dataset/index.js';
|
|
2
2
|
export * as serialization from './serialization/index.js';
|
|
3
|
+
export { Encoded as EncodedModel } from './serialization/model.js';
|
|
4
|
+
export { Encoded as EncodedWeights } from './serialization/weights.js';
|
|
3
5
|
export * as training from './training/index.js';
|
|
4
6
|
export * as privacy from './privacy.js';
|
|
5
7
|
export * as client from './client/index.js';
|
|
@@ -7,13 +9,14 @@ export * as aggregator from './aggregator/index.js';
|
|
|
7
9
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
10
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
9
11
|
export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
|
|
10
|
-
export { Disco, RoundLogs } from './training/index.js';
|
|
12
|
+
export { Disco, RoundLogs, RoundStatus } from './training/index.js';
|
|
11
13
|
export { Validator } from './validation/index.js';
|
|
12
14
|
export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
|
|
13
15
|
export * as models from './models/index.js';
|
|
14
16
|
export * from './task/index.js';
|
|
15
17
|
export * as defaultTasks from './default_tasks/index.js';
|
|
16
18
|
export * as async_iterator from "./utils/async_iterator.js";
|
|
19
|
+
export { EventEmitter } from "./utils/event_emitter.js";
|
|
17
20
|
export { Dataset } from "./dataset/index.js";
|
|
18
21
|
export * from "./dataset/types.js";
|
|
19
22
|
export * from "./types.js";
|
package/dist/index.js
CHANGED
|
@@ -14,6 +14,7 @@ export * as models from './models/index.js';
|
|
|
14
14
|
export * from './task/index.js';
|
|
15
15
|
export * as defaultTasks from './default_tasks/index.js';
|
|
16
16
|
export * as async_iterator from "./utils/async_iterator.js";
|
|
17
|
+
export { EventEmitter } from "./utils/event_emitter.js";
|
|
17
18
|
export { Dataset } from "./dataset/index.js";
|
|
18
19
|
export * from "./dataset/types.js"; // TODO merge with above
|
|
19
20
|
export * from "./types.js";
|
package/dist/logging/logger.d.ts
CHANGED
package/dist/task/index.d.ts
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
export { isTask, type Task, isTaskID, type TaskID } from './task.js';
|
|
2
2
|
export { type TaskProvider } from './task_provider.js';
|
|
3
|
-
export { isDigest, type Digest } from './digest.js';
|
|
4
3
|
export { isDisplayInformation, type DisplayInformation } from './display_information.js';
|
|
5
4
|
export type { TrainingInformation } from './training_information.js';
|
|
6
5
|
export { pushTask, fetchTasks } from './task_handler.js';
|
package/dist/task/index.js
CHANGED
package/dist/task/task.d.ts
CHANGED
|
@@ -1,10 +1,8 @@
|
|
|
1
1
|
import { type DisplayInformation } from './display_information.js';
|
|
2
2
|
import { type TrainingInformation } from './training_information.js';
|
|
3
|
-
import { type Digest } from './digest.js';
|
|
4
3
|
export type TaskID = string;
|
|
5
4
|
export interface Task {
|
|
6
5
|
id: TaskID;
|
|
7
|
-
digest?: Digest;
|
|
8
6
|
displayInformation: DisplayInformation;
|
|
9
7
|
trainingInformation: TrainingInformation;
|
|
10
8
|
}
|
package/dist/task/task.js
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import { isDisplayInformation } from './display_information.js';
|
|
2
2
|
import { isTrainingInformation } from './training_information.js';
|
|
3
|
-
import { isDigest } from './digest.js';
|
|
4
3
|
export function isTaskID(obj) {
|
|
5
4
|
return typeof obj === 'string';
|
|
6
5
|
}
|
|
@@ -8,14 +7,13 @@ export function isTask(raw) {
|
|
|
8
7
|
if (typeof raw !== 'object' || raw === null) {
|
|
9
8
|
return false;
|
|
10
9
|
}
|
|
11
|
-
const { id,
|
|
10
|
+
const { id, displayInformation, trainingInformation } = raw;
|
|
12
11
|
if (!isTaskID(id) ||
|
|
13
|
-
(digest !== undefined && !isDigest(digest)) ||
|
|
14
12
|
!isDisplayInformation(displayInformation) ||
|
|
15
13
|
!isTrainingInformation(trainingInformation)) {
|
|
16
14
|
return false;
|
|
17
15
|
}
|
|
18
|
-
const repack = { id,
|
|
16
|
+
const repack = { id, displayInformation, trainingInformation };
|
|
19
17
|
const _correct = repack;
|
|
20
18
|
const _total = repack;
|
|
21
19
|
return true;
|
|
@@ -21,7 +21,7 @@ export interface TrainingInformation {
|
|
|
21
21
|
privacy?: Privacy;
|
|
22
22
|
decentralizedSecure?: boolean;
|
|
23
23
|
maxShareValue?: number;
|
|
24
|
-
|
|
24
|
+
minNbOfParticipants: number;
|
|
25
25
|
aggregator?: 'mean' | 'secure';
|
|
26
26
|
tokenizer?: string | PreTrainedTokenizer;
|
|
27
27
|
maxSequenceLength?: number;
|
|
@@ -24,20 +24,20 @@ export function isTrainingInformation(raw) {
|
|
|
24
24
|
if (typeof raw !== 'object' || raw === null) {
|
|
25
25
|
return false;
|
|
26
26
|
}
|
|
27
|
-
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue,
|
|
27
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
|
|
28
28
|
if (typeof dataType !== 'string' ||
|
|
29
29
|
typeof modelID !== 'string' ||
|
|
30
30
|
typeof epochs !== 'number' ||
|
|
31
31
|
typeof batchSize !== 'number' ||
|
|
32
32
|
typeof roundDuration !== 'number' ||
|
|
33
33
|
typeof validationSplit !== 'number' ||
|
|
34
|
+
typeof minNbOfParticipants !== 'number' ||
|
|
34
35
|
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
35
36
|
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
36
37
|
(aggregator !== undefined && typeof aggregator !== 'string') ||
|
|
37
38
|
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
38
39
|
(privacy !== undefined && !isPrivacy(privacy)) ||
|
|
39
40
|
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
40
|
-
(minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
|
|
41
41
|
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
42
42
|
(IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
|
|
43
43
|
(LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
|
|
@@ -96,7 +96,7 @@ export function isTrainingInformation(raw) {
|
|
|
96
96
|
epochs,
|
|
97
97
|
inputColumns,
|
|
98
98
|
maxShareValue,
|
|
99
|
-
|
|
99
|
+
minNbOfParticipants,
|
|
100
100
|
modelID,
|
|
101
101
|
outputColumns,
|
|
102
102
|
preprocessingFunctions,
|