@epfml/discojs 3.0.1-p20240902100041.0 → 3.0.1-p20240904094219.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +16 -2
- package/dist/aggregator/base.js +25 -3
- package/dist/aggregator/mean.d.ts +1 -0
- package/dist/aggregator/mean.js +11 -6
- package/dist/aggregator/secure.js +1 -1
- package/dist/client/{base.d.ts → client.d.ts} +13 -30
- package/dist/client/{base.js → client.js} +10 -20
- package/dist/client/decentralized/{base.d.ts → decentralized_client.d.ts} +5 -5
- package/dist/client/decentralized/{base.js → decentralized_client.js} +20 -16
- package/dist/client/decentralized/index.d.ts +1 -1
- package/dist/client/decentralized/index.js +1 -1
- package/dist/client/decentralized/messages.d.ts +7 -2
- package/dist/client/decentralized/messages.js +4 -2
- package/dist/client/event_connection.js +2 -2
- package/dist/client/federated/federated_client.d.ts +44 -0
- package/dist/client/federated/federated_client.js +210 -0
- package/dist/client/federated/index.d.ts +1 -1
- package/dist/client/federated/index.js +1 -1
- package/dist/client/federated/messages.d.ts +17 -2
- package/dist/client/federated/messages.js +3 -1
- package/dist/client/index.d.ts +2 -2
- package/dist/client/index.js +2 -2
- package/dist/client/local_client.d.ts +10 -0
- package/dist/client/local_client.js +14 -0
- package/dist/client/messages.d.ts +6 -8
- package/dist/client/messages.js +23 -7
- package/dist/client/utils.js +1 -1
- package/dist/default_tasks/cifar10.js +1 -2
- package/dist/default_tasks/lus_covid.js +1 -1
- package/dist/default_tasks/mnist.js +1 -2
- package/dist/default_tasks/simple_face.js +2 -2
- package/dist/default_tasks/titanic.js +2 -2
- package/dist/default_tasks/wikitext.js +1 -1
- package/dist/index.d.ts +4 -2
- package/dist/index.js +1 -1
- package/dist/logging/logger.d.ts +1 -1
- package/dist/serialization/model.js +18 -9
- package/dist/task/index.d.ts +0 -1
- package/dist/task/index.js +0 -1
- package/dist/task/task.d.ts +0 -2
- package/dist/task/task.js +2 -4
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +3 -5
- package/dist/training/disco.d.ts +14 -16
- package/dist/training/disco.js +22 -46
- package/dist/training/index.d.ts +1 -1
- package/dist/training/trainer.d.ts +3 -2
- package/dist/training/trainer.js +12 -5
- package/dist/utils/event_emitter.js +1 -3
- package/package.json +1 -1
- package/dist/client/federated/base.d.ts +0 -38
- package/dist/client/federated/base.js +0 -130
- package/dist/client/local.d.ts +0 -5
- package/dist/client/local.js +0 -6
- package/dist/memory/base.d.ts +0 -111
- package/dist/memory/base.js +0 -9
- package/dist/memory/empty.d.ts +0 -20
- package/dist/memory/empty.js +0 -43
- package/dist/memory/index.d.ts +0 -2
- package/dist/memory/index.js +0 -2
- package/dist/task/digest.d.ts +0 -5
- package/dist/task/digest.js +0 -14
|
@@ -5,7 +5,6 @@ interface Privacy {
|
|
|
5
5
|
noiseScale?: number;
|
|
6
6
|
}
|
|
7
7
|
export interface TrainingInformation {
|
|
8
|
-
modelID: string;
|
|
9
8
|
epochs: number;
|
|
10
9
|
roundDuration: number;
|
|
11
10
|
validationSplit: number;
|
|
@@ -21,7 +20,7 @@ export interface TrainingInformation {
|
|
|
21
20
|
privacy?: Privacy;
|
|
22
21
|
decentralizedSecure?: boolean;
|
|
23
22
|
maxShareValue?: number;
|
|
24
|
-
|
|
23
|
+
minNbOfParticipants: number;
|
|
25
24
|
aggregator?: 'mean' | 'secure';
|
|
26
25
|
tokenizer?: string | PreTrainedTokenizer;
|
|
27
26
|
maxSequenceLength?: number;
|
|
@@ -24,20 +24,19 @@ export function isTrainingInformation(raw) {
|
|
|
24
24
|
if (typeof raw !== 'object' || raw === null) {
|
|
25
25
|
return false;
|
|
26
26
|
}
|
|
27
|
-
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue,
|
|
27
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
|
|
28
28
|
if (typeof dataType !== 'string' ||
|
|
29
|
-
typeof modelID !== 'string' ||
|
|
30
29
|
typeof epochs !== 'number' ||
|
|
31
30
|
typeof batchSize !== 'number' ||
|
|
32
31
|
typeof roundDuration !== 'number' ||
|
|
33
32
|
typeof validationSplit !== 'number' ||
|
|
33
|
+
typeof minNbOfParticipants !== 'number' ||
|
|
34
34
|
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
35
35
|
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
36
36
|
(aggregator !== undefined && typeof aggregator !== 'string') ||
|
|
37
37
|
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
38
38
|
(privacy !== undefined && !isPrivacy(privacy)) ||
|
|
39
39
|
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
40
|
-
(minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
|
|
41
40
|
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
42
41
|
(IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
|
|
43
42
|
(LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
|
|
@@ -96,8 +95,7 @@ export function isTrainingInformation(raw) {
|
|
|
96
95
|
epochs,
|
|
97
96
|
inputColumns,
|
|
98
97
|
maxShareValue,
|
|
99
|
-
|
|
100
|
-
modelID,
|
|
98
|
+
minNbOfParticipants,
|
|
101
99
|
outputColumns,
|
|
102
100
|
preprocessingFunctions,
|
|
103
101
|
roundDuration,
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -1,32 +1,34 @@
|
|
|
1
|
-
import { client as clients, BatchLogs, EpochLogs, Logger,
|
|
1
|
+
import { client as clients, BatchLogs, EpochLogs, Logger, Task, TrainingInformation } from "../index.js";
|
|
2
2
|
import type { TypedLabeledDataset } from "../index.js";
|
|
3
3
|
import type { Aggregator } from "../aggregator/index.js";
|
|
4
|
+
import { EventEmitter } from "../utils/event_emitter.js";
|
|
4
5
|
import { RoundLogs, Trainer } from "./trainer.js";
|
|
5
|
-
interface
|
|
6
|
+
interface DiscoConfig {
|
|
6
7
|
scheme: TrainingInformation["scheme"];
|
|
7
8
|
logger: Logger;
|
|
8
|
-
memory: Memory;
|
|
9
9
|
}
|
|
10
|
+
export type RoundStatus = "Waiting for more participants" | "Retrieving peers' information" | "Updating the model with other participants' models" | "Training the model on the data you connected";
|
|
10
11
|
/**
|
|
11
12
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
12
|
-
* a convenient object providing a reduced yet complete API that wraps model training
|
|
13
|
-
* communication with nodes
|
|
13
|
+
* a convenient object providing a reduced yet complete API that wraps model training and
|
|
14
|
+
* communication with nodes.
|
|
14
15
|
*/
|
|
15
|
-
export declare class Disco {
|
|
16
|
+
export declare class Disco extends EventEmitter<{
|
|
17
|
+
'status': RoundStatus;
|
|
18
|
+
}> {
|
|
16
19
|
#private;
|
|
17
20
|
readonly trainer: Trainer;
|
|
18
|
-
private constructor();
|
|
19
21
|
/**
|
|
20
22
|
* Connect to the given task and get ready to train.
|
|
21
23
|
*
|
|
22
|
-
*
|
|
23
|
-
*
|
|
24
|
+
* @param task
|
|
24
25
|
* @param clientConfig client to connect with or parameters on how to create one.
|
|
25
|
-
|
|
26
|
-
|
|
26
|
+
* @param config the DiscoConfig
|
|
27
|
+
*/
|
|
28
|
+
constructor(task: Task, clientConfig: clients.Client | URL | {
|
|
27
29
|
aggregator: Aggregator;
|
|
28
30
|
url: URL;
|
|
29
|
-
}, config: Partial<
|
|
31
|
+
}, config: Partial<DiscoConfig>);
|
|
30
32
|
/** Train on dataset, yielding logs of every round. */
|
|
31
33
|
trainByRound(dataset: TypedLabeledDataset): AsyncGenerator<RoundLogs>;
|
|
32
34
|
/** Train on dataset, yielding logs of every epoch. */
|
|
@@ -42,10 +44,6 @@ export declare class Disco {
|
|
|
42
44
|
* If you don't care about the whole process, use one of the other train methods.
|
|
43
45
|
**/
|
|
44
46
|
train(dataset: TypedLabeledDataset): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
|
|
45
|
-
/**
|
|
46
|
-
* Stops the ongoing training instance without disconnecting the client.
|
|
47
|
-
*/
|
|
48
|
-
pause(): Promise<void>;
|
|
49
47
|
/**
|
|
50
48
|
* Completely stops the ongoing training instance.
|
|
51
49
|
*/
|
package/dist/training/disco.js
CHANGED
|
@@ -1,38 +1,31 @@
|
|
|
1
|
-
import { async_iterator, client as clients, ConsoleLogger,
|
|
1
|
+
import { async_iterator, client as clients, ConsoleLogger, } from "../index.js";
|
|
2
2
|
import { getAggregator } from "../aggregator/index.js";
|
|
3
3
|
import { enumerate, split } from "../utils/async_iterator.js";
|
|
4
|
+
import { EventEmitter } from "../utils/event_emitter.js";
|
|
4
5
|
import { Trainer } from "./trainer.js";
|
|
5
6
|
import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
|
|
6
7
|
/**
|
|
7
8
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
8
|
-
* a convenient object providing a reduced yet complete API that wraps model training
|
|
9
|
-
* communication with nodes
|
|
9
|
+
* a convenient object providing a reduced yet complete API that wraps model training and
|
|
10
|
+
* communication with nodes.
|
|
10
11
|
*/
|
|
11
|
-
export class Disco {
|
|
12
|
+
export class Disco extends EventEmitter {
|
|
12
13
|
trainer;
|
|
13
14
|
#client;
|
|
14
15
|
#logger;
|
|
15
|
-
#memory;
|
|
16
16
|
#task;
|
|
17
|
-
constructor(trainer, task, client, memory, logger) {
|
|
18
|
-
this.trainer = trainer;
|
|
19
|
-
this.#client = client;
|
|
20
|
-
this.#logger = logger;
|
|
21
|
-
this.#memory = memory;
|
|
22
|
-
this.#task = task;
|
|
23
|
-
}
|
|
24
17
|
/**
|
|
25
18
|
* Connect to the given task and get ready to train.
|
|
26
19
|
*
|
|
27
|
-
*
|
|
28
|
-
*
|
|
20
|
+
* @param task
|
|
29
21
|
* @param clientConfig client to connect with or parameters on how to create one.
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
22
|
+
* @param config the DiscoConfig
|
|
23
|
+
*/
|
|
24
|
+
constructor(task, clientConfig, config) {
|
|
25
|
+
super();
|
|
26
|
+
const { scheme, logger } = {
|
|
33
27
|
scheme: task.trainingInformation.scheme,
|
|
34
28
|
logger: new ConsoleLogger(),
|
|
35
|
-
memory: new EmptyMemory(),
|
|
36
29
|
...config,
|
|
37
30
|
};
|
|
38
31
|
let client;
|
|
@@ -52,18 +45,12 @@ export class Disco {
|
|
|
52
45
|
}
|
|
53
46
|
if (client.task !== task)
|
|
54
47
|
throw new Error("client not setup for given task");
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
};
|
|
62
|
-
if (await memory.contains(memoryInfo))
|
|
63
|
-
model = await memory.getModel(memoryInfo);
|
|
64
|
-
else
|
|
65
|
-
model = await client.getLatestModel();
|
|
66
|
-
return new Disco(new Trainer(task, model, client), task, client, memory, logger);
|
|
48
|
+
this.#logger = logger;
|
|
49
|
+
this.#client = client;
|
|
50
|
+
this.#task = task;
|
|
51
|
+
this.trainer = new Trainer(task, client);
|
|
52
|
+
// Simply propagate the training status events emitted by the client
|
|
53
|
+
this.#client.on('status', status => this.emit('status', status));
|
|
67
54
|
}
|
|
68
55
|
/** Train on dataset, yielding logs of every round. */
|
|
69
56
|
async *trainByRound(dataset) {
|
|
@@ -106,11 +93,12 @@ export class Disco {
|
|
|
106
93
|
* If you don't care about the whole process, use one of the other train methods.
|
|
107
94
|
**/
|
|
108
95
|
async *train(dataset) {
|
|
109
|
-
this.#logger.success("Training started
|
|
96
|
+
this.#logger.success("Training started");
|
|
110
97
|
const data = await labeledDatasetToDataSplit(this.#task, dataset);
|
|
111
98
|
const trainData = data.train.preprocess().batch().dataset;
|
|
112
99
|
const validationData = data.validation?.preprocess().batch().dataset ?? trainData;
|
|
113
|
-
|
|
100
|
+
// the client fetches the latest weights upon connection
|
|
101
|
+
this.trainer.model = await this.#client.connect();
|
|
114
102
|
for await (const [round, epochs] of enumerate(this.trainer.train(trainData, validationData))) {
|
|
115
103
|
yield async function* () {
|
|
116
104
|
const [gen, returnedRoundLogs] = split(epochs);
|
|
@@ -123,6 +111,7 @@ export class Disco {
|
|
|
123
111
|
` Epoch: ${epoch}`,
|
|
124
112
|
` Training loss: ${epochLogs.training.loss}`,
|
|
125
113
|
` Training accuracy: ${epochLogs.training.accuracy}`,
|
|
114
|
+
` Peak memory: ${epochLogs.peakMemory}`,
|
|
126
115
|
epochLogs.validation !== undefined
|
|
127
116
|
? ` Validation loss: ${epochLogs.validation.loss}`
|
|
128
117
|
: "",
|
|
@@ -133,26 +122,13 @@ export class Disco {
|
|
|
133
122
|
}
|
|
134
123
|
return await returnedRoundLogs;
|
|
135
124
|
}.bind(this)();
|
|
136
|
-
await this.#memory.updateWorkingModel({
|
|
137
|
-
type: "working",
|
|
138
|
-
taskID: this.#task.id,
|
|
139
|
-
name: this.#task.trainingInformation.modelID,
|
|
140
|
-
tensorBackend: this.#task.trainingInformation.tensorBackend,
|
|
141
|
-
}, this.trainer.model);
|
|
142
125
|
}
|
|
143
|
-
this.#logger.success("Training finished
|
|
144
|
-
}
|
|
145
|
-
/**
|
|
146
|
-
* Stops the ongoing training instance without disconnecting the client.
|
|
147
|
-
*/
|
|
148
|
-
async pause() {
|
|
149
|
-
await this.trainer.stopTraining();
|
|
126
|
+
this.#logger.success("Training finished");
|
|
150
127
|
}
|
|
151
128
|
/**
|
|
152
129
|
* Completely stops the ongoing training instance.
|
|
153
130
|
*/
|
|
154
131
|
async close() {
|
|
155
|
-
await this.pause();
|
|
156
132
|
await this.#client.disconnect();
|
|
157
133
|
}
|
|
158
134
|
}
|
package/dist/training/index.d.ts
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export { Disco } from './disco.js';
|
|
1
|
+
export { Disco, RoundStatus } from './disco.js';
|
|
2
2
|
export { RoundLogs, Trainer } from './trainer.js';
|
|
@@ -9,8 +9,9 @@ export interface RoundLogs {
|
|
|
9
9
|
/** Train a model and exchange with others **/
|
|
10
10
|
export declare class Trainer {
|
|
11
11
|
#private;
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
get model(): Model;
|
|
13
|
+
set model(model: Model);
|
|
14
|
+
constructor(task: Task, client: Client);
|
|
14
15
|
stopTraining(): Promise<void>;
|
|
15
16
|
train(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
|
|
16
17
|
}
|
package/dist/training/trainer.js
CHANGED
|
@@ -4,14 +4,21 @@ import { privacy } from "../index.js";
|
|
|
4
4
|
import * as async_iterator from "../utils/async_iterator.js";
|
|
5
5
|
/** Train a model and exchange with others **/
|
|
6
6
|
export class Trainer {
|
|
7
|
-
model;
|
|
8
7
|
#client;
|
|
9
8
|
#roundDuration;
|
|
10
9
|
#epochs;
|
|
11
10
|
#privacy;
|
|
11
|
+
#model;
|
|
12
12
|
#training;
|
|
13
|
-
|
|
14
|
-
this
|
|
13
|
+
get model() {
|
|
14
|
+
if (this.#model === undefined)
|
|
15
|
+
throw new Error("trainer's model has not been set");
|
|
16
|
+
return this.#model;
|
|
17
|
+
}
|
|
18
|
+
set model(model) {
|
|
19
|
+
this.#model = model;
|
|
20
|
+
}
|
|
21
|
+
constructor(task, client) {
|
|
15
22
|
this.#client = client;
|
|
16
23
|
this.#roundDuration = task.trainingInformation.roundDuration;
|
|
17
24
|
this.#epochs = task.trainingInformation.epochs;
|
|
@@ -37,12 +44,12 @@ export class Trainer {
|
|
|
37
44
|
const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
|
|
38
45
|
let previousRoundWeights;
|
|
39
46
|
for (let round = 0; round < totalRound; round++) {
|
|
40
|
-
await this.#client.onRoundBeginCommunication(
|
|
47
|
+
await this.#client.onRoundBeginCommunication();
|
|
41
48
|
yield this.#runRound(dataset, valDataset);
|
|
42
49
|
let localWeights = this.model.weights;
|
|
43
50
|
if (this.#privacy !== undefined)
|
|
44
51
|
localWeights = await applyPrivacy(previousRoundWeights, localWeights, this.#privacy);
|
|
45
|
-
const networkWeights = await this.#client.onRoundEndCommunication(localWeights
|
|
52
|
+
const networkWeights = await this.#client.onRoundEndCommunication(localWeights);
|
|
46
53
|
this.model.weights = previousRoundWeights = networkWeights;
|
|
47
54
|
}
|
|
48
55
|
}
|
|
@@ -47,9 +47,7 @@ export class EventEmitter {
|
|
|
47
47
|
emit(event, value) {
|
|
48
48
|
const eventListeners = this.listeners[event] ?? List();
|
|
49
49
|
this.listeners[event] = eventListeners.filterNot(([once]) => once);
|
|
50
|
-
eventListeners.forEach(([_, listener]) => {
|
|
51
|
-
listener(value);
|
|
52
|
-
});
|
|
50
|
+
eventListeners.forEach(([_, listener]) => { listener(value); });
|
|
53
51
|
}
|
|
54
52
|
}
|
|
55
53
|
/** `EventEmitter` for all events */
|
package/package.json
CHANGED
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
import { type WeightsContainer } from "../../index.js";
|
|
2
|
-
import { Base as Client } from "../base.js";
|
|
3
|
-
/**
|
|
4
|
-
* Client class that communicates with a centralized, federated server, when training
|
|
5
|
-
* a specific task in the federated setting.
|
|
6
|
-
*/
|
|
7
|
-
export declare class Base extends Client {
|
|
8
|
-
#private;
|
|
9
|
-
/**
|
|
10
|
-
* Arbitrary node id assigned to the federated server which we are communicating with.
|
|
11
|
-
* Indeed, the server acts as a node within the network. In the federated setting described
|
|
12
|
-
* by this client class, the server is the only node which we are communicating with.
|
|
13
|
-
*/
|
|
14
|
-
static readonly SERVER_NODE_ID = "federated-server-node-id";
|
|
15
|
-
get nbOfParticipants(): number;
|
|
16
|
-
/**
|
|
17
|
-
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
18
|
-
*/
|
|
19
|
-
private connectServer;
|
|
20
|
-
/**
|
|
21
|
-
* Initializes the connection to the server and get our own node id.
|
|
22
|
-
* TODO: In the federated setting, should return the current server-side round
|
|
23
|
-
* for the task.
|
|
24
|
-
*/
|
|
25
|
-
connect(): Promise<void>;
|
|
26
|
-
/**
|
|
27
|
-
* Disconnection process when user quits the task.
|
|
28
|
-
*/
|
|
29
|
-
disconnect(): Promise<void>;
|
|
30
|
-
onRoundBeginCommunication(): Promise<void>;
|
|
31
|
-
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
|
|
32
|
-
/**
|
|
33
|
-
* Send a message containing our local weight updates to the federated server.
|
|
34
|
-
* And waits for the server to reply with the most recent aggregated weights
|
|
35
|
-
* @param payload The weight updates to send
|
|
36
|
-
*/
|
|
37
|
-
private sendPayloadAndReceiveResult;
|
|
38
|
-
}
|
|
@@ -1,130 +0,0 @@
|
|
|
1
|
-
import createDebug from "debug";
|
|
2
|
-
import { serialization, } from "../../index.js";
|
|
3
|
-
import { Base as Client } from "../base.js";
|
|
4
|
-
import { type } from "../messages.js";
|
|
5
|
-
import { waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
|
|
6
|
-
import * as messages from "./messages.js";
|
|
7
|
-
const debug = createDebug("discojs:client:federated");
|
|
8
|
-
/**
|
|
9
|
-
* Client class that communicates with a centralized, federated server, when training
|
|
10
|
-
* a specific task in the federated setting.
|
|
11
|
-
*/
|
|
12
|
-
export class Base extends Client {
|
|
13
|
-
/**
|
|
14
|
-
* Arbitrary node id assigned to the federated server which we are communicating with.
|
|
15
|
-
* Indeed, the server acts as a node within the network. In the federated setting described
|
|
16
|
-
* by this client class, the server is the only node which we are communicating with.
|
|
17
|
-
*/
|
|
18
|
-
static SERVER_NODE_ID = "federated-server-node-id";
|
|
19
|
-
// Total number of other federated contributors, including this client, excluding the server
|
|
20
|
-
// E.g., if 3 users are training a federated model, nbOfParticipants is 3
|
|
21
|
-
#nbOfParticipants = 1;
|
|
22
|
-
// the number of participants excluding the server
|
|
23
|
-
get nbOfParticipants() {
|
|
24
|
-
return this.#nbOfParticipants;
|
|
25
|
-
}
|
|
26
|
-
/**
|
|
27
|
-
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
28
|
-
*/
|
|
29
|
-
async connectServer(url) {
|
|
30
|
-
const server = await WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated);
|
|
31
|
-
return server;
|
|
32
|
-
}
|
|
33
|
-
/**
|
|
34
|
-
* Initializes the connection to the server and get our own node id.
|
|
35
|
-
* TODO: In the federated setting, should return the current server-side round
|
|
36
|
-
* for the task.
|
|
37
|
-
*/
|
|
38
|
-
async connect() {
|
|
39
|
-
const serverURL = new URL("", this.url.href);
|
|
40
|
-
switch (this.url.protocol) {
|
|
41
|
-
case "http:":
|
|
42
|
-
serverURL.protocol = "ws:";
|
|
43
|
-
break;
|
|
44
|
-
case "https:":
|
|
45
|
-
serverURL.protocol = "wss:";
|
|
46
|
-
break;
|
|
47
|
-
default:
|
|
48
|
-
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
49
|
-
}
|
|
50
|
-
serverURL.pathname += `feai/${this.task.id}`;
|
|
51
|
-
this._server = await this.connectServer(serverURL);
|
|
52
|
-
this.aggregator.registerNode(Base.SERVER_NODE_ID);
|
|
53
|
-
const msg = {
|
|
54
|
-
type: type.ClientConnected,
|
|
55
|
-
};
|
|
56
|
-
this.server.send(msg);
|
|
57
|
-
const received = await waitMessageWithTimeout(this.server, type.AssignNodeID);
|
|
58
|
-
debug(`[${received.id}] assign id generated by the server`);
|
|
59
|
-
this._ownId = received.id;
|
|
60
|
-
}
|
|
61
|
-
/**
|
|
62
|
-
* Disconnection process when user quits the task.
|
|
63
|
-
*/
|
|
64
|
-
async disconnect() {
|
|
65
|
-
await this.server.disconnect();
|
|
66
|
-
this._server = undefined;
|
|
67
|
-
this._ownId = undefined;
|
|
68
|
-
this.aggregator.setNodes(this.aggregator.nodes.delete(Base.SERVER_NODE_ID));
|
|
69
|
-
return Promise.resolve();
|
|
70
|
-
}
|
|
71
|
-
onRoundBeginCommunication() {
|
|
72
|
-
// Prepare the result promise for the incoming round
|
|
73
|
-
this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
|
|
74
|
-
return Promise.resolve();
|
|
75
|
-
}
|
|
76
|
-
async onRoundEndCommunication(weights, round) {
|
|
77
|
-
// NB: For now, we suppose a fully-federated setting.
|
|
78
|
-
if (this.aggregationResult === undefined) {
|
|
79
|
-
throw new Error("local aggregation result was not set");
|
|
80
|
-
}
|
|
81
|
-
// Send our local contribution to the server
|
|
82
|
-
// and receive the most recent weights as an answer to our contribution
|
|
83
|
-
const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
|
|
84
|
-
if (serverResult !== undefined &&
|
|
85
|
-
this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
|
|
86
|
-
// Regular case: the server sends us its aggregation result which will serve our
|
|
87
|
-
// own aggregation result.
|
|
88
|
-
}
|
|
89
|
-
else {
|
|
90
|
-
// Unexpected case: for some reason, the server result is stale.
|
|
91
|
-
// We proceed to the next round without its result.
|
|
92
|
-
debug(`[${this.ownId}] server result is either stale or not received`);
|
|
93
|
-
this.aggregator.nextRound();
|
|
94
|
-
}
|
|
95
|
-
return await this.aggregationResult;
|
|
96
|
-
}
|
|
97
|
-
/**
|
|
98
|
-
* Send a message containing our local weight updates to the federated server.
|
|
99
|
-
* And waits for the server to reply with the most recent aggregated weights
|
|
100
|
-
* @param payload The weight updates to send
|
|
101
|
-
*/
|
|
102
|
-
async sendPayloadAndReceiveResult(payload) {
|
|
103
|
-
const msg = {
|
|
104
|
-
type: type.SendPayload,
|
|
105
|
-
payload: await serialization.weights.encode(payload),
|
|
106
|
-
round: this.aggregator.round,
|
|
107
|
-
};
|
|
108
|
-
this.server.send(msg);
|
|
109
|
-
// Waits for the server's result for its current (most recent) round and add it to our aggregator.
|
|
110
|
-
// Updates the aggregator's round if it's behind the server's.
|
|
111
|
-
try {
|
|
112
|
-
// It is important than the client immediately awaits the server result or it may miss it
|
|
113
|
-
const { payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload);
|
|
114
|
-
const serverRound = round;
|
|
115
|
-
this.#nbOfParticipants = nbOfParticipants; // Save the current participants
|
|
116
|
-
// Store the server result only if it is not stale
|
|
117
|
-
if (this.aggregator.round <= round) {
|
|
118
|
-
const serverResult = serialization.weights.decode(payload);
|
|
119
|
-
// Update the local round to match the server's
|
|
120
|
-
if (this.aggregator.round < serverRound) {
|
|
121
|
-
this.aggregator.setRound(serverRound);
|
|
122
|
-
}
|
|
123
|
-
return serverResult;
|
|
124
|
-
}
|
|
125
|
-
}
|
|
126
|
-
catch (e) {
|
|
127
|
-
debug(`[${this.ownId}] while receiving results: %o`, e);
|
|
128
|
-
}
|
|
129
|
-
}
|
|
130
|
-
}
|
package/dist/client/local.d.ts
DELETED
package/dist/client/local.js
DELETED
package/dist/memory/base.d.ts
DELETED
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
import type { Model, TaskID } from '../index.js';
|
|
2
|
-
/**
|
|
3
|
-
* Type of models stored in memory. Stored models can either be a model currently
|
|
4
|
-
* being trained ("working model") or a regular model saved in memory ("saved model").
|
|
5
|
-
* There can only be a single working model for a given task.
|
|
6
|
-
*/
|
|
7
|
-
type StoredModelType = 'saved' | 'working';
|
|
8
|
-
/**
|
|
9
|
-
* Model information which uniquely identifies a model in memory.
|
|
10
|
-
*/
|
|
11
|
-
export interface ModelInfo {
|
|
12
|
-
type: StoredModelType;
|
|
13
|
-
version?: number;
|
|
14
|
-
taskID: TaskID;
|
|
15
|
-
name: string;
|
|
16
|
-
tensorBackend: 'gpt' | 'tfjs';
|
|
17
|
-
}
|
|
18
|
-
/**
|
|
19
|
-
* A model source uniquely identifies a model stored in memory.
|
|
20
|
-
* It can be in the form of either a model info object or an ID
|
|
21
|
-
* (one-to-one mapping between the two)
|
|
22
|
-
*/
|
|
23
|
-
export type ModelSource = ModelInfo | string;
|
|
24
|
-
/**
|
|
25
|
-
* Represents a model memory system, providing functions to fetch, save, delete and update models.
|
|
26
|
-
* Stored models can either be a model currently being trained ("working model") or a regular model
|
|
27
|
-
* saved in memory ("saved model"). There can only be a single working model for a given task.
|
|
28
|
-
*/
|
|
29
|
-
export declare abstract class Memory {
|
|
30
|
-
/**
|
|
31
|
-
* Fetches the model identified by the given model source.
|
|
32
|
-
* @param source The model source
|
|
33
|
-
* @returns The model
|
|
34
|
-
*/
|
|
35
|
-
abstract getModel(source: ModelSource): Promise<Model>;
|
|
36
|
-
/**
|
|
37
|
-
* Removes the model identified by the given model source from memory.
|
|
38
|
-
* @param source The model source
|
|
39
|
-
* @returns The model
|
|
40
|
-
*/
|
|
41
|
-
abstract deleteModel(source: ModelSource): Promise<void>;
|
|
42
|
-
/**
|
|
43
|
-
* Replaces the corresponding working model with the saved model identified by the given model source.
|
|
44
|
-
* @param source The model source
|
|
45
|
-
*/
|
|
46
|
-
abstract loadModel(source: ModelSource): Promise<void>;
|
|
47
|
-
/**
|
|
48
|
-
* Fetches metadata for the model identified by the given model source.
|
|
49
|
-
* If the model does not exist in memory, returns undefined.
|
|
50
|
-
* @param source The model source
|
|
51
|
-
* @returns The model metadata or undefined
|
|
52
|
-
*/
|
|
53
|
-
abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
|
|
54
|
-
/**
|
|
55
|
-
* Replaces the working model identified by the given source with the newly provided model.
|
|
56
|
-
* @param source The model source
|
|
57
|
-
* @param model The new model
|
|
58
|
-
*/
|
|
59
|
-
abstract updateWorkingModel(source: ModelSource, model: Model): Promise<void>;
|
|
60
|
-
/**
|
|
61
|
-
* Creates a saved model copy from the working model identified by the given model source.
|
|
62
|
-
* Returns the saved model's path.
|
|
63
|
-
* @param source The model source
|
|
64
|
-
* @returns The saved model's path
|
|
65
|
-
*/
|
|
66
|
-
abstract saveWorkingModel(source: ModelSource): Promise<string | undefined>;
|
|
67
|
-
/**
|
|
68
|
-
* Saves the newly provided model to the given model source.
|
|
69
|
-
* Returns the saved model's path
|
|
70
|
-
* @param source The model source
|
|
71
|
-
* @param model The new model
|
|
72
|
-
* @returns The saved model's path
|
|
73
|
-
*/
|
|
74
|
-
abstract saveModel(source: ModelSource, model: Model): Promise<string | undefined>;
|
|
75
|
-
/**
|
|
76
|
-
* Moves the model identified by the model source to a file system. This is platform-dependent.
|
|
77
|
-
* @param source The model source
|
|
78
|
-
*/
|
|
79
|
-
abstract downloadModel(source: ModelSource): Promise<void>;
|
|
80
|
-
/**
|
|
81
|
-
* Checks whether the model memory contains the model identified by the given source.
|
|
82
|
-
* @param source The model source
|
|
83
|
-
* @returns True if the memory contains the model, false otherwise
|
|
84
|
-
*/
|
|
85
|
-
abstract contains(source: ModelSource): Promise<boolean>;
|
|
86
|
-
/**
|
|
87
|
-
* Computes the path in memory corresponding to the given model source, be it a path or model information.
|
|
88
|
-
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
89
|
-
* with a one-to-one equivalence. Returns undefined instead if no path could be inferred from the given
|
|
90
|
-
* model source.
|
|
91
|
-
* @param source The model source
|
|
92
|
-
* @returns The model path
|
|
93
|
-
*/
|
|
94
|
-
abstract getModelMemoryPath(source: ModelSource): string | undefined;
|
|
95
|
-
/**
|
|
96
|
-
* Computes the model information corresponding to the given model source, be it a path or model information.
|
|
97
|
-
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
98
|
-
* with a one-to-one equivalence. Returns undefined instead if no unique model information could be inferred
|
|
99
|
-
* from the given model source.
|
|
100
|
-
* @param source The model source
|
|
101
|
-
* @returns The model information
|
|
102
|
-
*/
|
|
103
|
-
abstract getModelInfo(source: ModelSource): ModelInfo | undefined;
|
|
104
|
-
/**
|
|
105
|
-
* Computes the lowest version a model source can have without conflicting with model versions currently in memory.
|
|
106
|
-
* @param source The model source
|
|
107
|
-
* @returns The duplicated model source
|
|
108
|
-
*/
|
|
109
|
-
abstract duplicateSource(source: ModelSource): Promise<ModelSource | undefined>;
|
|
110
|
-
}
|
|
111
|
-
export {};
|
package/dist/memory/base.js
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
// only used browser-side
|
|
2
|
-
// TODO: replace IO type
|
|
3
|
-
/**
|
|
4
|
-
* Represents a model memory system, providing functions to fetch, save, delete and update models.
|
|
5
|
-
* Stored models can either be a model currently being trained ("working model") or a regular model
|
|
6
|
-
* saved in memory ("saved model"). There can only be a single working model for a given task.
|
|
7
|
-
*/
|
|
8
|
-
export class Memory {
|
|
9
|
-
}
|
package/dist/memory/empty.d.ts
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
import type { Model } from '../index.js';
|
|
2
|
-
import type { ModelInfo } from './base.js';
|
|
3
|
-
import { Memory } from './base.js';
|
|
4
|
-
/**
|
|
5
|
-
* Represents an empty model memory.
|
|
6
|
-
*/
|
|
7
|
-
export declare class Empty extends Memory {
|
|
8
|
-
getModelMetadata(): Promise<undefined>;
|
|
9
|
-
contains(): Promise<boolean>;
|
|
10
|
-
getModel(): Promise<Model>;
|
|
11
|
-
loadModel(): Promise<void>;
|
|
12
|
-
updateWorkingModel(): Promise<void>;
|
|
13
|
-
saveWorkingModel(): Promise<undefined>;
|
|
14
|
-
saveModel(): Promise<undefined>;
|
|
15
|
-
deleteModel(): Promise<void>;
|
|
16
|
-
downloadModel(): Promise<void>;
|
|
17
|
-
getModelMemoryPath(): string;
|
|
18
|
-
getModelInfo(): ModelInfo;
|
|
19
|
-
duplicateSource(): Promise<undefined>;
|
|
20
|
-
}
|