@epfml/discojs 2.1.2-p20240723143623.0 → 2.1.2-p20240723160018.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/dist/aggregator/base.d.ts +8 -48
  2. package/dist/aggregator/base.js +6 -68
  3. package/dist/aggregator/get.d.ts +0 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/mean.d.ts +2 -2
  6. package/dist/aggregator/mean.js +3 -6
  7. package/dist/aggregator/secure.d.ts +2 -2
  8. package/dist/aggregator/secure.js +4 -7
  9. package/dist/client/base.d.ts +2 -1
  10. package/dist/client/base.js +0 -6
  11. package/dist/client/decentralized/base.d.ts +2 -2
  12. package/dist/client/decentralized/base.js +9 -8
  13. package/dist/client/federated/base.d.ts +1 -1
  14. package/dist/client/federated/base.js +2 -1
  15. package/dist/client/local.d.ts +3 -1
  16. package/dist/client/local.js +4 -1
  17. package/dist/default_tasks/cifar10.js +1 -2
  18. package/dist/default_tasks/mnist.js +0 -2
  19. package/dist/default_tasks/simple_face.js +0 -2
  20. package/dist/default_tasks/titanic.js +0 -2
  21. package/dist/index.d.ts +0 -1
  22. package/dist/index.js +0 -1
  23. package/dist/privacy.d.ts +8 -10
  24. package/dist/privacy.js +25 -40
  25. package/dist/task/training_information.d.ts +6 -2
  26. package/dist/task/training_information.js +17 -5
  27. package/dist/training/disco.d.ts +30 -28
  28. package/dist/training/disco.js +76 -61
  29. package/dist/training/index.d.ts +1 -1
  30. package/dist/training/index.js +1 -0
  31. package/dist/training/trainer.d.ts +16 -0
  32. package/dist/training/trainer.js +72 -0
  33. package/dist/weights/weights_container.d.ts +0 -5
  34. package/dist/weights/weights_container.js +0 -7
  35. package/package.json +1 -1
  36. package/dist/async_informant.d.ts +0 -15
  37. package/dist/async_informant.js +0 -42
  38. package/dist/training/trainer/distributed_trainer.d.ts +0 -20
  39. package/dist/training/trainer/distributed_trainer.js +0 -41
  40. package/dist/training/trainer/local_trainer.d.ts +0 -12
  41. package/dist/training/trainer/local_trainer.js +0 -24
  42. package/dist/training/trainer/trainer.d.ts +0 -32
  43. package/dist/training/trainer/trainer.js +0 -61
  44. package/dist/training/trainer/trainer_builder.d.ts +0 -23
  45. package/dist/training/trainer/trainer_builder.js +0 -47
package/dist/privacy.js CHANGED
@@ -1,42 +1,27 @@
1
- import * as tf from '@tensorflow/tfjs';
1
+ import * as tf from "@tensorflow/tfjs";
2
+ async function frobeniusNorm(weights) {
3
+ const squared = await weights
4
+ .map((w) => w.square().sum())
5
+ .reduce((a, b) => a.add(b))
6
+ .data();
7
+ if (squared.length !== 1)
8
+ throw new Error("unexcepted weights shape");
9
+ return Math.sqrt(squared[0]);
10
+ }
11
+ /** Scramble weights */
12
+ export function addNoise(weights, deviation) {
13
+ const variance = Math.pow(deviation, 2);
14
+ return weights.map((w) => w.add(tf.randomNormal(w.shape, 0, variance)));
15
+ }
2
16
  /**
3
- * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
4
- * The previous round's weights are the last weights pulled from server/peers.
5
- * The current round's weights are obtained after a single round of training, from the previous round's weights.
6
- * @param updatedWeights weights from the current round
7
- * @param staleWeights weights from the previous round
8
- * @param task the task
9
- * @returns the noised weights for the current round
10
- */
11
- export function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
12
- const noiseScale = task.trainingInformation?.noiseScale;
13
- const clippingRadius = task.trainingInformation?.clippingRadius;
14
- const weightsDiff = updatedWeights.sub(staleWeights);
15
- let newWeightsDiff;
16
- if (clippingRadius !== undefined) {
17
- // Frobenius norm
18
- const norm = weightsDiff.frobeniusNorm();
19
- newWeightsDiff = weightsDiff.map((w) => {
20
- const clipped = w.div(Math.max(1, norm / clippingRadius));
21
- if (noiseScale !== undefined) {
22
- // Add clipping and noise
23
- const noise = tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
24
- return clipped.add(noise);
25
- }
26
- else {
27
- // Add clipping without any noise
28
- return clipped;
29
- }
30
- });
31
- }
32
- else {
33
- if (noiseScale !== undefined) {
34
- // Add noise without any clipping
35
- newWeightsDiff = weightsDiff.map((w) => tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)));
36
- }
37
- else {
38
- return updatedWeights;
39
- }
40
- }
41
- return staleWeights.add(newWeightsDiff);
17
+ * Keep weights' norm within radius
18
+ *
19
+ * @param radius maximum norm
20
+ **/
21
+ export async function clipNorm(weights, radius) {
22
+ if (radius <= 0)
23
+ throw new Error("invalid radius");
24
+ const norm = await frobeniusNorm(weights);
25
+ const scaling = Math.max(1, norm / radius);
26
+ return weights.map((w) => w.div(scaling));
42
27
  }
@@ -1,5 +1,9 @@
1
1
  import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
2
2
  import { PreTrainedTokenizer } from '@xenova/transformers';
3
+ interface Privacy {
4
+ clippingRadius?: number;
5
+ noiseScale?: number;
6
+ }
3
7
  export interface TrainingInformation {
4
8
  modelID: string;
5
9
  epochs: number;
@@ -14,8 +18,7 @@ export interface TrainingInformation {
14
18
  IMAGE_W?: number;
15
19
  LABEL_LIST?: string[];
16
20
  scheme: 'decentralized' | 'federated' | 'local';
17
- noiseScale?: number;
18
- clippingRadius?: number;
21
+ privacy?: Privacy;
19
22
  decentralizedSecure?: boolean;
20
23
  maxShareValue?: number;
21
24
  minimumReadyPeers?: number;
@@ -25,3 +28,4 @@ export interface TrainingInformation {
25
28
  tensorBackend: 'tfjs' | 'gpt';
26
29
  }
27
30
  export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
31
+ export {};
@@ -6,11 +6,25 @@ function isStringArray(raw) {
6
6
  const arr = raw; // isArray is unsafely guarding with any[]
7
7
  return arr.every((e) => typeof e === 'string');
8
8
  }
9
+ function isPrivacy(raw) {
10
+ if (typeof raw !== "object" || raw === null) {
11
+ return false;
12
+ }
13
+ const { clippingRadius, noiseScale, } = raw;
14
+ if ((clippingRadius !== undefined && typeof clippingRadius !== "number") ||
15
+ (noiseScale !== undefined && typeof noiseScale !== "number"))
16
+ return false;
17
+ const _ = {
18
+ clippingRadius,
19
+ noiseScale,
20
+ };
21
+ return true;
22
+ }
9
23
  export function isTrainingInformation(raw) {
10
24
  if (typeof raw !== 'object' || raw === null) {
11
25
  return false;
12
26
  }
13
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, clippingRadius, dataType, decentralizedSecure, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, noiseScale, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
27
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
14
28
  if (typeof dataType !== 'string' ||
15
29
  typeof modelID !== 'string' ||
16
30
  typeof epochs !== 'number' ||
@@ -20,11 +34,10 @@ export function isTrainingInformation(raw) {
20
34
  (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
21
35
  (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
22
36
  (aggregator !== undefined && typeof aggregator !== 'string') ||
23
- (clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
24
37
  (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
38
+ (privacy !== undefined && !isPrivacy(privacy)) ||
25
39
  (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
26
40
  (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
27
- (noiseScale !== undefined && typeof noiseScale !== 'number') ||
28
41
  (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
29
42
  (IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
30
43
  (LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
@@ -77,15 +90,14 @@ export function isTrainingInformation(raw) {
77
90
  LABEL_LIST,
78
91
  aggregator,
79
92
  batchSize,
80
- clippingRadius,
81
93
  dataType,
82
94
  decentralizedSecure,
95
+ privacy,
83
96
  epochs,
84
97
  inputColumns,
85
98
  maxShareValue,
86
99
  minimumReadyPeers,
87
100
  modelID,
88
- noiseScale,
89
101
  outputColumns,
90
102
  preprocessingFunctions,
91
103
  roundDuration,
@@ -1,14 +1,11 @@
1
- import { BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js';
2
- import { client as clients } from '../index.js';
3
- import { type Aggregator } from '../aggregator/index.js';
4
- import type { RoundLogs } from './trainer/trainer.js';
5
- export interface DiscoOptions {
6
- client?: clients.Client;
7
- aggregator?: Aggregator;
8
- url?: string | URL;
9
- scheme?: TrainingInformation['scheme'];
10
- logger?: Logger;
11
- memory?: Memory;
1
+ import { data, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
2
+ import { client as clients } from "../index.js";
3
+ import type { Aggregator } from "../aggregator/index.js";
4
+ import { RoundLogs, Trainer } from "./trainer.js";
5
+ interface Config {
6
+ scheme: TrainingInformation["scheme"];
7
+ logger: Logger;
8
+ memory: Memory;
12
9
  }
13
10
  /**
14
11
  * Top-level class handling distributed training from a client's perspective. It is meant to be
@@ -16,16 +13,22 @@ export interface DiscoOptions {
16
13
  * communication with nodes, logs and model memory.
17
14
  */
18
15
  export declare class Disco {
19
- readonly task: Task;
20
- readonly logger: Logger;
21
- readonly memory: Memory;
22
- private readonly client;
23
- private readonly trainerPromise;
24
- constructor(task: Task, options: DiscoOptions);
16
+ #private;
17
+ readonly trainer: Trainer;
18
+ private constructor();
19
+ /**
20
+ * Connect to the given task and get ready to train.
21
+ *
22
+ * Will load the model from memory if available or fetch it from the server.
23
+ *
24
+ * @param clientConfig client to connect with or parameters on how to create one.
25
+ **/
26
+ static fromTask(task: Task, clientConfig: clients.Client | URL | {
27
+ aggregator: Aggregator;
28
+ url: URL;
29
+ }, config: Partial<Config>): Promise<Disco>;
25
30
  /** Train on dataset, yielding logs of every round. */
26
- trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
27
- participants: number;
28
- }>;
31
+ trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs>;
29
32
  /** Train on dataset, yielding logs of every epoch. */
30
33
  trainByEpoch(dataTuple: data.DataSplit): AsyncGenerator<EpochLogs>;
31
34
  /** Train on dataset, yielding logs of every batch. */
@@ -33,14 +36,12 @@ export declare class Disco {
33
36
  /** Run whole train on dataset. */
34
37
  trainFully(dataTuple: data.DataSplit): Promise<void>;
35
38
  /**
36
- * Train on dataset, yield the nested steps.
37
- *
38
- * Don't forget to await the yielded generator otherwise nothing will progress.
39
- * If you don't care about the whole process, use one of the other train methods.
40
- **/
41
- train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs & {
42
- participants: number;
43
- }>>;
39
+ * Train on dataset, yield the nested steps.
40
+ *
41
+ * Don't forget to await the yielded generator otherwise nothing will progress.
42
+ * If you don't care about the whole process, use one of the other train methods.
43
+ **/
44
+ train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
44
45
  /**
45
46
  * Stops the ongoing training instance without disconnecting the client.
46
47
  */
@@ -50,3 +51,4 @@ export declare class Disco {
50
51
  */
51
52
  close(): Promise<void>;
52
53
  }
54
+ export {};
@@ -1,52 +1,73 @@
1
- import { List } from 'immutable';
2
- import { async_iterator } from '../index.js';
3
- import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
4
- import { getAggregator } from '../aggregator/index.js';
5
- import { enumerate, split } from '../utils/async_iterator.js';
6
- import { TrainerBuilder } from './trainer/trainer_builder.js';
1
+ import { async_iterator, } from "../index.js";
2
+ import { client as clients, ConsoleLogger, EmptyMemory } from "../index.js";
3
+ import { getAggregator } from "../aggregator/index.js";
4
+ import { enumerate, split } from "../utils/async_iterator.js";
5
+ import { Trainer } from "./trainer.js";
7
6
  /**
8
7
  * Top-level class handling distributed training from a client's perspective. It is meant to be
9
8
  * a convenient object providing a reduced yet complete API that wraps model training,
10
9
  * communication with nodes, logs and model memory.
11
10
  */
12
11
  export class Disco {
13
- task;
14
- logger;
15
- memory;
16
- client;
17
- trainerPromise;
18
- constructor(task, options) {
19
- // Fill undefined options with default values
20
- if (options.scheme === undefined) {
21
- options.scheme = task.trainingInformation.scheme;
12
+ trainer;
13
+ #client;
14
+ #logger;
15
+ // small helper to avoid keeping Task & Memory around
16
+ #updateWorkingModel;
17
+ constructor(trainer, task, client, memory, logger) {
18
+ this.trainer = trainer;
19
+ this.#client = client;
20
+ this.#logger = logger;
21
+ this.#updateWorkingModel = () => memory.updateWorkingModel({
22
+ type: "working",
23
+ taskID: task.id,
24
+ name: task.trainingInformation.modelID,
25
+ tensorBackend: task.trainingInformation.tensorBackend,
26
+ }, this.trainer.model);
27
+ }
28
+ /**
29
+ * Connect to the given task and get ready to train.
30
+ *
31
+ * Will load the model from memory if available or fetch it from the server.
32
+ *
33
+ * @param clientConfig client to connect with or parameters on how to create one.
34
+ **/
35
+ static async fromTask(task, clientConfig, config) {
36
+ const { scheme, logger, memory } = {
37
+ scheme: task.trainingInformation.scheme,
38
+ logger: new ConsoleLogger(),
39
+ memory: new EmptyMemory(),
40
+ ...config,
41
+ };
42
+ let client;
43
+ if (clientConfig instanceof clients.Client) {
44
+ client = clientConfig;
22
45
  }
23
- if (options.client === undefined) {
24
- if (options.url === undefined) {
25
- throw new Error('could not determine client from given parameters');
26
- }
27
- if (options.aggregator === undefined) {
28
- options.aggregator = getAggregator(task, { scheme: options.scheme });
46
+ else {
47
+ let url, aggregator;
48
+ if (clientConfig instanceof URL) {
49
+ url = clientConfig;
50
+ aggregator = getAggregator(task, { scheme });
29
51
  }
30
- if (typeof options.url === 'string') {
31
- options.url = new URL(options.url);
52
+ else {
53
+ ({ url, aggregator } = clientConfig);
32
54
  }
33
- options.client = clients.getClient(options.scheme, options.url, task, options.aggregator);
34
- }
35
- if (options.logger === undefined) {
36
- options.logger = new ConsoleLogger();
37
- }
38
- if (options.memory === undefined) {
39
- options.memory = new EmptyMemory();
40
- }
41
- if (options.client.task !== task) {
42
- throw new Error('client not setup for given task');
55
+ client = clients.getClient(scheme, url, task, aggregator);
43
56
  }
44
- this.task = task;
45
- this.client = options.client;
46
- this.memory = options.memory;
47
- this.logger = options.logger;
48
- const trainerBuilder = new TrainerBuilder(this.memory, this.task);
49
- this.trainerPromise = trainerBuilder.build(this.client, options.scheme !== 'local');
57
+ if (client.task !== task)
58
+ throw new Error("client not setup for given task");
59
+ let model;
60
+ const memoryInfo = {
61
+ type: "working",
62
+ taskID: task.id,
63
+ name: task.trainingInformation.modelID,
64
+ tensorBackend: task.trainingInformation.tensorBackend,
65
+ };
66
+ if (await memory.contains(memoryInfo))
67
+ model = await memory.getModel(memoryInfo);
68
+ else
69
+ model = await client.getLatestModel();
70
+ return new Disco(new Trainer(task, model, client), task, client, memory, logger);
50
71
  }
51
72
  /** Train on dataset, yielding logs of every round. */
52
73
  async *trainByRound(dataTuple) {
@@ -83,27 +104,24 @@ export class Disco {
83
104
  ;
84
105
  }
85
106
  /**
86
- * Train on dataset, yield the nested steps.
87
- *
88
- * Don't forget to await the yielded generator otherwise nothing will progress.
89
- * If you don't care about the whole process, use one of the other train methods.
90
- **/
91
- // TODO RoundLogs should contain number of participants but Trainer doesn't need client
107
+ * Train on dataset, yield the nested steps.
108
+ *
109
+ * Don't forget to await the yielded generator otherwise nothing will progress.
110
+ * If you don't care about the whole process, use one of the other train methods.
111
+ **/
92
112
  async *train(dataTuple) {
93
- this.logger.success("Training started.");
113
+ this.#logger.success("Training started.");
94
114
  const trainData = dataTuple.train.preprocess().batch();
95
115
  const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
96
- await this.client.connect();
97
- const trainer = await this.trainerPromise;
98
- for await (const [round, epochs] of enumerate(trainer.fitModel(trainData.dataset, validationData.dataset))) {
116
+ await this.#client.connect();
117
+ for await (const [round, epochs] of enumerate(this.trainer.train(trainData.dataset, validationData.dataset))) {
99
118
  yield async function* () {
100
- let epochsLogs = List();
101
- for await (const [epoch, batches] of enumerate(epochs)) {
119
+ const [gen, returnedRoundLogs] = split(epochs);
120
+ for await (const [epoch, batches] of enumerate(gen)) {
102
121
  const [gen, returnedEpochLogs] = split(batches);
103
122
  yield gen;
104
123
  const epochLogs = await returnedEpochLogs;
105
- epochsLogs = epochsLogs.push(epochLogs);
106
- this.logger.success([
124
+ this.#logger.success([
107
125
  `Round: ${round}`,
108
126
  ` Epoch: ${epoch}`,
109
127
  ` Training loss: ${epochLogs.training.loss}`,
@@ -116,26 +134,23 @@ export class Disco {
116
134
  : "",
117
135
  ].join("\n"));
118
136
  }
119
- return {
120
- epochs: epochsLogs,
121
- participants: this.client.nbOfParticipants, // already includes ourselves
122
- };
137
+ return await returnedRoundLogs;
123
138
  }.bind(this)();
139
+ await this.#updateWorkingModel(this.trainer.model);
124
140
  }
125
- this.logger.success("Training finished.");
141
+ this.#logger.success("Training finished.");
126
142
  }
127
143
  /**
128
144
  * Stops the ongoing training instance without disconnecting the client.
129
145
  */
130
146
  async pause() {
131
- const trainer = await this.trainerPromise;
132
- await trainer.stopTraining();
147
+ await this.trainer.stopTraining();
133
148
  }
134
149
  /**
135
150
  * Completely stops the ongoing training instance.
136
151
  */
137
152
  async close() {
138
153
  await this.pause();
139
- await this.client.disconnect();
154
+ await this.#client.disconnect();
140
155
  }
141
156
  }
@@ -1,2 +1,2 @@
1
1
  export { Disco } from './disco.js';
2
- export { RoundLogs } from './trainer/trainer.js';
2
+ export { RoundLogs, Trainer } from './trainer.js';
@@ -1 +1,2 @@
1
1
  export { Disco } from './disco.js';
2
+ export { Trainer } from './trainer.js';
@@ -0,0 +1,16 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { List } from "immutable";
3
+ import type { BatchLogs, EpochLogs, Model, Task } from "../index.js";
4
+ import { Client } from "../client/index.js";
5
+ export interface RoundLogs {
6
+ epochs: List<EpochLogs>;
7
+ participants: number;
8
+ }
9
+ /** Train a model and exchange with others **/
10
+ export declare class Trainer {
11
+ #private;
12
+ readonly model: Model;
13
+ constructor(task: Task, model: Model, client: Client);
14
+ stopTraining(): Promise<void>;
15
+ train(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
16
+ }
@@ -0,0 +1,72 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { List } from "immutable";
3
+ import { privacy } from "../index.js";
4
+ import * as async_iterator from "../utils/async_iterator.js";
5
+ /** Train a model and exchange with others **/
6
+ export class Trainer {
7
+ model;
8
+ #client;
9
+ #roundDuration;
10
+ #epochs;
11
+ #privacy;
12
+ #training;
13
+ constructor(task, model, client) {
14
+ this.model = model;
15
+ this.#client = client;
16
+ this.#roundDuration = task.trainingInformation.roundDuration;
17
+ this.#epochs = task.trainingInformation.epochs;
18
+ this.#privacy = task.trainingInformation.privacy;
19
+ if (!Number.isInteger(this.#epochs / this.#roundDuration))
20
+ throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
21
+ }
22
+ async stopTraining() {
23
+ await this.#training?.return();
24
+ }
25
+ async *train(dataset, valDataset) {
26
+ if (this.#training !== undefined)
27
+ throw new Error("training already running, stop it before launching a new one");
28
+ try {
29
+ this.#training = this.#runRounds(dataset, valDataset);
30
+ yield* this.#training;
31
+ }
32
+ finally {
33
+ this.#training = undefined;
34
+ }
35
+ }
36
+ async *#runRounds(dataset, valDataset) {
37
+ const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
38
+ let previousRoundWeights;
39
+ for (let round = 0; round < totalRound; round++) {
40
+ await this.#client.onRoundBeginCommunication(this.model.weights, round);
41
+ yield this.#runRound(dataset, valDataset);
42
+ let localWeights = this.model.weights;
43
+ if (this.#privacy !== undefined)
44
+ localWeights = await applyPrivacy(previousRoundWeights, localWeights, this.#privacy);
45
+ const networkWeights = await this.#client.onRoundEndCommunication(localWeights, round);
46
+ this.model.weights = previousRoundWeights = networkWeights;
47
+ }
48
+ }
49
+ async *#runRound(dataset, valDataset) {
50
+ let epochsLogs = List();
51
+ for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
52
+ const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
53
+ yield gen;
54
+ epochsLogs = epochsLogs.push(await epochLogs);
55
+ }
56
+ return {
57
+ epochs: epochsLogs,
58
+ participants: this.#client.nbOfParticipants,
59
+ };
60
+ }
61
+ }
62
+ async function applyPrivacy(previous, current, options) {
63
+ let ret = current;
64
+ if (options.clippingRadius !== undefined) {
65
+ const previousRoundWeights = previous ?? current.map((w) => tf.zerosLike(w));
66
+ const weightsProgress = current.sub(previousRoundWeights);
67
+ ret = previousRoundWeights.add(await privacy.clipNorm(weightsProgress, options.clippingRadius));
68
+ }
69
+ if (options.noiseScale !== undefined)
70
+ ret = privacy.addNoise(ret, options.noiseScale);
71
+ return ret;
72
+ }
@@ -51,11 +51,6 @@ export declare class WeightsContainer {
51
51
  * @returns The tensor located at the index
52
52
  */
53
53
  get(index: number): tf.Tensor | undefined;
54
- /**
55
- * Computes the weights container's Frobenius norm
56
- * @returns The Frobenius norm
57
- */
58
- frobeniusNorm(): number;
59
54
  concat(other: WeightsContainer): WeightsContainer;
60
55
  equals(other: WeightsContainer, margin?: number): boolean;
61
56
  /**
@@ -70,13 +70,6 @@ export class WeightsContainer {
70
70
  get(index) {
71
71
  return this._weights.get(index);
72
72
  }
73
- /**
74
- * Computes the weights container's Frobenius norm
75
- * @returns The Frobenius norm
76
- */
77
- frobeniusNorm() {
78
- return Math.sqrt(this.map((w) => w.square().sum()).reduce((a, b) => a.add(b)).dataSync()[0]);
79
- }
80
73
  concat(other) {
81
74
  return WeightsContainer.of(...this.weights, ...other.weights);
82
75
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.1.2-p20240723143623.0",
3
+ "version": "2.1.2-p20240723160018.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,15 +0,0 @@
1
- import type { AggregatorBase } from './aggregator/index.js';
2
- export declare class AsyncInformant<T> {
3
- private readonly aggregator;
4
- private _round;
5
- private _currentNumberOfParticipants;
6
- private _totalNumberOfParticipants;
7
- private _averageNumberOfParticipants;
8
- constructor(aggregator: AggregatorBase<T>);
9
- update(): void;
10
- get round(): number;
11
- get currentNumberOfParticipants(): number;
12
- get totalNumberOfParticipants(): number;
13
- get averageNumberOfParticipants(): number;
14
- getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
15
- }
@@ -1,42 +0,0 @@
1
- export class AsyncInformant {
2
- aggregator;
3
- _round = 0;
4
- _currentNumberOfParticipants = 0;
5
- _totalNumberOfParticipants = 0;
6
- _averageNumberOfParticipants = 0;
7
- constructor(aggregator) {
8
- this.aggregator = aggregator;
9
- }
10
- update() {
11
- if (this.round === 0 || this.round < this.aggregator.round) {
12
- this._round = this.aggregator.round;
13
- this._currentNumberOfParticipants = this.aggregator.size;
14
- this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
15
- this._totalNumberOfParticipants += this.currentNumberOfParticipants;
16
- }
17
- else {
18
- this._round = this.aggregator.round;
19
- }
20
- }
21
- // Getter functions
22
- get round() {
23
- return this._round;
24
- }
25
- get currentNumberOfParticipants() {
26
- return this._currentNumberOfParticipants;
27
- }
28
- get totalNumberOfParticipants() {
29
- return this._totalNumberOfParticipants;
30
- }
31
- get averageNumberOfParticipants() {
32
- return this._averageNumberOfParticipants;
33
- }
34
- getAllStatistics() {
35
- return {
36
- round: this.round,
37
- currentNumberOfParticipants: this.currentNumberOfParticipants,
38
- totalNumberOfParticipants: this.totalNumberOfParticipants,
39
- averageNumberOfParticipants: this.averageNumberOfParticipants
40
- };
41
- }
42
- }
@@ -1,20 +0,0 @@
1
- import type { Model, Memory, Task, client as clients } from "../../index.js";
2
- import { Trainer } from "./trainer.js";
3
- /**
4
- * Class whose role is to train a model in a distributed way with a given dataset.
5
- */
6
- export declare class DistributedTrainer extends Trainer {
7
- private readonly task;
8
- private readonly memory;
9
- private readonly client;
10
- private readonly aggregator;
11
- /**
12
- * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
13
- */
14
- constructor(task: Task, memory: Memory, model: Model, client: clients.Client);
15
- onRoundBegin(round: number): Promise<void>;
16
- /**
17
- * Callback called every time a round is over
18
- */
19
- onRoundEnd(round: number): Promise<void>;
20
- }