@epfml/discojs 2.2.2-p20240703101552.0 → 3.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +9 -48
- package/dist/aggregator/base.js +8 -69
- package/dist/aggregator/get.d.ts +23 -11
- package/dist/aggregator/get.js +40 -23
- package/dist/aggregator/index.d.ts +1 -1
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +25 -6
- package/dist/aggregator/mean.js +62 -17
- package/dist/aggregator/secure.d.ts +2 -2
- package/dist/aggregator/secure.js +4 -7
- package/dist/client/base.d.ts +3 -3
- package/dist/client/base.js +6 -8
- package/dist/client/decentralized/base.d.ts +27 -10
- package/dist/client/decentralized/base.js +123 -86
- package/dist/client/decentralized/peer.js +7 -12
- package/dist/client/decentralized/peer_pool.js +6 -2
- package/dist/client/event_connection.d.ts +1 -1
- package/dist/client/event_connection.js +3 -3
- package/dist/client/federated/base.d.ts +5 -21
- package/dist/client/federated/base.js +38 -61
- package/dist/client/federated/messages.d.ts +2 -10
- package/dist/client/federated/messages.js +0 -1
- package/dist/client/index.d.ts +1 -1
- package/dist/client/index.js +1 -1
- package/dist/client/local.d.ts +3 -1
- package/dist/client/local.js +4 -1
- package/dist/client/messages.d.ts +1 -2
- package/dist/client/messages.js +8 -3
- package/dist/client/utils.d.ts +4 -2
- package/dist/client/utils.js +18 -3
- package/dist/dataset/data/data.d.ts +1 -1
- package/dist/dataset/data/data.js +13 -2
- package/dist/dataset/data/preprocessing/image_preprocessing.js +6 -4
- package/dist/default_tasks/cifar10.js +1 -2
- package/dist/default_tasks/lus_covid.js +0 -5
- package/dist/default_tasks/mnist.js +15 -14
- package/dist/default_tasks/simple_face.js +0 -2
- package/dist/default_tasks/titanic.js +2 -4
- package/dist/default_tasks/wikitext.js +7 -1
- package/dist/index.d.ts +0 -1
- package/dist/index.js +0 -1
- package/dist/models/gpt/config.js +1 -1
- package/dist/privacy.d.ts +8 -10
- package/dist/privacy.js +25 -40
- package/dist/task/task_handler.js +10 -2
- package/dist/task/training_information.d.ts +7 -4
- package/dist/task/training_information.js +25 -6
- package/dist/training/disco.d.ts +30 -28
- package/dist/training/disco.js +75 -73
- package/dist/training/index.d.ts +1 -1
- package/dist/training/index.js +1 -0
- package/dist/training/trainer.d.ts +16 -0
- package/dist/training/trainer.js +72 -0
- package/dist/types.d.ts +0 -2
- package/dist/weights/weights_container.d.ts +0 -5
- package/dist/weights/weights_container.js +0 -7
- package/package.json +1 -1
- package/dist/async_informant.d.ts +0 -15
- package/dist/async_informant.js +0 -42
- package/dist/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/training/trainer/distributed_trainer.js +0 -41
- package/dist/training/trainer/local_trainer.d.ts +0 -12
- package/dist/training/trainer/local_trainer.js +0 -24
- package/dist/training/trainer/trainer.d.ts +0 -32
- package/dist/training/trainer/trainer.js +0 -61
- package/dist/training/trainer/trainer_builder.d.ts +0 -23
- package/dist/training/trainer/trainer_builder.js +0 -47
package/dist/client/index.js
CHANGED
|
@@ -4,5 +4,5 @@ export * as aggregator from '../aggregator/index.js';
|
|
|
4
4
|
export * as decentralized from './decentralized/index.js';
|
|
5
5
|
export * as federated from './federated/index.js';
|
|
6
6
|
export * as messages from './messages.js';
|
|
7
|
-
export
|
|
7
|
+
export { getClient, timeout } from './utils.js';
|
|
8
8
|
export { Local } from './local.js';
|
package/dist/client/local.d.ts
CHANGED
package/dist/client/local.js
CHANGED
package/dist/client/messages.js
CHANGED
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
export var type;
|
|
2
2
|
(function (type) {
|
|
3
|
+
// Sent from client to server as first point of contact to join a task.
|
|
4
|
+
// The server answers with an node id in a AssignNodeID message
|
|
3
5
|
type[type["ClientConnected"] = 0] = "ClientConnected";
|
|
6
|
+
// When a user joins a task with a ClientConnected message, the server
|
|
7
|
+
// answers with an AssignNodeID message with its peer id.
|
|
4
8
|
type[type["AssignNodeID"] = 1] = "AssignNodeID";
|
|
5
|
-
|
|
9
|
+
/* Decentralized */
|
|
10
|
+
// Message forwarded by the server from a client to another client
|
|
11
|
+
// to establish a peer-to-peer (WebRTC) connection
|
|
6
12
|
type[type["SignalForPeer"] = 2] = "SignalForPeer";
|
|
7
13
|
type[type["PeerIsReady"] = 3] = "PeerIsReady";
|
|
8
14
|
type[type["PeersForRound"] = 4] = "PeersForRound";
|
|
9
15
|
type[type["Payload"] = 5] = "Payload";
|
|
10
16
|
// Federated
|
|
11
17
|
type[type["SendPayload"] = 6] = "SendPayload";
|
|
12
|
-
type[type["
|
|
13
|
-
type[type["ReceiveServerPayload"] = 8] = "ReceiveServerPayload";
|
|
18
|
+
type[type["ReceiveServerPayload"] = 7] = "ReceiveServerPayload";
|
|
14
19
|
})(type || (type = {}));
|
|
15
20
|
export function hasMessageType(raw) {
|
|
16
21
|
if (typeof raw !== 'object' || raw === null) {
|
package/dist/client/utils.d.ts
CHANGED
|
@@ -1,2 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import type { Task } from '../index.js';
|
|
2
|
+
import { client as clients, type aggregator } from '../index.js';
|
|
3
|
+
export declare function timeout(ms?: number, errorMsg?: string): Promise<never>;
|
|
4
|
+
export declare function getClient(trainingScheme: Required<Task['trainingInformation']['scheme']>, serverURL: URL, task: Task, aggregator: aggregator.Aggregator): clients.Client;
|
package/dist/client/utils.js
CHANGED
|
@@ -1,7 +1,22 @@
|
|
|
1
|
+
import { client as clients } from '../index.js';
|
|
1
2
|
// Time to wait for the others in milliseconds.
|
|
2
|
-
|
|
3
|
-
export async function timeout(ms = MAX_WAIT_PER_ROUND) {
|
|
3
|
+
const MAX_WAIT_PER_ROUND = 15_000;
|
|
4
|
+
export async function timeout(ms = MAX_WAIT_PER_ROUND, errorMsg = 'timeout') {
|
|
4
5
|
return await new Promise((_, reject) => {
|
|
5
|
-
setTimeout(() => { reject(new Error(
|
|
6
|
+
setTimeout(() => { reject(new Error(errorMsg)); }, ms);
|
|
6
7
|
});
|
|
7
8
|
}
|
|
9
|
+
export function getClient(trainingScheme, serverURL, task, aggregator) {
|
|
10
|
+
switch (trainingScheme) {
|
|
11
|
+
case 'decentralized':
|
|
12
|
+
return new clients.decentralized.DecentralizedClient(serverURL, task, aggregator);
|
|
13
|
+
case 'federated':
|
|
14
|
+
return new clients.federated.FederatedClient(serverURL, task, aggregator);
|
|
15
|
+
case 'local':
|
|
16
|
+
return new clients.Local(serverURL, task, aggregator);
|
|
17
|
+
default: {
|
|
18
|
+
const _ = trainingScheme;
|
|
19
|
+
throw new Error('should never happen');
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
}
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
1
2
|
/**
|
|
2
3
|
* Abstract class representing an immutable Disco dataset, including a TF.js dataset,
|
|
3
4
|
* Disco task and set of preprocessing functions.
|
|
@@ -59,8 +60,18 @@ export class Data {
|
|
|
59
60
|
if (applyPreprocessing.size === 0) {
|
|
60
61
|
return x => Promise.resolve(x);
|
|
61
62
|
}
|
|
62
|
-
const preprocessingChain =
|
|
63
|
-
|
|
63
|
+
const preprocessingChain = async (input) => {
|
|
64
|
+
let currentContainer = await input; // Start with the initial tensor container
|
|
65
|
+
for (const fn of applyPreprocessing) {
|
|
66
|
+
const newContainer = await fn(Promise.resolve(currentContainer), this.task);
|
|
67
|
+
if (currentContainer !== newContainer) {
|
|
68
|
+
tf.dispose(currentContainer); // Dispose of the old container
|
|
69
|
+
}
|
|
70
|
+
currentContainer = newContainer;
|
|
71
|
+
}
|
|
72
|
+
return currentContainer; // Return the final tensor container
|
|
73
|
+
};
|
|
74
|
+
return async (entry) => await preprocessingChain(Promise.resolve(entry));
|
|
64
75
|
}
|
|
65
76
|
/**
|
|
66
77
|
* The TF.js dataset preprocessing according to the set of preprocessing functions and the task's
|
|
@@ -25,10 +25,12 @@ const normalize = {
|
|
|
25
25
|
type: ImagePreprocessing.Normalize,
|
|
26
26
|
apply: async (entry) => {
|
|
27
27
|
const { xs, ys } = await entry;
|
|
28
|
-
return {
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
28
|
+
return tf.tidy(() => {
|
|
29
|
+
return {
|
|
30
|
+
xs: xs.div(tf.scalar(255)),
|
|
31
|
+
ys
|
|
32
|
+
};
|
|
33
|
+
});
|
|
32
34
|
}
|
|
33
35
|
};
|
|
34
36
|
/**
|
|
@@ -30,8 +30,7 @@ export const cifar10 = {
|
|
|
30
30
|
IMAGE_W: 224,
|
|
31
31
|
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
|
|
32
32
|
scheme: 'decentralized',
|
|
33
|
-
noiseScale:
|
|
34
|
-
clippingRadius: 20,
|
|
33
|
+
privacy: { clippingRadius: 20, noiseScale: 1 },
|
|
35
34
|
decentralizedSecure: true,
|
|
36
35
|
minimumReadyPeers: 3,
|
|
37
36
|
maxShareValue: 100,
|
|
@@ -29,11 +29,6 @@ export const lusCovid = {
|
|
|
29
29
|
LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
|
|
30
30
|
dataType: 'image',
|
|
31
31
|
scheme: 'federated',
|
|
32
|
-
noiseScale: undefined,
|
|
33
|
-
clippingRadius: 20,
|
|
34
|
-
decentralizedSecure: true,
|
|
35
|
-
minimumReadyPeers: 2,
|
|
36
|
-
maxShareValue: 100,
|
|
37
32
|
tensorBackend: 'tfjs'
|
|
38
33
|
}
|
|
39
34
|
};
|
|
@@ -20,17 +20,16 @@ export const mnist = {
|
|
|
20
20
|
trainingInformation: {
|
|
21
21
|
modelID: 'mnist-model',
|
|
22
22
|
epochs: 20,
|
|
23
|
-
roundDuration:
|
|
23
|
+
roundDuration: 2,
|
|
24
24
|
validationSplit: 0.2,
|
|
25
|
-
batchSize:
|
|
25
|
+
batchSize: 64,
|
|
26
26
|
dataType: 'image',
|
|
27
27
|
IMAGE_H: 28,
|
|
28
28
|
IMAGE_W: 28,
|
|
29
|
-
|
|
29
|
+
// Images should already be at the right size but resizing just in case
|
|
30
|
+
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
30
31
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
31
32
|
scheme: 'decentralized',
|
|
32
|
-
noiseScale: undefined,
|
|
33
|
-
clippingRadius: 20,
|
|
34
33
|
decentralizedSecure: true,
|
|
35
34
|
minimumReadyPeers: 3,
|
|
36
35
|
maxShareValue: 100,
|
|
@@ -39,22 +38,24 @@ export const mnist = {
|
|
|
39
38
|
};
|
|
40
39
|
},
|
|
41
40
|
getModel() {
|
|
41
|
+
// Architecture from the PyTorch MNIST example (I made it slightly smaller, 650kB instead of 5MB)
|
|
42
|
+
// https://github.com/pytorch/examples/blob/main/mnist/main.py
|
|
42
43
|
const model = tf.sequential();
|
|
43
44
|
model.add(tf.layers.conv2d({
|
|
44
45
|
inputShape: [28, 28, 3],
|
|
45
|
-
kernelSize:
|
|
46
|
-
filters:
|
|
47
|
-
activation: 'relu'
|
|
46
|
+
kernelSize: 5,
|
|
47
|
+
filters: 8,
|
|
48
|
+
activation: 'relu',
|
|
48
49
|
}));
|
|
50
|
+
model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, activation: 'relu' }));
|
|
49
51
|
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
|
|
50
|
-
model.add(tf.layers.
|
|
51
|
-
model.add(tf.layers.
|
|
52
|
-
model.add(tf.layers.
|
|
53
|
-
model.add(tf.layers.
|
|
54
|
-
model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
|
|
52
|
+
model.add(tf.layers.dropout({ rate: 0.25 }));
|
|
53
|
+
model.add(tf.layers.flatten());
|
|
54
|
+
model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
|
|
55
|
+
model.add(tf.layers.dropout({ rate: 0.25 }));
|
|
55
56
|
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
|
|
56
57
|
model.compile({
|
|
57
|
-
optimizer: '
|
|
58
|
+
optimizer: 'adam',
|
|
58
59
|
loss: 'categoricalCrossentropy',
|
|
59
60
|
metrics: ['accuracy']
|
|
60
61
|
});
|
|
@@ -46,8 +46,8 @@ export const titanic = {
|
|
|
46
46
|
},
|
|
47
47
|
trainingInformation: {
|
|
48
48
|
modelID: 'titanic-model',
|
|
49
|
-
epochs:
|
|
50
|
-
roundDuration:
|
|
49
|
+
epochs: 10,
|
|
50
|
+
roundDuration: 2,
|
|
51
51
|
validationSplit: 0.2,
|
|
52
52
|
batchSize: 30,
|
|
53
53
|
preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
|
|
@@ -63,8 +63,6 @@ export const titanic = {
|
|
|
63
63
|
'Survived'
|
|
64
64
|
],
|
|
65
65
|
scheme: 'federated', // secure aggregation not yet implemented for FeAI
|
|
66
|
-
noiseScale: undefined,
|
|
67
|
-
clippingRadius: undefined,
|
|
68
66
|
tensorBackend: 'tfjs'
|
|
69
67
|
}
|
|
70
68
|
};
|
|
@@ -9,7 +9,13 @@ export const wikitext = {
|
|
|
9
9
|
preview: 'Train a language model (L)LM in your browser, collaboratively and from scratch.',
|
|
10
10
|
overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling, which you can download <a class='underline text-blue-400' target='_blank' href='https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'>here</a>. More information on how to connect the dataset at the next step."
|
|
11
11
|
},
|
|
12
|
-
model:
|
|
12
|
+
model: [
|
|
13
|
+
"The model follows the exact GPT-2 architecture and is implemented in TensorFlow.js.",
|
|
14
|
+
"The tokenizer used for preprocessing is the GPT-2 Byte-Pair encoding tokenizer.",
|
|
15
|
+
"The model is trained via an Adam optimizer with unit gradient clipping and softmax cross-entropy loss.",
|
|
16
|
+
"It has around 5M parameters.",
|
|
17
|
+
"To accommodate all devices, the context length is currently kept at 128 and the batch size at 1.",
|
|
18
|
+
].join(" "),
|
|
13
19
|
dataFormatInformation: 'You can use any natural language (text) dataset you like. For example the Wikitext-103 dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
|
|
14
20
|
dataExampleText: 'An example excerpt from the dataset is: <i>"For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work ( except eight private performances for Ludwig II at Munich in 1884 and 1885 ) ."</i>',
|
|
15
21
|
sampleDatasetLink: 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz',
|
package/dist/index.d.ts
CHANGED
|
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
|
|
|
5
5
|
export * as client from './client/index.js';
|
|
6
6
|
export * as aggregator from './aggregator/index.js';
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
|
-
export { AsyncInformant } from './async_informant.js';
|
|
9
8
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
10
9
|
export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
|
|
11
10
|
export { Disco, RoundLogs } from './training/index.js';
|
package/dist/index.js
CHANGED
|
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
|
|
|
5
5
|
export * as client from './client/index.js';
|
|
6
6
|
export * as aggregator from './aggregator/index.js';
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
|
-
export { AsyncInformant } from './async_informant.js';
|
|
9
8
|
export { ConsoleLogger } from './logging/index.js';
|
|
10
9
|
export { Memory, Empty as EmptyMemory } from './memory/index.js';
|
|
11
10
|
export { Disco } from './training/index.js';
|
package/dist/privacy.d.ts
CHANGED
|
@@ -1,11 +1,9 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { WeightsContainer } from "./index.js";
|
|
2
|
+
/** Scramble weights */
|
|
3
|
+
export declare function addNoise(weights: WeightsContainer, deviation: number): WeightsContainer;
|
|
2
4
|
/**
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
* @param task the task
|
|
9
|
-
* @returns the noised weights for the current round
|
|
10
|
-
*/
|
|
11
|
-
export declare function addDifferentialPrivacy(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, task: Task): WeightsContainer;
|
|
5
|
+
* Keep weights' norm within radius
|
|
6
|
+
*
|
|
7
|
+
* @param radius maximum norm
|
|
8
|
+
**/
|
|
9
|
+
export declare function clipNorm(weights: WeightsContainer, radius: number): Promise<WeightsContainer>;
|
package/dist/privacy.js
CHANGED
|
@@ -1,42 +1,27 @@
|
|
|
1
|
-
import * as tf from
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
async function frobeniusNorm(weights) {
|
|
3
|
+
const squared = await weights
|
|
4
|
+
.map((w) => w.square().sum())
|
|
5
|
+
.reduce((a, b) => a.add(b))
|
|
6
|
+
.data();
|
|
7
|
+
if (squared.length !== 1)
|
|
8
|
+
throw new Error("unexcepted weights shape");
|
|
9
|
+
return Math.sqrt(squared[0]);
|
|
10
|
+
}
|
|
11
|
+
/** Scramble weights */
|
|
12
|
+
export function addNoise(weights, deviation) {
|
|
13
|
+
const variance = Math.pow(deviation, 2);
|
|
14
|
+
return weights.map((w) => w.add(tf.randomNormal(w.shape, 0, variance)));
|
|
15
|
+
}
|
|
2
16
|
/**
|
|
3
|
-
*
|
|
4
|
-
*
|
|
5
|
-
*
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
const clippingRadius = task.trainingInformation?.clippingRadius;
|
|
14
|
-
const weightsDiff = updatedWeights.sub(staleWeights);
|
|
15
|
-
let newWeightsDiff;
|
|
16
|
-
if (clippingRadius !== undefined) {
|
|
17
|
-
// Frobenius norm
|
|
18
|
-
const norm = weightsDiff.frobeniusNorm();
|
|
19
|
-
newWeightsDiff = weightsDiff.map((w) => {
|
|
20
|
-
const clipped = w.div(Math.max(1, norm / clippingRadius));
|
|
21
|
-
if (noiseScale !== undefined) {
|
|
22
|
-
// Add clipping and noise
|
|
23
|
-
const noise = tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
|
|
24
|
-
return clipped.add(noise);
|
|
25
|
-
}
|
|
26
|
-
else {
|
|
27
|
-
// Add clipping without any noise
|
|
28
|
-
return clipped;
|
|
29
|
-
}
|
|
30
|
-
});
|
|
31
|
-
}
|
|
32
|
-
else {
|
|
33
|
-
if (noiseScale !== undefined) {
|
|
34
|
-
// Add noise without any clipping
|
|
35
|
-
newWeightsDiff = weightsDiff.map((w) => tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)));
|
|
36
|
-
}
|
|
37
|
-
else {
|
|
38
|
-
return updatedWeights;
|
|
39
|
-
}
|
|
40
|
-
}
|
|
41
|
-
return staleWeights.add(newWeightsDiff);
|
|
17
|
+
* Keep weights' norm within radius
|
|
18
|
+
*
|
|
19
|
+
* @param radius maximum norm
|
|
20
|
+
**/
|
|
21
|
+
export async function clipNorm(weights, radius) {
|
|
22
|
+
if (radius <= 0)
|
|
23
|
+
throw new Error("invalid radius");
|
|
24
|
+
const norm = await frobeniusNorm(weights);
|
|
25
|
+
const scaling = Math.max(1, norm / radius);
|
|
26
|
+
return weights.map((w) => w.div(scaling));
|
|
42
27
|
}
|
|
@@ -13,8 +13,16 @@ export async function pushTask(url, task, model) {
|
|
|
13
13
|
export async function fetchTasks(url) {
|
|
14
14
|
const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
|
|
15
15
|
const tasks = response.data;
|
|
16
|
-
if (!
|
|
17
|
-
throw new Error('
|
|
16
|
+
if (!Array.isArray(tasks)) {
|
|
17
|
+
throw new Error('Expected to receive an array of Tasks when fetching tasks');
|
|
18
|
+
}
|
|
19
|
+
else if (!tasks.every(isTask)) {
|
|
20
|
+
for (const task of tasks) {
|
|
21
|
+
if (!isTask(task)) {
|
|
22
|
+
console.error("task has invalid format:", task);
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
throw new Error('invalid tasks response, the task object received is not well formatted');
|
|
18
26
|
}
|
|
19
27
|
return Map(tasks.map((t) => [t.id, t]));
|
|
20
28
|
}
|
|
@@ -1,6 +1,9 @@
|
|
|
1
|
-
import type { AggregatorChoice } from '../aggregator/get.js';
|
|
2
1
|
import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
|
|
3
2
|
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
3
|
+
interface Privacy {
|
|
4
|
+
clippingRadius?: number;
|
|
5
|
+
noiseScale?: number;
|
|
6
|
+
}
|
|
4
7
|
export interface TrainingInformation {
|
|
5
8
|
modelID: string;
|
|
6
9
|
epochs: number;
|
|
@@ -15,14 +18,14 @@ export interface TrainingInformation {
|
|
|
15
18
|
IMAGE_W?: number;
|
|
16
19
|
LABEL_LIST?: string[];
|
|
17
20
|
scheme: 'decentralized' | 'federated' | 'local';
|
|
18
|
-
|
|
19
|
-
clippingRadius?: number;
|
|
21
|
+
privacy?: Privacy;
|
|
20
22
|
decentralizedSecure?: boolean;
|
|
21
23
|
maxShareValue?: number;
|
|
22
24
|
minimumReadyPeers?: number;
|
|
23
|
-
aggregator?:
|
|
25
|
+
aggregator?: 'mean' | 'secure';
|
|
24
26
|
tokenizer?: string | PreTrainedTokenizer;
|
|
25
27
|
maxSequenceLength?: number;
|
|
26
28
|
tensorBackend: 'tfjs' | 'gpt';
|
|
27
29
|
}
|
|
28
30
|
export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
|
|
31
|
+
export {};
|
|
@@ -6,11 +6,25 @@ function isStringArray(raw) {
|
|
|
6
6
|
const arr = raw; // isArray is unsafely guarding with any[]
|
|
7
7
|
return arr.every((e) => typeof e === 'string');
|
|
8
8
|
}
|
|
9
|
+
function isPrivacy(raw) {
|
|
10
|
+
if (typeof raw !== "object" || raw === null) {
|
|
11
|
+
return false;
|
|
12
|
+
}
|
|
13
|
+
const { clippingRadius, noiseScale, } = raw;
|
|
14
|
+
if ((clippingRadius !== undefined && typeof clippingRadius !== "number") ||
|
|
15
|
+
(noiseScale !== undefined && typeof noiseScale !== "number"))
|
|
16
|
+
return false;
|
|
17
|
+
const _ = {
|
|
18
|
+
clippingRadius,
|
|
19
|
+
noiseScale,
|
|
20
|
+
};
|
|
21
|
+
return true;
|
|
22
|
+
}
|
|
9
23
|
export function isTrainingInformation(raw) {
|
|
10
24
|
if (typeof raw !== 'object' || raw === null) {
|
|
11
25
|
return false;
|
|
12
26
|
}
|
|
13
|
-
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize,
|
|
27
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
|
|
14
28
|
if (typeof dataType !== 'string' ||
|
|
15
29
|
typeof modelID !== 'string' ||
|
|
16
30
|
typeof epochs !== 'number' ||
|
|
@@ -19,12 +33,11 @@ export function isTrainingInformation(raw) {
|
|
|
19
33
|
typeof validationSplit !== 'number' ||
|
|
20
34
|
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
21
35
|
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
22
|
-
(aggregator !== undefined && typeof aggregator !== '
|
|
23
|
-
(clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
|
|
36
|
+
(aggregator !== undefined && typeof aggregator !== 'string') ||
|
|
24
37
|
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
38
|
+
(privacy !== undefined && !isPrivacy(privacy)) ||
|
|
25
39
|
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
26
40
|
(minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
|
|
27
|
-
(noiseScale !== undefined && typeof noiseScale !== 'number') ||
|
|
28
41
|
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
29
42
|
(IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
|
|
30
43
|
(LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
|
|
@@ -33,6 +46,13 @@ export function isTrainingInformation(raw) {
|
|
|
33
46
|
(preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
|
|
34
47
|
return false;
|
|
35
48
|
}
|
|
49
|
+
if (aggregator !== undefined) {
|
|
50
|
+
switch (aggregator) {
|
|
51
|
+
case 'mean': break;
|
|
52
|
+
case 'secure': break;
|
|
53
|
+
default: return false;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
36
56
|
switch (dataType) {
|
|
37
57
|
case 'image': break;
|
|
38
58
|
case 'tabular': break;
|
|
@@ -70,15 +90,14 @@ export function isTrainingInformation(raw) {
|
|
|
70
90
|
LABEL_LIST,
|
|
71
91
|
aggregator,
|
|
72
92
|
batchSize,
|
|
73
|
-
clippingRadius,
|
|
74
93
|
dataType,
|
|
75
94
|
decentralizedSecure,
|
|
95
|
+
privacy,
|
|
76
96
|
epochs,
|
|
77
97
|
inputColumns,
|
|
78
98
|
maxShareValue,
|
|
79
99
|
minimumReadyPeers,
|
|
80
100
|
modelID,
|
|
81
|
-
noiseScale,
|
|
82
101
|
outputColumns,
|
|
83
102
|
preprocessingFunctions,
|
|
84
103
|
roundDuration,
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -1,14 +1,11 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { client as clients } from
|
|
3
|
-
import type { Aggregator } from
|
|
4
|
-
import
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
scheme?: TrainingInformation['scheme'];
|
|
10
|
-
logger?: Logger;
|
|
11
|
-
memory?: Memory;
|
|
1
|
+
import { data, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
|
|
2
|
+
import { client as clients } from "../index.js";
|
|
3
|
+
import type { Aggregator } from "../aggregator/index.js";
|
|
4
|
+
import { RoundLogs, Trainer } from "./trainer.js";
|
|
5
|
+
interface Config {
|
|
6
|
+
scheme: TrainingInformation["scheme"];
|
|
7
|
+
logger: Logger;
|
|
8
|
+
memory: Memory;
|
|
12
9
|
}
|
|
13
10
|
/**
|
|
14
11
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
@@ -16,16 +13,22 @@ export interface DiscoOptions {
|
|
|
16
13
|
* communication with nodes, logs and model memory.
|
|
17
14
|
*/
|
|
18
15
|
export declare class Disco {
|
|
19
|
-
|
|
20
|
-
readonly
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
16
|
+
#private;
|
|
17
|
+
readonly trainer: Trainer;
|
|
18
|
+
private constructor();
|
|
19
|
+
/**
|
|
20
|
+
* Connect to the given task and get ready to train.
|
|
21
|
+
*
|
|
22
|
+
* Will load the model from memory if available or fetch it from the server.
|
|
23
|
+
*
|
|
24
|
+
* @param clientConfig client to connect with or parameters on how to create one.
|
|
25
|
+
**/
|
|
26
|
+
static fromTask(task: Task, clientConfig: clients.Client | URL | {
|
|
27
|
+
aggregator: Aggregator;
|
|
28
|
+
url: URL;
|
|
29
|
+
}, config: Partial<Config>): Promise<Disco>;
|
|
25
30
|
/** Train on dataset, yielding logs of every round. */
|
|
26
|
-
trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs
|
|
27
|
-
participants: number;
|
|
28
|
-
}>;
|
|
31
|
+
trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs>;
|
|
29
32
|
/** Train on dataset, yielding logs of every epoch. */
|
|
30
33
|
trainByEpoch(dataTuple: data.DataSplit): AsyncGenerator<EpochLogs>;
|
|
31
34
|
/** Train on dataset, yielding logs of every batch. */
|
|
@@ -33,14 +36,12 @@ export declare class Disco {
|
|
|
33
36
|
/** Run whole train on dataset. */
|
|
34
37
|
trainFully(dataTuple: data.DataSplit): Promise<void>;
|
|
35
38
|
/**
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs
|
|
42
|
-
participants: number;
|
|
43
|
-
}>>;
|
|
39
|
+
* Train on dataset, yield the nested steps.
|
|
40
|
+
*
|
|
41
|
+
* Don't forget to await the yielded generator otherwise nothing will progress.
|
|
42
|
+
* If you don't care about the whole process, use one of the other train methods.
|
|
43
|
+
**/
|
|
44
|
+
train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
|
|
44
45
|
/**
|
|
45
46
|
* Stops the ongoing training instance without disconnecting the client.
|
|
46
47
|
*/
|
|
@@ -50,3 +51,4 @@ export declare class Disco {
|
|
|
50
51
|
*/
|
|
51
52
|
close(): Promise<void>;
|
|
52
53
|
}
|
|
54
|
+
export {};
|