@epfml/discojs 2.1.2-p20240723143623.0 → 2.1.2-p20240723160120.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +8 -48
- package/dist/aggregator/base.js +6 -68
- package/dist/aggregator/get.d.ts +0 -2
- package/dist/aggregator/get.js +4 -4
- package/dist/aggregator/mean.d.ts +2 -2
- package/dist/aggregator/mean.js +3 -6
- package/dist/aggregator/secure.d.ts +2 -2
- package/dist/aggregator/secure.js +4 -7
- package/dist/client/base.d.ts +2 -1
- package/dist/client/base.js +0 -6
- package/dist/client/decentralized/base.d.ts +2 -2
- package/dist/client/decentralized/base.js +9 -8
- package/dist/client/federated/base.d.ts +1 -1
- package/dist/client/federated/base.js +2 -1
- package/dist/client/local.d.ts +3 -1
- package/dist/client/local.js +4 -1
- package/dist/default_tasks/cifar10.js +1 -2
- package/dist/default_tasks/mnist.js +0 -2
- package/dist/default_tasks/simple_face.js +0 -2
- package/dist/default_tasks/titanic.js +0 -2
- package/dist/index.d.ts +0 -1
- package/dist/index.js +0 -1
- package/dist/privacy.d.ts +8 -10
- package/dist/privacy.js +25 -40
- package/dist/task/training_information.d.ts +6 -2
- package/dist/task/training_information.js +17 -5
- package/dist/training/disco.d.ts +30 -28
- package/dist/training/disco.js +76 -61
- package/dist/training/index.d.ts +1 -1
- package/dist/training/index.js +1 -0
- package/dist/training/trainer.d.ts +16 -0
- package/dist/training/trainer.js +72 -0
- package/dist/weights/weights_container.d.ts +0 -5
- package/dist/weights/weights_container.js +0 -7
- package/package.json +1 -1
- package/dist/async_informant.d.ts +0 -15
- package/dist/async_informant.js +0 -42
- package/dist/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/training/trainer/distributed_trainer.js +0 -41
- package/dist/training/trainer/local_trainer.d.ts +0 -12
- package/dist/training/trainer/local_trainer.js +0 -24
- package/dist/training/trainer/trainer.d.ts +0 -32
- package/dist/training/trainer/trainer.js +0 -61
- package/dist/training/trainer/trainer_builder.d.ts +0 -23
- package/dist/training/trainer/trainer_builder.js +0 -47
package/dist/privacy.js
CHANGED
|
@@ -1,42 +1,27 @@
|
|
|
1
|
-
import * as tf from
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
async function frobeniusNorm(weights) {
|
|
3
|
+
const squared = await weights
|
|
4
|
+
.map((w) => w.square().sum())
|
|
5
|
+
.reduce((a, b) => a.add(b))
|
|
6
|
+
.data();
|
|
7
|
+
if (squared.length !== 1)
|
|
8
|
+
throw new Error("unexcepted weights shape");
|
|
9
|
+
return Math.sqrt(squared[0]);
|
|
10
|
+
}
|
|
11
|
+
/** Scramble weights */
|
|
12
|
+
export function addNoise(weights, deviation) {
|
|
13
|
+
const variance = Math.pow(deviation, 2);
|
|
14
|
+
return weights.map((w) => w.add(tf.randomNormal(w.shape, 0, variance)));
|
|
15
|
+
}
|
|
2
16
|
/**
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
const clippingRadius = task.trainingInformation?.clippingRadius;
|
|
14
|
-
const weightsDiff = updatedWeights.sub(staleWeights);
|
|
15
|
-
let newWeightsDiff;
|
|
16
|
-
if (clippingRadius !== undefined) {
|
|
17
|
-
// Frobenius norm
|
|
18
|
-
const norm = weightsDiff.frobeniusNorm();
|
|
19
|
-
newWeightsDiff = weightsDiff.map((w) => {
|
|
20
|
-
const clipped = w.div(Math.max(1, norm / clippingRadius));
|
|
21
|
-
if (noiseScale !== undefined) {
|
|
22
|
-
// Add clipping and noise
|
|
23
|
-
const noise = tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
|
|
24
|
-
return clipped.add(noise);
|
|
25
|
-
}
|
|
26
|
-
else {
|
|
27
|
-
// Add clipping without any noise
|
|
28
|
-
return clipped;
|
|
29
|
-
}
|
|
30
|
-
});
|
|
31
|
-
}
|
|
32
|
-
else {
|
|
33
|
-
if (noiseScale !== undefined) {
|
|
34
|
-
// Add noise without any clipping
|
|
35
|
-
newWeightsDiff = weightsDiff.map((w) => tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)));
|
|
36
|
-
}
|
|
37
|
-
else {
|
|
38
|
-
return updatedWeights;
|
|
39
|
-
}
|
|
40
|
-
}
|
|
41
|
-
return staleWeights.add(newWeightsDiff);
|
|
17
|
+
* Keep weights' norm within radius
|
|
18
|
+
*
|
|
19
|
+
* @param radius maximum norm
|
|
20
|
+
**/
|
|
21
|
+
export async function clipNorm(weights, radius) {
|
|
22
|
+
if (radius <= 0)
|
|
23
|
+
throw new Error("invalid radius");
|
|
24
|
+
const norm = await frobeniusNorm(weights);
|
|
25
|
+
const scaling = Math.max(1, norm / radius);
|
|
26
|
+
return weights.map((w) => w.div(scaling));
|
|
42
27
|
}
|
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
|
|
2
2
|
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
3
|
+
interface Privacy {
|
|
4
|
+
clippingRadius?: number;
|
|
5
|
+
noiseScale?: number;
|
|
6
|
+
}
|
|
3
7
|
export interface TrainingInformation {
|
|
4
8
|
modelID: string;
|
|
5
9
|
epochs: number;
|
|
@@ -14,8 +18,7 @@ export interface TrainingInformation {
|
|
|
14
18
|
IMAGE_W?: number;
|
|
15
19
|
LABEL_LIST?: string[];
|
|
16
20
|
scheme: 'decentralized' | 'federated' | 'local';
|
|
17
|
-
|
|
18
|
-
clippingRadius?: number;
|
|
21
|
+
privacy?: Privacy;
|
|
19
22
|
decentralizedSecure?: boolean;
|
|
20
23
|
maxShareValue?: number;
|
|
21
24
|
minimumReadyPeers?: number;
|
|
@@ -25,3 +28,4 @@ export interface TrainingInformation {
|
|
|
25
28
|
tensorBackend: 'tfjs' | 'gpt';
|
|
26
29
|
}
|
|
27
30
|
export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
|
|
31
|
+
export {};
|
|
@@ -6,11 +6,25 @@ function isStringArray(raw) {
|
|
|
6
6
|
const arr = raw; // isArray is unsafely guarding with any[]
|
|
7
7
|
return arr.every((e) => typeof e === 'string');
|
|
8
8
|
}
|
|
9
|
+
function isPrivacy(raw) {
|
|
10
|
+
if (typeof raw !== "object" || raw === null) {
|
|
11
|
+
return false;
|
|
12
|
+
}
|
|
13
|
+
const { clippingRadius, noiseScale, } = raw;
|
|
14
|
+
if ((clippingRadius !== undefined && typeof clippingRadius !== "number") ||
|
|
15
|
+
(noiseScale !== undefined && typeof noiseScale !== "number"))
|
|
16
|
+
return false;
|
|
17
|
+
const _ = {
|
|
18
|
+
clippingRadius,
|
|
19
|
+
noiseScale,
|
|
20
|
+
};
|
|
21
|
+
return true;
|
|
22
|
+
}
|
|
9
23
|
export function isTrainingInformation(raw) {
|
|
10
24
|
if (typeof raw !== 'object' || raw === null) {
|
|
11
25
|
return false;
|
|
12
26
|
}
|
|
13
|
-
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize,
|
|
27
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
|
|
14
28
|
if (typeof dataType !== 'string' ||
|
|
15
29
|
typeof modelID !== 'string' ||
|
|
16
30
|
typeof epochs !== 'number' ||
|
|
@@ -20,11 +34,10 @@ export function isTrainingInformation(raw) {
|
|
|
20
34
|
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
21
35
|
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
22
36
|
(aggregator !== undefined && typeof aggregator !== 'string') ||
|
|
23
|
-
(clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
|
|
24
37
|
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
38
|
+
(privacy !== undefined && !isPrivacy(privacy)) ||
|
|
25
39
|
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
26
40
|
(minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
|
|
27
|
-
(noiseScale !== undefined && typeof noiseScale !== 'number') ||
|
|
28
41
|
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
29
42
|
(IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
|
|
30
43
|
(LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
|
|
@@ -77,15 +90,14 @@ export function isTrainingInformation(raw) {
|
|
|
77
90
|
LABEL_LIST,
|
|
78
91
|
aggregator,
|
|
79
92
|
batchSize,
|
|
80
|
-
clippingRadius,
|
|
81
93
|
dataType,
|
|
82
94
|
decentralizedSecure,
|
|
95
|
+
privacy,
|
|
83
96
|
epochs,
|
|
84
97
|
inputColumns,
|
|
85
98
|
maxShareValue,
|
|
86
99
|
minimumReadyPeers,
|
|
87
100
|
modelID,
|
|
88
|
-
noiseScale,
|
|
89
101
|
outputColumns,
|
|
90
102
|
preprocessingFunctions,
|
|
91
103
|
roundDuration,
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -1,14 +1,11 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { client as clients } from
|
|
3
|
-
import {
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
scheme?: TrainingInformation['scheme'];
|
|
10
|
-
logger?: Logger;
|
|
11
|
-
memory?: Memory;
|
|
1
|
+
import { data, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
|
|
2
|
+
import { client as clients } from "../index.js";
|
|
3
|
+
import type { Aggregator } from "../aggregator/index.js";
|
|
4
|
+
import { RoundLogs, Trainer } from "./trainer.js";
|
|
5
|
+
interface Config {
|
|
6
|
+
scheme: TrainingInformation["scheme"];
|
|
7
|
+
logger: Logger;
|
|
8
|
+
memory: Memory;
|
|
12
9
|
}
|
|
13
10
|
/**
|
|
14
11
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
@@ -16,16 +13,22 @@ export interface DiscoOptions {
|
|
|
16
13
|
* communication with nodes, logs and model memory.
|
|
17
14
|
*/
|
|
18
15
|
export declare class Disco {
|
|
19
|
-
|
|
20
|
-
readonly
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
16
|
+
#private;
|
|
17
|
+
readonly trainer: Trainer;
|
|
18
|
+
private constructor();
|
|
19
|
+
/**
|
|
20
|
+
* Connect to the given task and get ready to train.
|
|
21
|
+
*
|
|
22
|
+
* Will load the model from memory if available or fetch it from the server.
|
|
23
|
+
*
|
|
24
|
+
* @param clientConfig client to connect with or parameters on how to create one.
|
|
25
|
+
**/
|
|
26
|
+
static fromTask(task: Task, clientConfig: clients.Client | URL | {
|
|
27
|
+
aggregator: Aggregator;
|
|
28
|
+
url: URL;
|
|
29
|
+
}, config: Partial<Config>): Promise<Disco>;
|
|
25
30
|
/** Train on dataset, yielding logs of every round. */
|
|
26
|
-
trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs
|
|
27
|
-
participants: number;
|
|
28
|
-
}>;
|
|
31
|
+
trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs>;
|
|
29
32
|
/** Train on dataset, yielding logs of every epoch. */
|
|
30
33
|
trainByEpoch(dataTuple: data.DataSplit): AsyncGenerator<EpochLogs>;
|
|
31
34
|
/** Train on dataset, yielding logs of every batch. */
|
|
@@ -33,14 +36,12 @@ export declare class Disco {
|
|
|
33
36
|
/** Run whole train on dataset. */
|
|
34
37
|
trainFully(dataTuple: data.DataSplit): Promise<void>;
|
|
35
38
|
/**
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs
|
|
42
|
-
participants: number;
|
|
43
|
-
}>>;
|
|
39
|
+
* Train on dataset, yield the nested steps.
|
|
40
|
+
*
|
|
41
|
+
* Don't forget to await the yielded generator otherwise nothing will progress.
|
|
42
|
+
* If you don't care about the whole process, use one of the other train methods.
|
|
43
|
+
**/
|
|
44
|
+
train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
|
|
44
45
|
/**
|
|
45
46
|
* Stops the ongoing training instance without disconnecting the client.
|
|
46
47
|
*/
|
|
@@ -50,3 +51,4 @@ export declare class Disco {
|
|
|
50
51
|
*/
|
|
51
52
|
close(): Promise<void>;
|
|
52
53
|
}
|
|
54
|
+
export {};
|
package/dist/training/disco.js
CHANGED
|
@@ -1,52 +1,73 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
import {
|
|
5
|
-
import {
|
|
6
|
-
import { TrainerBuilder } from './trainer/trainer_builder.js';
|
|
1
|
+
import { async_iterator, } from "../index.js";
|
|
2
|
+
import { client as clients, ConsoleLogger, EmptyMemory } from "../index.js";
|
|
3
|
+
import { getAggregator } from "../aggregator/index.js";
|
|
4
|
+
import { enumerate, split } from "../utils/async_iterator.js";
|
|
5
|
+
import { Trainer } from "./trainer.js";
|
|
7
6
|
/**
|
|
8
7
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
9
8
|
* a convenient object providing a reduced yet complete API that wraps model training,
|
|
10
9
|
* communication with nodes, logs and model memory.
|
|
11
10
|
*/
|
|
12
11
|
export class Disco {
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
constructor(task,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
12
|
+
trainer;
|
|
13
|
+
#client;
|
|
14
|
+
#logger;
|
|
15
|
+
// small helper to avoid keeping Task & Memory around
|
|
16
|
+
#updateWorkingModel;
|
|
17
|
+
constructor(trainer, task, client, memory, logger) {
|
|
18
|
+
this.trainer = trainer;
|
|
19
|
+
this.#client = client;
|
|
20
|
+
this.#logger = logger;
|
|
21
|
+
this.#updateWorkingModel = () => memory.updateWorkingModel({
|
|
22
|
+
type: "working",
|
|
23
|
+
taskID: task.id,
|
|
24
|
+
name: task.trainingInformation.modelID,
|
|
25
|
+
tensorBackend: task.trainingInformation.tensorBackend,
|
|
26
|
+
}, this.trainer.model);
|
|
27
|
+
}
|
|
28
|
+
/**
|
|
29
|
+
* Connect to the given task and get ready to train.
|
|
30
|
+
*
|
|
31
|
+
* Will load the model from memory if available or fetch it from the server.
|
|
32
|
+
*
|
|
33
|
+
* @param clientConfig client to connect with or parameters on how to create one.
|
|
34
|
+
**/
|
|
35
|
+
static async fromTask(task, clientConfig, config) {
|
|
36
|
+
const { scheme, logger, memory } = {
|
|
37
|
+
scheme: task.trainingInformation.scheme,
|
|
38
|
+
logger: new ConsoleLogger(),
|
|
39
|
+
memory: new EmptyMemory(),
|
|
40
|
+
...config,
|
|
41
|
+
};
|
|
42
|
+
let client;
|
|
43
|
+
if (clientConfig instanceof clients.Client) {
|
|
44
|
+
client = clientConfig;
|
|
22
45
|
}
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
options.aggregator = getAggregator(task, { scheme: options.scheme });
|
|
46
|
+
else {
|
|
47
|
+
let url, aggregator;
|
|
48
|
+
if (clientConfig instanceof URL) {
|
|
49
|
+
url = clientConfig;
|
|
50
|
+
aggregator = getAggregator(task, { scheme });
|
|
29
51
|
}
|
|
30
|
-
|
|
31
|
-
|
|
52
|
+
else {
|
|
53
|
+
({ url, aggregator } = clientConfig);
|
|
32
54
|
}
|
|
33
|
-
|
|
34
|
-
}
|
|
35
|
-
if (options.logger === undefined) {
|
|
36
|
-
options.logger = new ConsoleLogger();
|
|
37
|
-
}
|
|
38
|
-
if (options.memory === undefined) {
|
|
39
|
-
options.memory = new EmptyMemory();
|
|
40
|
-
}
|
|
41
|
-
if (options.client.task !== task) {
|
|
42
|
-
throw new Error('client not setup for given task');
|
|
55
|
+
client = clients.getClient(scheme, url, task, aggregator);
|
|
43
56
|
}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
57
|
+
if (client.task !== task)
|
|
58
|
+
throw new Error("client not setup for given task");
|
|
59
|
+
let model;
|
|
60
|
+
const memoryInfo = {
|
|
61
|
+
type: "working",
|
|
62
|
+
taskID: task.id,
|
|
63
|
+
name: task.trainingInformation.modelID,
|
|
64
|
+
tensorBackend: task.trainingInformation.tensorBackend,
|
|
65
|
+
};
|
|
66
|
+
if (await memory.contains(memoryInfo))
|
|
67
|
+
model = await memory.getModel(memoryInfo);
|
|
68
|
+
else
|
|
69
|
+
model = await client.getLatestModel();
|
|
70
|
+
return new Disco(new Trainer(task, model, client), task, client, memory, logger);
|
|
50
71
|
}
|
|
51
72
|
/** Train on dataset, yielding logs of every round. */
|
|
52
73
|
async *trainByRound(dataTuple) {
|
|
@@ -83,27 +104,24 @@ export class Disco {
|
|
|
83
104
|
;
|
|
84
105
|
}
|
|
85
106
|
/**
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
// TODO RoundLogs should contain number of participants but Trainer doesn't need client
|
|
107
|
+
* Train on dataset, yield the nested steps.
|
|
108
|
+
*
|
|
109
|
+
* Don't forget to await the yielded generator otherwise nothing will progress.
|
|
110
|
+
* If you don't care about the whole process, use one of the other train methods.
|
|
111
|
+
**/
|
|
92
112
|
async *train(dataTuple) {
|
|
93
|
-
this
|
|
113
|
+
this.#logger.success("Training started.");
|
|
94
114
|
const trainData = dataTuple.train.preprocess().batch();
|
|
95
115
|
const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
|
|
96
|
-
await this
|
|
97
|
-
const
|
|
98
|
-
for await (const [round, epochs] of enumerate(trainer.fitModel(trainData.dataset, validationData.dataset))) {
|
|
116
|
+
await this.#client.connect();
|
|
117
|
+
for await (const [round, epochs] of enumerate(this.trainer.train(trainData.dataset, validationData.dataset))) {
|
|
99
118
|
yield async function* () {
|
|
100
|
-
|
|
101
|
-
for await (const [epoch, batches] of enumerate(
|
|
119
|
+
const [gen, returnedRoundLogs] = split(epochs);
|
|
120
|
+
for await (const [epoch, batches] of enumerate(gen)) {
|
|
102
121
|
const [gen, returnedEpochLogs] = split(batches);
|
|
103
122
|
yield gen;
|
|
104
123
|
const epochLogs = await returnedEpochLogs;
|
|
105
|
-
|
|
106
|
-
this.logger.success([
|
|
124
|
+
this.#logger.success([
|
|
107
125
|
`Round: ${round}`,
|
|
108
126
|
` Epoch: ${epoch}`,
|
|
109
127
|
` Training loss: ${epochLogs.training.loss}`,
|
|
@@ -116,26 +134,23 @@ export class Disco {
|
|
|
116
134
|
: "",
|
|
117
135
|
].join("\n"));
|
|
118
136
|
}
|
|
119
|
-
return
|
|
120
|
-
epochs: epochsLogs,
|
|
121
|
-
participants: this.client.nbOfParticipants, // already includes ourselves
|
|
122
|
-
};
|
|
137
|
+
return await returnedRoundLogs;
|
|
123
138
|
}.bind(this)();
|
|
139
|
+
await this.#updateWorkingModel(this.trainer.model);
|
|
124
140
|
}
|
|
125
|
-
this
|
|
141
|
+
this.#logger.success("Training finished.");
|
|
126
142
|
}
|
|
127
143
|
/**
|
|
128
144
|
* Stops the ongoing training instance without disconnecting the client.
|
|
129
145
|
*/
|
|
130
146
|
async pause() {
|
|
131
|
-
|
|
132
|
-
await trainer.stopTraining();
|
|
147
|
+
await this.trainer.stopTraining();
|
|
133
148
|
}
|
|
134
149
|
/**
|
|
135
150
|
* Completely stops the ongoing training instance.
|
|
136
151
|
*/
|
|
137
152
|
async close() {
|
|
138
153
|
await this.pause();
|
|
139
|
-
await this
|
|
154
|
+
await this.#client.disconnect();
|
|
140
155
|
}
|
|
141
156
|
}
|
package/dist/training/index.d.ts
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
export { Disco } from './disco.js';
|
|
2
|
-
export { RoundLogs } from './trainer
|
|
2
|
+
export { RoundLogs, Trainer } from './trainer.js';
|
package/dist/training/index.js
CHANGED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { List } from "immutable";
|
|
3
|
+
import type { BatchLogs, EpochLogs, Model, Task } from "../index.js";
|
|
4
|
+
import { Client } from "../client/index.js";
|
|
5
|
+
export interface RoundLogs {
|
|
6
|
+
epochs: List<EpochLogs>;
|
|
7
|
+
participants: number;
|
|
8
|
+
}
|
|
9
|
+
/** Train a model and exchange with others **/
|
|
10
|
+
export declare class Trainer {
|
|
11
|
+
#private;
|
|
12
|
+
readonly model: Model;
|
|
13
|
+
constructor(task: Task, model: Model, client: Client);
|
|
14
|
+
stopTraining(): Promise<void>;
|
|
15
|
+
train(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
|
|
16
|
+
}
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { List } from "immutable";
|
|
3
|
+
import { privacy } from "../index.js";
|
|
4
|
+
import * as async_iterator from "../utils/async_iterator.js";
|
|
5
|
+
/** Train a model and exchange with others **/
|
|
6
|
+
export class Trainer {
|
|
7
|
+
model;
|
|
8
|
+
#client;
|
|
9
|
+
#roundDuration;
|
|
10
|
+
#epochs;
|
|
11
|
+
#privacy;
|
|
12
|
+
#training;
|
|
13
|
+
constructor(task, model, client) {
|
|
14
|
+
this.model = model;
|
|
15
|
+
this.#client = client;
|
|
16
|
+
this.#roundDuration = task.trainingInformation.roundDuration;
|
|
17
|
+
this.#epochs = task.trainingInformation.epochs;
|
|
18
|
+
this.#privacy = task.trainingInformation.privacy;
|
|
19
|
+
if (!Number.isInteger(this.#epochs / this.#roundDuration))
|
|
20
|
+
throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
|
|
21
|
+
}
|
|
22
|
+
async stopTraining() {
|
|
23
|
+
await this.#training?.return();
|
|
24
|
+
}
|
|
25
|
+
async *train(dataset, valDataset) {
|
|
26
|
+
if (this.#training !== undefined)
|
|
27
|
+
throw new Error("training already running, stop it before launching a new one");
|
|
28
|
+
try {
|
|
29
|
+
this.#training = this.#runRounds(dataset, valDataset);
|
|
30
|
+
yield* this.#training;
|
|
31
|
+
}
|
|
32
|
+
finally {
|
|
33
|
+
this.#training = undefined;
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
async *#runRounds(dataset, valDataset) {
|
|
37
|
+
const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
|
|
38
|
+
let previousRoundWeights;
|
|
39
|
+
for (let round = 0; round < totalRound; round++) {
|
|
40
|
+
await this.#client.onRoundBeginCommunication(this.model.weights, round);
|
|
41
|
+
yield this.#runRound(dataset, valDataset);
|
|
42
|
+
let localWeights = this.model.weights;
|
|
43
|
+
if (this.#privacy !== undefined)
|
|
44
|
+
localWeights = await applyPrivacy(previousRoundWeights, localWeights, this.#privacy);
|
|
45
|
+
const networkWeights = await this.#client.onRoundEndCommunication(localWeights, round);
|
|
46
|
+
this.model.weights = previousRoundWeights = networkWeights;
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
async *#runRound(dataset, valDataset) {
|
|
50
|
+
let epochsLogs = List();
|
|
51
|
+
for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
|
|
52
|
+
const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
|
|
53
|
+
yield gen;
|
|
54
|
+
epochsLogs = epochsLogs.push(await epochLogs);
|
|
55
|
+
}
|
|
56
|
+
return {
|
|
57
|
+
epochs: epochsLogs,
|
|
58
|
+
participants: this.#client.nbOfParticipants,
|
|
59
|
+
};
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
async function applyPrivacy(previous, current, options) {
|
|
63
|
+
let ret = current;
|
|
64
|
+
if (options.clippingRadius !== undefined) {
|
|
65
|
+
const previousRoundWeights = previous ?? current.map((w) => tf.zerosLike(w));
|
|
66
|
+
const weightsProgress = current.sub(previousRoundWeights);
|
|
67
|
+
ret = previousRoundWeights.add(await privacy.clipNorm(weightsProgress, options.clippingRadius));
|
|
68
|
+
}
|
|
69
|
+
if (options.noiseScale !== undefined)
|
|
70
|
+
ret = privacy.addNoise(ret, options.noiseScale);
|
|
71
|
+
return ret;
|
|
72
|
+
}
|
|
@@ -51,11 +51,6 @@ export declare class WeightsContainer {
|
|
|
51
51
|
* @returns The tensor located at the index
|
|
52
52
|
*/
|
|
53
53
|
get(index: number): tf.Tensor | undefined;
|
|
54
|
-
/**
|
|
55
|
-
* Computes the weights container's Frobenius norm
|
|
56
|
-
* @returns The Frobenius norm
|
|
57
|
-
*/
|
|
58
|
-
frobeniusNorm(): number;
|
|
59
54
|
concat(other: WeightsContainer): WeightsContainer;
|
|
60
55
|
equals(other: WeightsContainer, margin?: number): boolean;
|
|
61
56
|
/**
|
|
@@ -70,13 +70,6 @@ export class WeightsContainer {
|
|
|
70
70
|
get(index) {
|
|
71
71
|
return this._weights.get(index);
|
|
72
72
|
}
|
|
73
|
-
/**
|
|
74
|
-
* Computes the weights container's Frobenius norm
|
|
75
|
-
* @returns The Frobenius norm
|
|
76
|
-
*/
|
|
77
|
-
frobeniusNorm() {
|
|
78
|
-
return Math.sqrt(this.map((w) => w.square().sum()).reduce((a, b) => a.add(b)).dataSync()[0]);
|
|
79
|
-
}
|
|
80
73
|
concat(other) {
|
|
81
74
|
return WeightsContainer.of(...this.weights, ...other.weights);
|
|
82
75
|
}
|
package/package.json
CHANGED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
import type { AggregatorBase } from './aggregator/index.js';
|
|
2
|
-
export declare class AsyncInformant<T> {
|
|
3
|
-
private readonly aggregator;
|
|
4
|
-
private _round;
|
|
5
|
-
private _currentNumberOfParticipants;
|
|
6
|
-
private _totalNumberOfParticipants;
|
|
7
|
-
private _averageNumberOfParticipants;
|
|
8
|
-
constructor(aggregator: AggregatorBase<T>);
|
|
9
|
-
update(): void;
|
|
10
|
-
get round(): number;
|
|
11
|
-
get currentNumberOfParticipants(): number;
|
|
12
|
-
get totalNumberOfParticipants(): number;
|
|
13
|
-
get averageNumberOfParticipants(): number;
|
|
14
|
-
getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
|
|
15
|
-
}
|
package/dist/async_informant.js
DELETED
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
export class AsyncInformant {
|
|
2
|
-
aggregator;
|
|
3
|
-
_round = 0;
|
|
4
|
-
_currentNumberOfParticipants = 0;
|
|
5
|
-
_totalNumberOfParticipants = 0;
|
|
6
|
-
_averageNumberOfParticipants = 0;
|
|
7
|
-
constructor(aggregator) {
|
|
8
|
-
this.aggregator = aggregator;
|
|
9
|
-
}
|
|
10
|
-
update() {
|
|
11
|
-
if (this.round === 0 || this.round < this.aggregator.round) {
|
|
12
|
-
this._round = this.aggregator.round;
|
|
13
|
-
this._currentNumberOfParticipants = this.aggregator.size;
|
|
14
|
-
this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
|
|
15
|
-
this._totalNumberOfParticipants += this.currentNumberOfParticipants;
|
|
16
|
-
}
|
|
17
|
-
else {
|
|
18
|
-
this._round = this.aggregator.round;
|
|
19
|
-
}
|
|
20
|
-
}
|
|
21
|
-
// Getter functions
|
|
22
|
-
get round() {
|
|
23
|
-
return this._round;
|
|
24
|
-
}
|
|
25
|
-
get currentNumberOfParticipants() {
|
|
26
|
-
return this._currentNumberOfParticipants;
|
|
27
|
-
}
|
|
28
|
-
get totalNumberOfParticipants() {
|
|
29
|
-
return this._totalNumberOfParticipants;
|
|
30
|
-
}
|
|
31
|
-
get averageNumberOfParticipants() {
|
|
32
|
-
return this._averageNumberOfParticipants;
|
|
33
|
-
}
|
|
34
|
-
getAllStatistics() {
|
|
35
|
-
return {
|
|
36
|
-
round: this.round,
|
|
37
|
-
currentNumberOfParticipants: this.currentNumberOfParticipants,
|
|
38
|
-
totalNumberOfParticipants: this.totalNumberOfParticipants,
|
|
39
|
-
averageNumberOfParticipants: this.averageNumberOfParticipants
|
|
40
|
-
};
|
|
41
|
-
}
|
|
42
|
-
}
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
import type { Model, Memory, Task, client as clients } from "../../index.js";
|
|
2
|
-
import { Trainer } from "./trainer.js";
|
|
3
|
-
/**
|
|
4
|
-
* Class whose role is to train a model in a distributed way with a given dataset.
|
|
5
|
-
*/
|
|
6
|
-
export declare class DistributedTrainer extends Trainer {
|
|
7
|
-
private readonly task;
|
|
8
|
-
private readonly memory;
|
|
9
|
-
private readonly client;
|
|
10
|
-
private readonly aggregator;
|
|
11
|
-
/**
|
|
12
|
-
* DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
|
|
13
|
-
*/
|
|
14
|
-
constructor(task: Task, memory: Memory, model: Model, client: clients.Client);
|
|
15
|
-
onRoundBegin(round: number): Promise<void>;
|
|
16
|
-
/**
|
|
17
|
-
* Callback called every time a round is over
|
|
18
|
-
*/
|
|
19
|
-
onRoundEnd(round: number): Promise<void>;
|
|
20
|
-
}
|