@epfml/discojs 3.0.1-p20241007204240.0 → 3.0.1-p20241024094708.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
- package/dist/aggregator/{base.js → aggregator.js} +48 -36
- package/dist/aggregator/get.d.ts +2 -2
- package/dist/aggregator/get.js +4 -4
- package/dist/aggregator/index.d.ts +1 -4
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +4 -4
- package/dist/aggregator/mean.js +5 -15
- package/dist/aggregator/secure.d.ts +4 -4
- package/dist/aggregator/secure.js +7 -17
- package/dist/client/client.d.ts +71 -17
- package/dist/client/client.js +118 -17
- package/dist/client/decentralized/decentralized_client.d.ts +11 -13
- package/dist/client/decentralized/decentralized_client.js +121 -84
- package/dist/client/decentralized/messages.d.ts +12 -6
- package/dist/client/decentralized/messages.js +9 -8
- package/dist/client/event_connection.js +2 -2
- package/dist/client/federated/federated_client.d.ts +1 -13
- package/dist/client/federated/federated_client.js +15 -94
- package/dist/client/federated/messages.d.ts +6 -11
- package/dist/client/local_client.d.ts +1 -0
- package/dist/client/local_client.js +3 -0
- package/dist/client/messages.d.ts +14 -7
- package/dist/client/messages.js +13 -11
- package/dist/default_tasks/cifar10.js +1 -1
- package/dist/default_tasks/lus_covid.js +1 -0
- package/dist/default_tasks/mnist.js +1 -1
- package/dist/default_tasks/simple_face.js +1 -0
- package/dist/default_tasks/titanic.js +1 -0
- package/dist/default_tasks/wikitext.js +1 -0
- package/dist/index.d.ts +0 -2
- package/dist/serialization/coder.d.ts +4 -0
- package/dist/serialization/coder.js +51 -0
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +1 -0
- package/dist/serialization/model.d.ts +1 -2
- package/dist/serialization/model.js +9 -24
- package/dist/serialization/weights.d.ts +2 -3
- package/dist/serialization/weights.js +15 -26
- package/dist/task/task_handler.d.ts +5 -5
- package/dist/task/task_handler.js +21 -15
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +6 -8
- package/dist/training/disco.d.ts +4 -1
- package/dist/training/trainer.js +1 -1
- package/dist/utils/event_emitter.d.ts +3 -3
- package/dist/utils/event_emitter.js +10 -9
- package/package.json +2 -3
|
@@ -6,25 +6,13 @@ import { Client } from "../client.js";
|
|
|
6
6
|
*/
|
|
7
7
|
export declare class FederatedClient extends Client {
|
|
8
8
|
#private;
|
|
9
|
-
|
|
10
|
-
/**
|
|
11
|
-
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
12
|
-
*/
|
|
13
|
-
private connectServer;
|
|
9
|
+
getNbOfParticipants(): number;
|
|
14
10
|
/**
|
|
15
11
|
* Initializes the connection to the server, gets our node ID
|
|
16
12
|
* as well as the latest training information: latest global model, current round and
|
|
17
13
|
* whether we are waiting for more participants.
|
|
18
14
|
*/
|
|
19
15
|
connect(): Promise<Model>;
|
|
20
|
-
/**
|
|
21
|
-
* Method called when the server notifies the client that there aren't enough
|
|
22
|
-
* participants (anymore) to start/continue training
|
|
23
|
-
* The method creates a promise that will resolve once the server notifies
|
|
24
|
-
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
25
|
-
* @returns a promise which resolves when enough participants joined the session
|
|
26
|
-
*/
|
|
27
|
-
private waitForMoreParticipants;
|
|
28
16
|
/**
|
|
29
17
|
* Disconnection process when user quits the task.
|
|
30
18
|
*/
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import { serialization } from "../../index.js";
|
|
3
|
-
import { Client } from "../client.js";
|
|
3
|
+
import { Client, shortenId } from "../client.js";
|
|
4
4
|
import { type } from "../messages.js";
|
|
5
5
|
import { waitMessage, waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
|
|
6
6
|
import * as messages from "./messages.js";
|
|
@@ -19,38 +19,10 @@ export class FederatedClient extends Client {
|
|
|
19
19
|
// Total number of other federated contributors, including this client, excluding the server
|
|
20
20
|
// E.g., if 3 users are training a federated model, nbOfParticipants is 3
|
|
21
21
|
#nbOfParticipants = 1;
|
|
22
|
-
/**
|
|
23
|
-
* When the server notifies clients to pause and wait until more
|
|
24
|
-
* participants join, we rely on this promise to wait
|
|
25
|
-
* until the server signals that the training can resume
|
|
26
|
-
*/
|
|
27
|
-
#promiseForMoreParticipants = undefined;
|
|
28
|
-
/**
|
|
29
|
-
* When the server notifies the client that they can resume training
|
|
30
|
-
* after waiting for more participants, we want to be able to display what
|
|
31
|
-
* we were doing before waiting (training locally or updating our model).
|
|
32
|
-
* We use this attribute to store the status to rollback to when we stop waiting
|
|
33
|
-
*/
|
|
34
|
-
#previousStatus = undefined;
|
|
35
|
-
/**
|
|
36
|
-
* Whether the client should wait until more
|
|
37
|
-
* participants join the session, i.e. a promise has been created
|
|
38
|
-
*/
|
|
39
|
-
get #waitingForMoreParticipants() {
|
|
40
|
-
return this.#promiseForMoreParticipants !== undefined;
|
|
41
|
-
}
|
|
42
22
|
// the number of participants excluding the server
|
|
43
|
-
|
|
23
|
+
getNbOfParticipants() {
|
|
44
24
|
return this.#nbOfParticipants;
|
|
45
25
|
}
|
|
46
|
-
/**
|
|
47
|
-
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
48
|
-
*/
|
|
49
|
-
async connectServer(url) {
|
|
50
|
-
const server = await WebSocketServer.connect(url, messages.isMessageFederated, // can only receive federated message types from the server
|
|
51
|
-
messages.isMessageFederated);
|
|
52
|
-
return server;
|
|
53
|
-
}
|
|
54
26
|
/**
|
|
55
27
|
* Initializes the connection to the server, gets our node ID
|
|
56
28
|
* as well as the latest training information: latest global model, current round and
|
|
@@ -70,31 +42,12 @@ export class FederatedClient extends Client {
|
|
|
70
42
|
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
71
43
|
}
|
|
72
44
|
serverURL.pathname += `federated/${this.task.id}`;
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
debug(`[${id.slice(0, 4)}] received WaitingForMoreParticipants message from server`);
|
|
78
|
-
// Display the waiting status right away
|
|
79
|
-
this.emit("status", "Waiting for more participants");
|
|
80
|
-
// Upon receiving a WaitingForMoreParticipants message,
|
|
81
|
-
// the client will await for this promise to resolve before sending its
|
|
82
|
-
// local weight update
|
|
83
|
-
this.#promiseForMoreParticipants = this.waitForMoreParticipants();
|
|
84
|
-
});
|
|
85
|
-
// As an example assume we need at least 2 participants to train,
|
|
86
|
-
// When two participants join almost at the same time, the server
|
|
87
|
-
// sends a NewFederatedNodeInfo with waitForMoreParticipants=true to the first participant
|
|
88
|
-
// and directly follows with an EnoughParticipants message when the 2nd participant joins
|
|
89
|
-
// However, the EnoughParticipants can arrive before the NewFederatedNodeInfo (which is much bigger)
|
|
90
|
-
// so we check whether we received the EnoughParticipants before being assigned a node ID
|
|
45
|
+
// Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
46
|
+
this._server = await WebSocketServer.connect(serverURL, messages.isMessageFederated, // can only receive federated message types from the server
|
|
47
|
+
messages.isMessageFederated);
|
|
48
|
+
// c.f. setupServerCallbacks doc for explanation
|
|
91
49
|
let receivedEnoughParticipants = false;
|
|
92
|
-
this.
|
|
93
|
-
if (this._ownId === undefined) {
|
|
94
|
-
debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
|
|
95
|
-
receivedEnoughParticipants = true;
|
|
96
|
-
}
|
|
97
|
-
});
|
|
50
|
+
this.setupServerCallbacks(() => receivedEnoughParticipants = true);
|
|
98
51
|
this.aggregator.registerNode(SERVER_NODE_ID);
|
|
99
52
|
const msg = {
|
|
100
53
|
type: type.ClientConnected,
|
|
@@ -109,40 +62,21 @@ export class FederatedClient extends Client {
|
|
|
109
62
|
if (waitForMoreParticipants && !receivedEnoughParticipants) {
|
|
110
63
|
// Create a promise that resolves when enough participants join
|
|
111
64
|
// The client will await this promise before sending its local weight update
|
|
112
|
-
this
|
|
65
|
+
this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
|
|
113
66
|
}
|
|
114
67
|
if (this._ownId !== undefined) {
|
|
115
68
|
throw new Error('received id from server but was already received');
|
|
116
69
|
}
|
|
117
70
|
this._ownId = id;
|
|
118
|
-
debug(`[${id
|
|
71
|
+
debug(`[${shortenId(id)}] joined session at round ${round} `);
|
|
119
72
|
this.aggregator.setRound(round);
|
|
120
73
|
this.#nbOfParticipants = nbOfParticipants;
|
|
121
74
|
// Upon connecting, the server answers with a boolean
|
|
122
75
|
// which indicates whether there are enough participants or not
|
|
123
|
-
debug(`[${this.ownId
|
|
76
|
+
debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants);
|
|
124
77
|
model.weights = serialization.weights.decode(payload);
|
|
125
78
|
return model;
|
|
126
79
|
}
|
|
127
|
-
/**
|
|
128
|
-
* Method called when the server notifies the client that there aren't enough
|
|
129
|
-
* participants (anymore) to start/continue training
|
|
130
|
-
* The method creates a promise that will resolve once the server notifies
|
|
131
|
-
* the client that the training can resume via a subsequent EnoughParticipants message
|
|
132
|
-
* @returns a promise which resolves when enough participants joined the session
|
|
133
|
-
*/
|
|
134
|
-
async waitForMoreParticipants() {
|
|
135
|
-
return new Promise((resolve) => {
|
|
136
|
-
// "once" is important because we can't resolve the same promise multiple times
|
|
137
|
-
this.server.once(type.EnoughParticipants, () => {
|
|
138
|
-
debug(`[${this.ownId.slice(0, 4)}] received EnoughParticipants message from server`);
|
|
139
|
-
// Emit the last status emitted before waiting if defined
|
|
140
|
-
if (this.#previousStatus !== undefined)
|
|
141
|
-
this.emit("status", this.#previousStatus);
|
|
142
|
-
resolve();
|
|
143
|
-
});
|
|
144
|
-
});
|
|
145
|
-
}
|
|
146
80
|
/**
|
|
147
81
|
* Disconnection process when user quits the task.
|
|
148
82
|
*/
|
|
@@ -155,10 +89,7 @@ export class FederatedClient extends Client {
|
|
|
155
89
|
onRoundBeginCommunication() {
|
|
156
90
|
// Prepare the result promise for the incoming round
|
|
157
91
|
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
158
|
-
|
|
159
|
-
// Once enough new participants join we can display the previous status again
|
|
160
|
-
this.#previousStatus = "Training the model on the data you connected";
|
|
161
|
-
this.emit("status", this.#previousStatus);
|
|
92
|
+
this.saveAndEmit("local training");
|
|
162
93
|
return Promise.resolve();
|
|
163
94
|
}
|
|
164
95
|
/**
|
|
@@ -176,18 +107,8 @@ export class FederatedClient extends Client {
|
|
|
176
107
|
throw new Error("local aggregation result was not set");
|
|
177
108
|
}
|
|
178
109
|
// First we check if we are waiting for more participants before sending our weight update
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
debug(`[${this.ownId.slice(0, 4)}] is awaiting the promise for more participants`);
|
|
182
|
-
this.emit("status", "Waiting for more participants");
|
|
183
|
-
await this.#promiseForMoreParticipants;
|
|
184
|
-
// Make sure to set the promise back to undefined once resolved
|
|
185
|
-
this.#promiseForMoreParticipants = undefined;
|
|
186
|
-
}
|
|
187
|
-
// Save the status in case participants leave and we switch to waiting for more participants
|
|
188
|
-
// Once enough new participants join we can display the previous status again
|
|
189
|
-
this.#previousStatus = "Updating the model with other participants' models";
|
|
190
|
-
this.emit("status", this.#previousStatus);
|
|
110
|
+
await this.waitForParticipantsIfNeeded();
|
|
111
|
+
this.saveAndEmit("updating model");
|
|
191
112
|
// Send our local contribution to the server
|
|
192
113
|
// and receive the server global update for this round as an answer to our contribution
|
|
193
114
|
const payloadToServer = this.aggregator.makePayloads(weights).first();
|
|
@@ -198,9 +119,9 @@ export class FederatedClient extends Client {
|
|
|
198
119
|
};
|
|
199
120
|
// Need to await the resulting global model right after sending our local contribution
|
|
200
121
|
// to make sure we don't miss it
|
|
201
|
-
debug(`[${this.ownId
|
|
122
|
+
debug(`[${shortenId(this.ownId)}] sent its local update to the server for round ${this.aggregator.round}`);
|
|
202
123
|
this.server.send(msg);
|
|
203
|
-
debug(`[${this.ownId
|
|
124
|
+
debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`);
|
|
204
125
|
const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
|
|
205
126
|
this.#nbOfParticipants = nbOfParticipants; // Save the current participants
|
|
206
127
|
const serverResult = serialization.weights.decode(payloadFromServer);
|
|
@@ -1,30 +1,25 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import type { serialization } from "../../index.js";
|
|
2
2
|
import { type NodeID } from '..//types.js';
|
|
3
|
-
import { type
|
|
3
|
+
import { type } from '../messages.js';
|
|
4
|
+
import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
|
|
4
5
|
export type MessageFederated = ClientConnected | NewFederatedNodeInfo | SendPayload | ReceiveServerPayload | WaitingForMoreParticipants | EnoughParticipants;
|
|
5
6
|
export interface NewFederatedNodeInfo {
|
|
6
7
|
type: type.NewFederatedNodeInfo;
|
|
7
8
|
id: NodeID;
|
|
8
9
|
waitForMoreParticipants: boolean;
|
|
9
|
-
payload:
|
|
10
|
+
payload: serialization.Encoded;
|
|
10
11
|
round: number;
|
|
11
12
|
nbOfParticipants: number;
|
|
12
13
|
}
|
|
13
14
|
export interface SendPayload {
|
|
14
15
|
type: type.SendPayload;
|
|
15
|
-
payload:
|
|
16
|
+
payload: serialization.Encoded;
|
|
16
17
|
round: number;
|
|
17
18
|
}
|
|
18
19
|
export interface ReceiveServerPayload {
|
|
19
20
|
type: type.ReceiveServerPayload;
|
|
20
|
-
payload:
|
|
21
|
+
payload: serialization.Encoded;
|
|
21
22
|
round: number;
|
|
22
23
|
nbOfParticipants: number;
|
|
23
24
|
}
|
|
24
|
-
export interface EnoughParticipants {
|
|
25
|
-
type: type.EnoughParticipants;
|
|
26
|
-
}
|
|
27
|
-
export interface WaitingForMoreParticipants {
|
|
28
|
-
type: type.WaitingForMoreParticipants;
|
|
29
|
-
}
|
|
30
25
|
export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
|
|
@@ -5,6 +5,7 @@ import { Client } from "./client.js";
|
|
|
5
5
|
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
6
6
|
*/
|
|
7
7
|
export declare class LocalClient extends Client {
|
|
8
|
+
getNbOfParticipants(): number;
|
|
8
9
|
onRoundBeginCommunication(): Promise<void>;
|
|
9
10
|
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
10
11
|
}
|
|
@@ -4,6 +4,9 @@ import { Client } from "./client.js";
|
|
|
4
4
|
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
5
5
|
*/
|
|
6
6
|
export class LocalClient extends Client {
|
|
7
|
+
getNbOfParticipants() {
|
|
8
|
+
return 1;
|
|
9
|
+
}
|
|
7
10
|
onRoundBeginCommunication() {
|
|
8
11
|
return Promise.resolve();
|
|
9
12
|
}
|
|
@@ -3,19 +3,26 @@ import type * as federated from './federated/messages.js';
|
|
|
3
3
|
export declare enum type {
|
|
4
4
|
ClientConnected = 0,
|
|
5
5
|
NewDecentralizedNodeInfo = 1,
|
|
6
|
-
|
|
6
|
+
JoinRound = 2,
|
|
7
7
|
PeerIsReady = 3,
|
|
8
8
|
PeersForRound = 4,
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
9
|
+
SignalForPeer = 5,
|
|
10
|
+
Payload = 6,
|
|
11
|
+
NewFederatedNodeInfo = 7,
|
|
12
|
+
WaitingForMoreParticipants = 8,
|
|
13
|
+
EnoughParticipants = 9,
|
|
14
|
+
SendPayload = 10,
|
|
15
|
+
ReceiveServerPayload = 11
|
|
15
16
|
}
|
|
16
17
|
export interface ClientConnected {
|
|
17
18
|
type: type.ClientConnected;
|
|
18
19
|
}
|
|
20
|
+
export interface EnoughParticipants {
|
|
21
|
+
type: type.EnoughParticipants;
|
|
22
|
+
}
|
|
23
|
+
export interface WaitingForMoreParticipants {
|
|
24
|
+
type: type.WaitingForMoreParticipants;
|
|
25
|
+
}
|
|
19
26
|
export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
|
|
20
27
|
export type NarrowMessage<D> = Extract<Message, {
|
|
21
28
|
type: D;
|
package/dist/client/messages.js
CHANGED
|
@@ -9,34 +9,36 @@ export var type;
|
|
|
9
9
|
// answers with its peer id and also tells the client whether we are waiting
|
|
10
10
|
// for more participants before starting training
|
|
11
11
|
type[type["NewDecentralizedNodeInfo"] = 1] = "NewDecentralizedNodeInfo";
|
|
12
|
-
// Message
|
|
13
|
-
//
|
|
14
|
-
type[type["
|
|
12
|
+
// Message sent by peers to the server to signal they want to
|
|
13
|
+
// join the next round
|
|
14
|
+
type[type["JoinRound"] = 2] = "JoinRound";
|
|
15
15
|
// Message sent by nodes to server signaling they are ready to
|
|
16
16
|
// start the next round
|
|
17
17
|
type[type["PeerIsReady"] = 3] = "PeerIsReady";
|
|
18
18
|
// Sent by the server to participating peers containing the list
|
|
19
19
|
// of peers for the round
|
|
20
20
|
type[type["PeersForRound"] = 4] = "PeersForRound";
|
|
21
|
+
// Message forwarded by the server from a client to another client
|
|
22
|
+
// to establish a peer-to-peer (WebRTC) connection
|
|
23
|
+
type[type["SignalForPeer"] = 5] = "SignalForPeer";
|
|
21
24
|
// The weight update
|
|
22
|
-
type[type["Payload"] =
|
|
25
|
+
type[type["Payload"] = 6] = "Payload";
|
|
23
26
|
/* Federated */
|
|
24
27
|
// The server answers the ClientConnected message with the necessary information
|
|
25
28
|
// to start training: node id, latest model global weights, current round etc
|
|
26
|
-
type[type["NewFederatedNodeInfo"] =
|
|
29
|
+
type[type["NewFederatedNodeInfo"] = 7] = "NewFederatedNodeInfo";
|
|
27
30
|
// Message sent by server to notify clients that there are not enough
|
|
28
31
|
// participants to continue training
|
|
29
|
-
type[type["WaitingForMoreParticipants"] =
|
|
32
|
+
type[type["WaitingForMoreParticipants"] = 8] = "WaitingForMoreParticipants";
|
|
30
33
|
// Message sent by server to notify clients that there are now enough
|
|
31
34
|
// participants to start training collaboratively
|
|
32
|
-
type[type["EnoughParticipants"] =
|
|
33
|
-
type[type["SendPayload"] =
|
|
34
|
-
type[type["ReceiveServerPayload"] =
|
|
35
|
+
type[type["EnoughParticipants"] = 9] = "EnoughParticipants";
|
|
36
|
+
type[type["SendPayload"] = 10] = "SendPayload";
|
|
37
|
+
type[type["ReceiveServerPayload"] = 11] = "ReceiveServerPayload";
|
|
35
38
|
})(type || (type = {}));
|
|
36
39
|
export function hasMessageType(raw) {
|
|
37
|
-
if (typeof raw !== 'object' || raw === null)
|
|
40
|
+
if (typeof raw !== 'object' || raw === null)
|
|
38
41
|
return false;
|
|
39
|
-
}
|
|
40
42
|
const o = raw;
|
|
41
43
|
if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
|
|
42
44
|
return false;
|
|
@@ -29,8 +29,8 @@ export const cifar10 = {
|
|
|
29
29
|
IMAGE_W: 224,
|
|
30
30
|
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
|
|
31
31
|
scheme: 'decentralized',
|
|
32
|
+
aggregationStrategy: 'mean',
|
|
32
33
|
privacy: { clippingRadius: 20, noiseScale: 1 },
|
|
33
|
-
decentralizedSecure: true,
|
|
34
34
|
minNbOfParticipants: 3,
|
|
35
35
|
maxShareValue: 100,
|
|
36
36
|
tensorBackend: 'tfjs'
|
|
@@ -29,7 +29,7 @@ export const mnist = {
|
|
|
29
29
|
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
30
30
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
31
31
|
scheme: 'decentralized',
|
|
32
|
-
|
|
32
|
+
aggregationStrategy: 'secure',
|
|
33
33
|
minNbOfParticipants: 3,
|
|
34
34
|
maxShareValue: 100,
|
|
35
35
|
tensorBackend: 'tfjs'
|
|
@@ -25,6 +25,7 @@ export const wikitext = {
|
|
|
25
25
|
dataType: 'text',
|
|
26
26
|
preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
|
|
27
27
|
scheme: 'federated',
|
|
28
|
+
aggregationStrategy: 'mean',
|
|
28
29
|
minNbOfParticipants: 2,
|
|
29
30
|
epochs: 6,
|
|
30
31
|
// Unused by wikitext because data already comes split
|
package/dist/index.d.ts
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
export * as data from './dataset/index.js';
|
|
2
2
|
export * as serialization from './serialization/index.js';
|
|
3
|
-
export { Encoded as EncodedModel } from './serialization/model.js';
|
|
4
|
-
export { Encoded as EncodedWeights } from './serialization/weights.js';
|
|
5
3
|
export * as training from './training/index.js';
|
|
6
4
|
export * as privacy from './privacy.js';
|
|
7
5
|
export * as client from './client/index.js';
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import * as msgpack from "@msgpack/msgpack";
|
|
2
|
+
export function isEncoded(raw) {
|
|
3
|
+
if (!(raw instanceof Uint8Array))
|
|
4
|
+
return false;
|
|
5
|
+
const _ = raw;
|
|
6
|
+
return true;
|
|
7
|
+
}
|
|
8
|
+
// create a new buffer instead of referencing the backing one
|
|
9
|
+
function copy(arr) {
|
|
10
|
+
// `Buffer.slice` (subclass of Uint8Array on Node) doesn't copy
|
|
11
|
+
// thus doesn't respect Liskov substitution principle
|
|
12
|
+
// https://nodejs.org/api/buffer.html#bufslicestart-end
|
|
13
|
+
// here we call the correct implementation
|
|
14
|
+
return Uint8Array.prototype.slice.call(arr);
|
|
15
|
+
}
|
|
16
|
+
// to avoid mapping every ArrayBuffer to Uint8Array,
|
|
17
|
+
// we register our own convertors for the type we know are needed
|
|
18
|
+
// type id are arbitrally taken from msgpack-lite
|
|
19
|
+
// https://www.npmjs.com/package/msgpack-lite#extension-types
|
|
20
|
+
const CODEC = new msgpack.ExtensionCodec();
|
|
21
|
+
// used by TFJS's weights
|
|
22
|
+
CODEC.register({
|
|
23
|
+
type: 0x17,
|
|
24
|
+
encode(obj) {
|
|
25
|
+
if (!(obj instanceof Float32Array))
|
|
26
|
+
return null;
|
|
27
|
+
return new Uint8Array(obj.buffer, obj.byteOffset, obj.byteLength);
|
|
28
|
+
},
|
|
29
|
+
decode: (raw) =>
|
|
30
|
+
// to reinterpred uint8 into float32, it needs to be 4-bytes aligned
|
|
31
|
+
// but the given buffer might not be so we need to copy it.
|
|
32
|
+
new Float32Array(copy(raw).buffer),
|
|
33
|
+
});
|
|
34
|
+
// used by TFJS's saved model
|
|
35
|
+
CODEC.register({
|
|
36
|
+
type: 0x1a,
|
|
37
|
+
encode(obj) {
|
|
38
|
+
if (!(obj instanceof ArrayBuffer))
|
|
39
|
+
return null;
|
|
40
|
+
return new Uint8Array(obj);
|
|
41
|
+
},
|
|
42
|
+
decode: (raw) =>
|
|
43
|
+
// need to copy as backing ArrayBuffer might be larger
|
|
44
|
+
copy(raw),
|
|
45
|
+
});
|
|
46
|
+
export function encode(serialized) {
|
|
47
|
+
return msgpack.encode(serialized, { extensionCodec: CODEC });
|
|
48
|
+
}
|
|
49
|
+
export function decode(encoded) {
|
|
50
|
+
return msgpack.decode(encoded, { extensionCodec: CODEC });
|
|
51
|
+
}
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import type { Model } from '../index.js';
|
|
2
|
-
|
|
3
|
-
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
2
|
+
import { Encoded } from "./coder.js";
|
|
4
3
|
export declare function encode(model: Model): Promise<Encoded>;
|
|
5
4
|
export declare function decode(encoded: unknown): Promise<Model>;
|
|
@@ -1,38 +1,29 @@
|
|
|
1
|
-
import msgpack from 'msgpack-lite';
|
|
2
1
|
import { models, serialization } from '../index.js';
|
|
2
|
+
import * as coder from "./coder.js";
|
|
3
|
+
import { isEncoded } from "./coder.js";
|
|
3
4
|
const Type = {
|
|
4
5
|
TFJS: 0,
|
|
5
6
|
GPT: 1
|
|
6
7
|
};
|
|
7
|
-
export function isEncoded(raw) {
|
|
8
|
-
return raw instanceof Uint8Array;
|
|
9
|
-
}
|
|
10
8
|
export async function encode(model) {
|
|
11
|
-
let encoded;
|
|
12
9
|
switch (true) {
|
|
13
10
|
case model instanceof models.TFJS: {
|
|
14
11
|
const serialized = await model.serialize();
|
|
15
|
-
|
|
16
|
-
break;
|
|
12
|
+
return coder.encode([Type.TFJS, serialized]);
|
|
17
13
|
}
|
|
18
14
|
case model instanceof models.GPT: {
|
|
19
15
|
const { weights, config } = model.serialize();
|
|
20
16
|
const serializedWeights = await serialization.weights.encode(weights);
|
|
21
|
-
|
|
22
|
-
break;
|
|
17
|
+
return coder.encode([Type.GPT, serializedWeights, config]);
|
|
23
18
|
}
|
|
24
19
|
default:
|
|
25
20
|
throw new Error("unknown model type");
|
|
26
21
|
}
|
|
27
|
-
// Node's Buffer extends Node's Uint8Array, which might not be the same
|
|
28
|
-
// as the browser's Uint8Array. we ensure here that it is.
|
|
29
|
-
return new Uint8Array(encoded);
|
|
30
22
|
}
|
|
31
23
|
export async function decode(encoded) {
|
|
32
|
-
if (!isEncoded(encoded))
|
|
24
|
+
if (!isEncoded(encoded))
|
|
33
25
|
throw new Error("Invalid encoding, raw encoding isn't an instance of Uint8Array");
|
|
34
|
-
|
|
35
|
-
const raw = msgpack.decode(encoded);
|
|
26
|
+
const raw = coder.decode(encoded);
|
|
36
27
|
if (!Array.isArray(raw) || raw.length < 2) {
|
|
37
28
|
throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values");
|
|
38
29
|
}
|
|
@@ -59,15 +50,9 @@ export async function decode(encoded) {
|
|
|
59
50
|
else {
|
|
60
51
|
throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3');
|
|
61
52
|
}
|
|
62
|
-
if (!
|
|
63
|
-
throw new Error(
|
|
64
|
-
|
|
65
|
-
const arr = rawModel;
|
|
66
|
-
if (arr.some((r) => typeof r !== 'number')) {
|
|
67
|
-
throw new Error("invalid encoding, gpt-tfjs weights should be numbers");
|
|
68
|
-
}
|
|
69
|
-
const nums = arr;
|
|
70
|
-
const weights = serialization.weights.decode(nums);
|
|
53
|
+
if (!isEncoded(rawModel))
|
|
54
|
+
throw new Error("invalid encoding, gpt-tfjs model weights should be an encoding of its weights");
|
|
55
|
+
const weights = serialization.weights.decode(rawModel);
|
|
71
56
|
return models.GPT.deserialize({ weights, config });
|
|
72
57
|
}
|
|
73
58
|
default:
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
import { WeightsContainer } from
|
|
2
|
-
|
|
3
|
-
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
1
|
+
import { WeightsContainer } from "../index.js";
|
|
2
|
+
import { Encoded } from "./coder.js";
|
|
4
3
|
export declare function encode(weights: WeightsContainer): Promise<Encoded>;
|
|
5
4
|
export declare function decode(encoded: Encoded): WeightsContainer;
|
|
@@ -1,37 +1,26 @@
|
|
|
1
|
-
import * as
|
|
2
|
-
import
|
|
3
|
-
import
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { WeightsContainer } from "../index.js";
|
|
3
|
+
import * as coder from "./coder.js";
|
|
4
4
|
function isSerialized(raw) {
|
|
5
|
-
if (typeof raw !==
|
|
5
|
+
if (typeof raw !== "object" || raw === null)
|
|
6
6
|
return false;
|
|
7
|
-
}
|
|
8
7
|
const { shape, data } = raw;
|
|
9
|
-
if (!(Array.isArray(shape) && shape.every((e) => typeof e ===
|
|
10
|
-
!(
|
|
8
|
+
if (!(Array.isArray(shape) && shape.every((e) => typeof e === "number")) ||
|
|
9
|
+
!(data instanceof Float32Array))
|
|
11
10
|
return false;
|
|
12
|
-
}
|
|
13
|
-
const _ = {
|
|
14
|
-
shape: shape,
|
|
15
|
-
data: data,
|
|
16
|
-
};
|
|
11
|
+
const _ = { shape, data };
|
|
17
12
|
return true;
|
|
18
13
|
}
|
|
19
|
-
export function isEncoded(raw) {
|
|
20
|
-
return Array.isArray(raw) && raw.every((e) => typeof e === 'number');
|
|
21
|
-
}
|
|
22
14
|
export async function encode(weights) {
|
|
23
|
-
const serialized = await Promise.all(weights.weights.map(async (t) => {
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
}));
|
|
29
|
-
return [...msgpack.encode(serialized).values()];
|
|
15
|
+
const serialized = await Promise.all(weights.weights.map(async (t) => ({
|
|
16
|
+
shape: t.shape,
|
|
17
|
+
data: await t.data(),
|
|
18
|
+
})));
|
|
19
|
+
return coder.encode(serialized);
|
|
30
20
|
}
|
|
31
21
|
export function decode(encoded) {
|
|
32
|
-
const raw =
|
|
33
|
-
if (!(Array.isArray(raw) && raw.every(isSerialized)))
|
|
34
|
-
throw new Error(
|
|
35
|
-
}
|
|
22
|
+
const raw = coder.decode(encoded);
|
|
23
|
+
if (!(Array.isArray(raw) && raw.every(isSerialized)))
|
|
24
|
+
throw new Error("expected to decode an array of serialized weights");
|
|
36
25
|
return new WeightsContainer(raw.map((w) => tf.tensor(w.data, w.shape)));
|
|
37
26
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { Map } from
|
|
2
|
-
import type { Model } from
|
|
3
|
-
import type { Task, TaskID } from
|
|
4
|
-
export declare function pushTask(
|
|
5
|
-
export declare function fetchTasks(
|
|
1
|
+
import { Map } from "immutable";
|
|
2
|
+
import type { Model } from "../index.js";
|
|
3
|
+
import type { Task, TaskID } from "./task.js";
|
|
4
|
+
export declare function pushTask(base: URL, task: Task, model: Model): Promise<void>;
|
|
5
|
+
export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task>>;
|