@epfml/discojs 2.1.2-p20240722093114.0 → 2.1.2-p20240723143623.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +3 -2
- package/dist/aggregator/base.js +4 -3
- package/dist/aggregator/get.d.ts +25 -11
- package/dist/aggregator/get.js +40 -23
- package/dist/aggregator/index.d.ts +1 -1
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +24 -5
- package/dist/aggregator/mean.js +60 -12
- package/dist/client/base.d.ts +1 -2
- package/dist/client/base.js +6 -2
- package/dist/client/decentralized/base.d.ts +26 -9
- package/dist/client/decentralized/base.js +115 -82
- package/dist/client/decentralized/peer.js +7 -12
- package/dist/client/decentralized/peer_pool.js +6 -2
- package/dist/client/event_connection.js +1 -1
- package/dist/client/federated/base.d.ts +5 -21
- package/dist/client/federated/base.js +37 -61
- package/dist/client/federated/messages.d.ts +2 -10
- package/dist/client/federated/messages.js +0 -1
- package/dist/client/index.d.ts +1 -1
- package/dist/client/index.js +1 -1
- package/dist/client/messages.d.ts +1 -2
- package/dist/client/messages.js +8 -3
- package/dist/client/utils.d.ts +3 -1
- package/dist/client/utils.js +16 -1
- package/dist/default_tasks/mnist.js +15 -12
- package/dist/task/task_handler.js +10 -2
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +8 -1
- package/dist/training/disco.d.ts +2 -2
- package/dist/training/disco.js +11 -24
- package/dist/training/trainer/trainer.js +1 -1
- package/dist/types.d.ts +0 -2
- package/package.json +1 -1
|
@@ -20,13 +20,14 @@ export const mnist = {
|
|
|
20
20
|
trainingInformation: {
|
|
21
21
|
modelID: 'mnist-model',
|
|
22
22
|
epochs: 20,
|
|
23
|
-
roundDuration:
|
|
23
|
+
roundDuration: 2,
|
|
24
24
|
validationSplit: 0.2,
|
|
25
|
-
batchSize:
|
|
25
|
+
batchSize: 64,
|
|
26
26
|
dataType: 'image',
|
|
27
27
|
IMAGE_H: 28,
|
|
28
28
|
IMAGE_W: 28,
|
|
29
|
-
|
|
29
|
+
// Images should already be at the right size but resizing just in case
|
|
30
|
+
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
30
31
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
31
32
|
scheme: 'decentralized',
|
|
32
33
|
noiseScale: undefined,
|
|
@@ -39,22 +40,24 @@ export const mnist = {
|
|
|
39
40
|
};
|
|
40
41
|
},
|
|
41
42
|
getModel() {
|
|
43
|
+
// Architecture from the PyTorch MNIST example (I made it slightly smaller, 650kB instead of 5MB)
|
|
44
|
+
// https://github.com/pytorch/examples/blob/main/mnist/main.py
|
|
42
45
|
const model = tf.sequential();
|
|
43
46
|
model.add(tf.layers.conv2d({
|
|
44
47
|
inputShape: [28, 28, 3],
|
|
45
|
-
kernelSize:
|
|
46
|
-
filters:
|
|
47
|
-
activation: 'relu'
|
|
48
|
+
kernelSize: 5,
|
|
49
|
+
filters: 8,
|
|
50
|
+
activation: 'relu',
|
|
48
51
|
}));
|
|
52
|
+
model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, activation: 'relu' }));
|
|
49
53
|
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
|
|
50
|
-
model.add(tf.layers.
|
|
51
|
-
model.add(tf.layers.
|
|
52
|
-
model.add(tf.layers.
|
|
53
|
-
model.add(tf.layers.
|
|
54
|
-
model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
|
|
54
|
+
model.add(tf.layers.dropout({ rate: 0.25 }));
|
|
55
|
+
model.add(tf.layers.flatten());
|
|
56
|
+
model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
|
|
57
|
+
model.add(tf.layers.dropout({ rate: 0.25 }));
|
|
55
58
|
model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
|
|
56
59
|
model.compile({
|
|
57
|
-
optimizer: '
|
|
60
|
+
optimizer: 'adam',
|
|
58
61
|
loss: 'categoricalCrossentropy',
|
|
59
62
|
metrics: ['accuracy']
|
|
60
63
|
});
|
|
@@ -13,8 +13,16 @@ export async function pushTask(url, task, model) {
|
|
|
13
13
|
export async function fetchTasks(url) {
|
|
14
14
|
const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
|
|
15
15
|
const tasks = response.data;
|
|
16
|
-
if (!
|
|
17
|
-
throw new Error('
|
|
16
|
+
if (!Array.isArray(tasks)) {
|
|
17
|
+
throw new Error('Expected to receive an array of Tasks when fetching tasks');
|
|
18
|
+
}
|
|
19
|
+
else if (!tasks.every(isTask)) {
|
|
20
|
+
for (const task of tasks) {
|
|
21
|
+
if (!isTask(task)) {
|
|
22
|
+
console.error("task has invalid format:", task);
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
throw new Error('invalid tasks response, the task object received is not well formatted');
|
|
18
26
|
}
|
|
19
27
|
return Map(tasks.map((t) => [t.id, t]));
|
|
20
28
|
}
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import type { AggregatorChoice } from '../aggregator/get.js';
|
|
2
1
|
import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
|
|
3
2
|
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
4
3
|
export interface TrainingInformation {
|
|
@@ -20,7 +19,7 @@ export interface TrainingInformation {
|
|
|
20
19
|
decentralizedSecure?: boolean;
|
|
21
20
|
maxShareValue?: number;
|
|
22
21
|
minimumReadyPeers?: number;
|
|
23
|
-
aggregator?:
|
|
22
|
+
aggregator?: 'mean' | 'secure';
|
|
24
23
|
tokenizer?: string | PreTrainedTokenizer;
|
|
25
24
|
maxSequenceLength?: number;
|
|
26
25
|
tensorBackend: 'tfjs' | 'gpt';
|
|
@@ -19,7 +19,7 @@ export function isTrainingInformation(raw) {
|
|
|
19
19
|
typeof validationSplit !== 'number' ||
|
|
20
20
|
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
21
21
|
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
22
|
-
(aggregator !== undefined && typeof aggregator !== '
|
|
22
|
+
(aggregator !== undefined && typeof aggregator !== 'string') ||
|
|
23
23
|
(clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
|
|
24
24
|
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
25
25
|
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
@@ -33,6 +33,13 @@ export function isTrainingInformation(raw) {
|
|
|
33
33
|
(preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
|
|
34
34
|
return false;
|
|
35
35
|
}
|
|
36
|
+
if (aggregator !== undefined) {
|
|
37
|
+
switch (aggregator) {
|
|
38
|
+
case 'mean': break;
|
|
39
|
+
case 'secure': break;
|
|
40
|
+
default: return false;
|
|
41
|
+
}
|
|
42
|
+
}
|
|
36
43
|
switch (dataType) {
|
|
37
44
|
case 'image': break;
|
|
38
45
|
case 'tabular': break;
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js';
|
|
2
2
|
import { client as clients } from '../index.js';
|
|
3
|
-
import type
|
|
3
|
+
import { type Aggregator } from '../aggregator/index.js';
|
|
4
4
|
import type { RoundLogs } from './trainer/trainer.js';
|
|
5
5
|
export interface DiscoOptions {
|
|
6
6
|
client?: clients.Client;
|
|
@@ -20,7 +20,7 @@ export declare class Disco {
|
|
|
20
20
|
readonly logger: Logger;
|
|
21
21
|
readonly memory: Memory;
|
|
22
22
|
private readonly client;
|
|
23
|
-
private readonly
|
|
23
|
+
private readonly trainerPromise;
|
|
24
24
|
constructor(task: Task, options: DiscoOptions);
|
|
25
25
|
/** Train on dataset, yielding logs of every round. */
|
|
26
26
|
trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
|
package/dist/training/disco.js
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { List } from 'immutable';
|
|
2
2
|
import { async_iterator } from '../index.js';
|
|
3
3
|
import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
|
|
4
|
-
import {
|
|
4
|
+
import { getAggregator } from '../aggregator/index.js';
|
|
5
5
|
import { enumerate, split } from '../utils/async_iterator.js';
|
|
6
6
|
import { TrainerBuilder } from './trainer/trainer_builder.js';
|
|
7
7
|
/**
|
|
@@ -14,36 +14,23 @@ export class Disco {
|
|
|
14
14
|
logger;
|
|
15
15
|
memory;
|
|
16
16
|
client;
|
|
17
|
-
|
|
17
|
+
trainerPromise;
|
|
18
18
|
constructor(task, options) {
|
|
19
|
+
// Fill undefined options with default values
|
|
19
20
|
if (options.scheme === undefined) {
|
|
20
21
|
options.scheme = task.trainingInformation.scheme;
|
|
21
22
|
}
|
|
22
|
-
if (options.aggregator === undefined) {
|
|
23
|
-
options.aggregator = new MeanAggregator();
|
|
24
|
-
}
|
|
25
23
|
if (options.client === undefined) {
|
|
26
24
|
if (options.url === undefined) {
|
|
27
25
|
throw new Error('could not determine client from given parameters');
|
|
28
26
|
}
|
|
27
|
+
if (options.aggregator === undefined) {
|
|
28
|
+
options.aggregator = getAggregator(task, { scheme: options.scheme });
|
|
29
|
+
}
|
|
29
30
|
if (typeof options.url === 'string') {
|
|
30
31
|
options.url = new URL(options.url);
|
|
31
32
|
}
|
|
32
|
-
|
|
33
|
-
case 'federated':
|
|
34
|
-
options.client = new clients.federated.FederatedClient(options.url, task, options.aggregator);
|
|
35
|
-
break;
|
|
36
|
-
case 'decentralized':
|
|
37
|
-
options.client = new clients.decentralized.DecentralizedClient(options.url, task, options.aggregator);
|
|
38
|
-
break;
|
|
39
|
-
case 'local':
|
|
40
|
-
options.client = new clients.Local(options.url, task, options.aggregator);
|
|
41
|
-
break;
|
|
42
|
-
default: {
|
|
43
|
-
const _ = options.scheme;
|
|
44
|
-
throw new Error('should never happen');
|
|
45
|
-
}
|
|
46
|
-
}
|
|
33
|
+
options.client = clients.getClient(options.scheme, options.url, task, options.aggregator);
|
|
47
34
|
}
|
|
48
35
|
if (options.logger === undefined) {
|
|
49
36
|
options.logger = new ConsoleLogger();
|
|
@@ -59,7 +46,7 @@ export class Disco {
|
|
|
59
46
|
this.memory = options.memory;
|
|
60
47
|
this.logger = options.logger;
|
|
61
48
|
const trainerBuilder = new TrainerBuilder(this.memory, this.task);
|
|
62
|
-
this.
|
|
49
|
+
this.trainerPromise = trainerBuilder.build(this.client, options.scheme !== 'local');
|
|
63
50
|
}
|
|
64
51
|
/** Train on dataset, yielding logs of every round. */
|
|
65
52
|
async *trainByRound(dataTuple) {
|
|
@@ -107,7 +94,7 @@ export class Disco {
|
|
|
107
94
|
const trainData = dataTuple.train.preprocess().batch();
|
|
108
95
|
const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
|
|
109
96
|
await this.client.connect();
|
|
110
|
-
const trainer = await this.
|
|
97
|
+
const trainer = await this.trainerPromise;
|
|
111
98
|
for await (const [round, epochs] of enumerate(trainer.fitModel(trainData.dataset, validationData.dataset))) {
|
|
112
99
|
yield async function* () {
|
|
113
100
|
let epochsLogs = List();
|
|
@@ -131,7 +118,7 @@ export class Disco {
|
|
|
131
118
|
}
|
|
132
119
|
return {
|
|
133
120
|
epochs: epochsLogs,
|
|
134
|
-
participants: this.client.
|
|
121
|
+
participants: this.client.nbOfParticipants, // already includes ourselves
|
|
135
122
|
};
|
|
136
123
|
}.bind(this)();
|
|
137
124
|
}
|
|
@@ -141,7 +128,7 @@ export class Disco {
|
|
|
141
128
|
* Stops the ongoing training instance without disconnecting the client.
|
|
142
129
|
*/
|
|
143
130
|
async pause() {
|
|
144
|
-
const trainer = await this.
|
|
131
|
+
const trainer = await this.trainerPromise;
|
|
145
132
|
await trainer.stopTraining();
|
|
146
133
|
}
|
|
147
134
|
/**
|
|
@@ -18,7 +18,7 @@ export class Trainer {
|
|
|
18
18
|
this.#roundDuration = task.trainingInformation.roundDuration;
|
|
19
19
|
this.#epochs = task.trainingInformation.epochs;
|
|
20
20
|
if (!Number.isInteger(this.#epochs / this.#roundDuration))
|
|
21
|
-
throw new Error(`round duration doesn't divide epochs`);
|
|
21
|
+
throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
|
|
22
22
|
}
|
|
23
23
|
/**
|
|
24
24
|
* Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
|
package/dist/types.d.ts
CHANGED
|
@@ -2,7 +2,5 @@ import type { Map } from 'immutable';
|
|
|
2
2
|
import type { WeightsContainer } from './index.js';
|
|
3
3
|
import type { NodeID } from './client/index.js';
|
|
4
4
|
export type Path = string;
|
|
5
|
-
export type MetadataKey = string;
|
|
6
|
-
export type MetadataValue = string;
|
|
7
5
|
export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][];
|
|
8
6
|
export type Contributions = Map<NodeID, WeightsContainer>;
|