@epfml/discojs 2.1.2-p20240723143623.0 → 2.1.2-p20240723160018.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/dist/aggregator/base.d.ts +8 -48
  2. package/dist/aggregator/base.js +6 -68
  3. package/dist/aggregator/get.d.ts +0 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/mean.d.ts +2 -2
  6. package/dist/aggregator/mean.js +3 -6
  7. package/dist/aggregator/secure.d.ts +2 -2
  8. package/dist/aggregator/secure.js +4 -7
  9. package/dist/client/base.d.ts +2 -1
  10. package/dist/client/base.js +0 -6
  11. package/dist/client/decentralized/base.d.ts +2 -2
  12. package/dist/client/decentralized/base.js +9 -8
  13. package/dist/client/federated/base.d.ts +1 -1
  14. package/dist/client/federated/base.js +2 -1
  15. package/dist/client/local.d.ts +3 -1
  16. package/dist/client/local.js +4 -1
  17. package/dist/default_tasks/cifar10.js +1 -2
  18. package/dist/default_tasks/mnist.js +0 -2
  19. package/dist/default_tasks/simple_face.js +0 -2
  20. package/dist/default_tasks/titanic.js +0 -2
  21. package/dist/index.d.ts +0 -1
  22. package/dist/index.js +0 -1
  23. package/dist/privacy.d.ts +8 -10
  24. package/dist/privacy.js +25 -40
  25. package/dist/task/training_information.d.ts +6 -2
  26. package/dist/task/training_information.js +17 -5
  27. package/dist/training/disco.d.ts +30 -28
  28. package/dist/training/disco.js +76 -61
  29. package/dist/training/index.d.ts +1 -1
  30. package/dist/training/index.js +1 -0
  31. package/dist/training/trainer.d.ts +16 -0
  32. package/dist/training/trainer.js +72 -0
  33. package/dist/weights/weights_container.d.ts +0 -5
  34. package/dist/weights/weights_container.js +0 -7
  35. package/package.json +1 -1
  36. package/dist/async_informant.d.ts +0 -15
  37. package/dist/async_informant.js +0 -42
  38. package/dist/training/trainer/distributed_trainer.d.ts +0 -20
  39. package/dist/training/trainer/distributed_trainer.js +0 -41
  40. package/dist/training/trainer/local_trainer.d.ts +0 -12
  41. package/dist/training/trainer/local_trainer.js +0 -24
  42. package/dist/training/trainer/trainer.d.ts +0 -32
  43. package/dist/training/trainer/trainer.js +0 -61
  44. package/dist/training/trainer/trainer_builder.d.ts +0 -23
  45. package/dist/training/trainer/trainer_builder.js +0 -47
@@ -1,41 +0,0 @@
1
- import { Trainer } from "./trainer.js";
2
- /**
3
- * Class whose role is to train a model in a distributed way with a given dataset.
4
- */
5
- export class DistributedTrainer extends Trainer {
6
- task;
7
- memory;
8
- client;
9
- aggregator;
10
- /**
11
- * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
12
- */
13
- constructor(task, memory, model, client) {
14
- super(task, model);
15
- this.task = task;
16
- this.memory = memory;
17
- this.client = client;
18
- this.aggregator = this.client.aggregator;
19
- this.aggregator.setModel(model);
20
- }
21
- async onRoundBegin(round) {
22
- await this.client.onRoundBeginCommunication(this.model.weights, round);
23
- }
24
- /**
25
- * Callback called every time a round is over
26
- */
27
- async onRoundEnd(round) {
28
- await this.client.onRoundEndCommunication(this.model.weights, round);
29
- if (this.aggregator.model !== undefined) {
30
- // The aggregator's own aggregation is async. The trainer updates its model to match the aggregator's
31
- // after it has completed a round of training.
32
- this.model.weights = this.aggregator.model.weights;
33
- }
34
- await this.memory.updateWorkingModel({
35
- type: 'working',
36
- taskID: this.task.id,
37
- name: this.task.trainingInformation.modelID,
38
- tensorBackend: this.task.trainingInformation.tensorBackend
39
- }, this.model);
40
- }
41
- }
@@ -1,12 +0,0 @@
1
- import type { Memory, Model, Task } from "../../index.js";
2
- import { Trainer } from "./trainer.js";
3
- /** Class whose role is to locally (alone) train a model on a given dataset,
4
- * without any collaborators.
5
- */
6
- export declare class LocalTrainer extends Trainer {
7
- private readonly task;
8
- private readonly memory;
9
- constructor(task: Task, memory: Memory, model: Model);
10
- onRoundBegin(): Promise<void>;
11
- onRoundEnd(): Promise<void>;
12
- }
@@ -1,24 +0,0 @@
1
- import { Trainer } from "./trainer.js";
2
- /** Class whose role is to locally (alone) train a model on a given dataset,
3
- * without any collaborators.
4
- */
5
- export class LocalTrainer extends Trainer {
6
- task;
7
- memory;
8
- constructor(task, memory, model) {
9
- super(task, model);
10
- this.task = task;
11
- this.memory = memory;
12
- }
13
- async onRoundBegin() {
14
- return await Promise.resolve();
15
- }
16
- async onRoundEnd() {
17
- await this.memory.updateWorkingModel({
18
- type: 'working',
19
- taskID: this.task.id,
20
- name: this.task.trainingInformation.modelID,
21
- tensorBackend: this.task.trainingInformation.tensorBackend
22
- }, this.model);
23
- }
24
- }
@@ -1,32 +0,0 @@
1
- import type tf from "@tensorflow/tfjs";
2
- import { List } from "immutable";
3
- import type { Model, Task } from "../../index.js";
4
- import { BatchLogs, EpochLogs } from "../../models/index.js";
5
- export interface RoundLogs {
6
- epochs: List<EpochLogs>;
7
- }
8
- /** Abstract class whose role is to train a model with a given dataset. This can be either done
9
- * locally (alone) or in a distributed way with collaborators.
10
- *
11
- * 1. Call `fitModel(dataset)` to start training.
12
- * 2. which will then call onRoundEnd once the round has ended.
13
- *
14
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
15
- */
16
- export declare abstract class Trainer {
17
- #private;
18
- readonly model: Model;
19
- private training?;
20
- constructor(task: Task, model: Model);
21
- protected abstract onRoundBegin(round: number): Promise<void>;
22
- protected abstract onRoundEnd(round: number): Promise<void>;
23
- /**
24
- * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
25
- */
26
- stopTraining(): Promise<void>;
27
- /**
28
- * Start training the model with the given dataset
29
- * @param dataset
30
- */
31
- fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
32
- }
@@ -1,61 +0,0 @@
1
- import { List } from "immutable";
2
- import * as async_iterator from "../../utils/async_iterator.js";
3
- /** Abstract class whose role is to train a model with a given dataset. This can be either done
4
- * locally (alone) or in a distributed way with collaborators.
5
- *
6
- * 1. Call `fitModel(dataset)` to start training.
7
- * 2. which will then call onRoundEnd once the round has ended.
8
- *
9
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
10
- */
11
- export class Trainer {
12
- model;
13
- #roundDuration;
14
- #epochs;
15
- training;
16
- constructor(task, model) {
17
- this.model = model;
18
- this.#roundDuration = task.trainingInformation.roundDuration;
19
- this.#epochs = task.trainingInformation.epochs;
20
- if (!Number.isInteger(this.#epochs / this.#roundDuration))
21
- throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
22
- }
23
- /**
24
- * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
25
- */
26
- async stopTraining() {
27
- await this.training?.return();
28
- }
29
- /**
30
- * Start training the model with the given dataset
31
- * @param dataset
32
- */
33
- async *fitModel(dataset, valDataset) {
34
- if (this.training !== undefined)
35
- throw new Error("training already running, cancel it before launching a new one");
36
- try {
37
- this.training = this.#runRounds(dataset, valDataset);
38
- yield* this.training;
39
- }
40
- finally {
41
- this.training = undefined;
42
- }
43
- }
44
- async *#runRounds(dataset, valDataset) {
45
- const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
46
- for (let round = 0; round < totalRound; round++) {
47
- await this.onRoundBegin(round);
48
- yield this.#runRound(dataset, valDataset);
49
- await this.onRoundEnd(round);
50
- }
51
- }
52
- async *#runRound(dataset, valDataset) {
53
- let epochsLogs = List();
54
- for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
55
- const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
56
- yield gen;
57
- epochsLogs = epochsLogs.push(await epochLogs);
58
- }
59
- return { epochs: epochsLogs };
60
- }
61
- }
@@ -1,23 +0,0 @@
1
- import type { client as clients, Task, Memory } from '../../index.js';
2
- import type { Trainer } from './trainer.js';
3
- /**
4
- * A class that helps build the Trainer and auxiliary classes.
5
- */
6
- export declare class TrainerBuilder {
7
- private readonly memory;
8
- private readonly task;
9
- constructor(memory: Memory, task: Task);
10
- /**
11
- * Builds a trainer object.
12
- *
13
- * @param client client to share weights with (either distributed or federated)
14
- * @param distributed whether to build a distributed or local trainer
15
- * @returns
16
- */
17
- build(client: clients.Client, distributed?: boolean): Promise<Trainer>;
18
- /**
19
- * If a model exists in memory, load it, otherwise load model from server
20
- * @returns
21
- */
22
- private getModel;
23
- }
@@ -1,47 +0,0 @@
1
- import { DistributedTrainer } from './distributed_trainer.js';
2
- import { LocalTrainer } from './local_trainer.js';
3
- /**
4
- * A class that helps build the Trainer and auxiliary classes.
5
- */
6
- export class TrainerBuilder {
7
- memory;
8
- task;
9
- constructor(memory, task) {
10
- this.memory = memory;
11
- this.task = task;
12
- }
13
- /**
14
- * Builds a trainer object.
15
- *
16
- * @param client client to share weights with (either distributed or federated)
17
- * @param distributed whether to build a distributed or local trainer
18
- * @returns
19
- */
20
- async build(client, distributed = false) {
21
- const model = await this.getModel(client);
22
- if (distributed) {
23
- return new DistributedTrainer(this.task, this.memory, model, client);
24
- }
25
- else {
26
- return new LocalTrainer(this.task, this.memory, model);
27
- }
28
- }
29
- /**
30
- * If a model exists in memory, load it, otherwise load model from server
31
- * @returns
32
- */
33
- async getModel(client) {
34
- const modelID = this.task.trainingInformation?.modelID;
35
- if (modelID === undefined) {
36
- throw new TypeError('model ID is undefined');
37
- }
38
- const info = {
39
- type: 'working',
40
- taskID: this.task.id,
41
- name: modelID,
42
- tensorBackend: 'gpt'
43
- };
44
- const model = await (await this.memory.contains(info) ? this.memory.getModel(info) : client.getLatestModel());
45
- return model;
46
- }
47
- }