@epfml/discojs 2.1.2-p20240718132634.0 → 2.1.2-p20240723143623.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -20,13 +20,14 @@ export const mnist = {
20
20
  trainingInformation: {
21
21
  modelID: 'mnist-model',
22
22
  epochs: 20,
23
- roundDuration: 10,
23
+ roundDuration: 2,
24
24
  validationSplit: 0.2,
25
- batchSize: 30,
25
+ batchSize: 64,
26
26
  dataType: 'image',
27
27
  IMAGE_H: 28,
28
28
  IMAGE_W: 28,
29
- preprocessingFunctions: [data.ImagePreprocessing.Normalize],
29
+ // Images should already be at the right size but resizing just in case
30
+ preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
30
31
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
31
32
  scheme: 'decentralized',
32
33
  noiseScale: undefined,
@@ -39,22 +40,24 @@ export const mnist = {
39
40
  };
40
41
  },
41
42
  getModel() {
43
+ // Architecture from the PyTorch MNIST example (I made it slightly smaller, 650kB instead of 5MB)
44
+ // https://github.com/pytorch/examples/blob/main/mnist/main.py
42
45
  const model = tf.sequential();
43
46
  model.add(tf.layers.conv2d({
44
47
  inputShape: [28, 28, 3],
45
- kernelSize: 3,
46
- filters: 16,
47
- activation: 'relu'
48
+ kernelSize: 5,
49
+ filters: 8,
50
+ activation: 'relu',
48
51
  }));
52
+ model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, activation: 'relu' }));
49
53
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
50
- model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
51
- model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
52
- model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
53
- model.add(tf.layers.flatten({}));
54
- model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
54
+ model.add(tf.layers.dropout({ rate: 0.25 }));
55
+ model.add(tf.layers.flatten());
56
+ model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
57
+ model.add(tf.layers.dropout({ rate: 0.25 }));
55
58
  model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
56
59
  model.compile({
57
- optimizer: 'rmsprop',
60
+ optimizer: 'adam',
58
61
  loss: 'categoricalCrossentropy',
59
62
  metrics: ['accuracy']
60
63
  });
@@ -13,8 +13,16 @@ export async function pushTask(url, task, model) {
13
13
  export async function fetchTasks(url) {
14
14
  const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
15
15
  const tasks = response.data;
16
- if (!(Array.isArray(tasks) && tasks.every(isTask))) {
17
- throw new Error('invalid tasks response');
16
+ if (!Array.isArray(tasks)) {
17
+ throw new Error('Expected to receive an array of Tasks when fetching tasks');
18
+ }
19
+ else if (!tasks.every(isTask)) {
20
+ for (const task of tasks) {
21
+ if (!isTask(task)) {
22
+ console.error("task has invalid format:", task);
23
+ }
24
+ }
25
+ throw new Error('invalid tasks response, the task object received is not well formatted');
18
26
  }
19
27
  return Map(tasks.map((t) => [t.id, t]));
20
28
  }
@@ -1,4 +1,3 @@
1
- import type { AggregatorChoice } from '../aggregator/get.js';
2
1
  import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
3
2
  import { PreTrainedTokenizer } from '@xenova/transformers';
4
3
  export interface TrainingInformation {
@@ -20,7 +19,7 @@ export interface TrainingInformation {
20
19
  decentralizedSecure?: boolean;
21
20
  maxShareValue?: number;
22
21
  minimumReadyPeers?: number;
23
- aggregator?: AggregatorChoice;
22
+ aggregator?: 'mean' | 'secure';
24
23
  tokenizer?: string | PreTrainedTokenizer;
25
24
  maxSequenceLength?: number;
26
25
  tensorBackend: 'tfjs' | 'gpt';
@@ -19,7 +19,7 @@ export function isTrainingInformation(raw) {
19
19
  typeof validationSplit !== 'number' ||
20
20
  (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
21
21
  (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
22
- (aggregator !== undefined && typeof aggregator !== 'number') ||
22
+ (aggregator !== undefined && typeof aggregator !== 'string') ||
23
23
  (clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
24
24
  (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
25
25
  (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
@@ -33,6 +33,13 @@ export function isTrainingInformation(raw) {
33
33
  (preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
34
34
  return false;
35
35
  }
36
+ if (aggregator !== undefined) {
37
+ switch (aggregator) {
38
+ case 'mean': break;
39
+ case 'secure': break;
40
+ default: return false;
41
+ }
42
+ }
36
43
  switch (dataType) {
37
44
  case 'image': break;
38
45
  case 'tabular': break;
@@ -1,6 +1,6 @@
1
1
  import { BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js';
2
2
  import { client as clients } from '../index.js';
3
- import type { Aggregator } from '../aggregator/index.js';
3
+ import { type Aggregator } from '../aggregator/index.js';
4
4
  import type { RoundLogs } from './trainer/trainer.js';
5
5
  export interface DiscoOptions {
6
6
  client?: clients.Client;
@@ -20,7 +20,7 @@ export declare class Disco {
20
20
  readonly logger: Logger;
21
21
  readonly memory: Memory;
22
22
  private readonly client;
23
- private readonly trainer;
23
+ private readonly trainerPromise;
24
24
  constructor(task: Task, options: DiscoOptions);
25
25
  /** Train on dataset, yielding logs of every round. */
26
26
  trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
@@ -1,7 +1,7 @@
1
1
  import { List } from 'immutable';
2
2
  import { async_iterator } from '../index.js';
3
3
  import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
4
- import { MeanAggregator } from '../aggregator/mean.js';
4
+ import { getAggregator } from '../aggregator/index.js';
5
5
  import { enumerate, split } from '../utils/async_iterator.js';
6
6
  import { TrainerBuilder } from './trainer/trainer_builder.js';
7
7
  /**
@@ -14,36 +14,23 @@ export class Disco {
14
14
  logger;
15
15
  memory;
16
16
  client;
17
- trainer;
17
+ trainerPromise;
18
18
  constructor(task, options) {
19
+ // Fill undefined options with default values
19
20
  if (options.scheme === undefined) {
20
21
  options.scheme = task.trainingInformation.scheme;
21
22
  }
22
- if (options.aggregator === undefined) {
23
- options.aggregator = new MeanAggregator();
24
- }
25
23
  if (options.client === undefined) {
26
24
  if (options.url === undefined) {
27
25
  throw new Error('could not determine client from given parameters');
28
26
  }
27
+ if (options.aggregator === undefined) {
28
+ options.aggregator = getAggregator(task, { scheme: options.scheme });
29
+ }
29
30
  if (typeof options.url === 'string') {
30
31
  options.url = new URL(options.url);
31
32
  }
32
- switch (options.scheme) {
33
- case 'federated':
34
- options.client = new clients.federated.FederatedClient(options.url, task, options.aggregator);
35
- break;
36
- case 'decentralized':
37
- options.client = new clients.decentralized.DecentralizedClient(options.url, task, options.aggregator);
38
- break;
39
- case 'local':
40
- options.client = new clients.Local(options.url, task, options.aggregator);
41
- break;
42
- default: {
43
- const _ = options.scheme;
44
- throw new Error('should never happen');
45
- }
46
- }
33
+ options.client = clients.getClient(options.scheme, options.url, task, options.aggregator);
47
34
  }
48
35
  if (options.logger === undefined) {
49
36
  options.logger = new ConsoleLogger();
@@ -59,7 +46,7 @@ export class Disco {
59
46
  this.memory = options.memory;
60
47
  this.logger = options.logger;
61
48
  const trainerBuilder = new TrainerBuilder(this.memory, this.task);
62
- this.trainer = trainerBuilder.build(this.client, options.scheme !== 'local');
49
+ this.trainerPromise = trainerBuilder.build(this.client, options.scheme !== 'local');
63
50
  }
64
51
  /** Train on dataset, yielding logs of every round. */
65
52
  async *trainByRound(dataTuple) {
@@ -107,7 +94,7 @@ export class Disco {
107
94
  const trainData = dataTuple.train.preprocess().batch();
108
95
  const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
109
96
  await this.client.connect();
110
- const trainer = await this.trainer;
97
+ const trainer = await this.trainerPromise;
111
98
  for await (const [round, epochs] of enumerate(trainer.fitModel(trainData.dataset, validationData.dataset))) {
112
99
  yield async function* () {
113
100
  let epochsLogs = List();
@@ -131,7 +118,7 @@ export class Disco {
131
118
  }
132
119
  return {
133
120
  epochs: epochsLogs,
134
- participants: this.client.nodes.size + 1, // add ourself
121
+ participants: this.client.nbOfParticipants, // already includes ourselves
135
122
  };
136
123
  }.bind(this)();
137
124
  }
@@ -141,7 +128,7 @@ export class Disco {
141
128
  * Stops the ongoing training instance without disconnecting the client.
142
129
  */
143
130
  async pause() {
144
- const trainer = await this.trainer;
131
+ const trainer = await this.trainerPromise;
145
132
  await trainer.stopTraining();
146
133
  }
147
134
  /**
@@ -18,7 +18,7 @@ export class Trainer {
18
18
  this.#roundDuration = task.trainingInformation.roundDuration;
19
19
  this.#epochs = task.trainingInformation.epochs;
20
20
  if (!Number.isInteger(this.#epochs / this.#roundDuration))
21
- throw new Error(`round duration doesn't divide epochs`);
21
+ throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
22
22
  }
23
23
  /**
24
24
  * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
package/dist/types.d.ts CHANGED
@@ -2,7 +2,5 @@ import type { Map } from 'immutable';
2
2
  import type { WeightsContainer } from './index.js';
3
3
  import type { NodeID } from './client/index.js';
4
4
  export type Path = string;
5
- export type MetadataKey = string;
6
- export type MetadataValue = string;
7
5
  export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][];
8
6
  export type Contributions = Map<NodeID, WeightsContainer>;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.1.2-p20240718132634.0",
3
+ "version": "2.1.2-p20240723143623.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",