@epfml/discojs 2.2.1 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/dist/aggregator/base.d.ts +9 -48
  2. package/dist/aggregator/base.js +8 -69
  3. package/dist/aggregator/get.d.ts +23 -11
  4. package/dist/aggregator/get.js +40 -23
  5. package/dist/aggregator/index.d.ts +1 -1
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +25 -6
  8. package/dist/aggregator/mean.js +62 -17
  9. package/dist/aggregator/secure.d.ts +2 -2
  10. package/dist/aggregator/secure.js +4 -7
  11. package/dist/client/base.d.ts +3 -3
  12. package/dist/client/base.js +6 -8
  13. package/dist/client/decentralized/base.d.ts +27 -10
  14. package/dist/client/decentralized/base.js +123 -86
  15. package/dist/client/decentralized/peer.js +7 -12
  16. package/dist/client/decentralized/peer_pool.js +6 -2
  17. package/dist/client/event_connection.d.ts +1 -1
  18. package/dist/client/event_connection.js +3 -3
  19. package/dist/client/federated/base.d.ts +5 -21
  20. package/dist/client/federated/base.js +38 -61
  21. package/dist/client/federated/messages.d.ts +2 -10
  22. package/dist/client/federated/messages.js +0 -1
  23. package/dist/client/index.d.ts +1 -1
  24. package/dist/client/index.js +1 -1
  25. package/dist/client/local.d.ts +3 -1
  26. package/dist/client/local.js +4 -1
  27. package/dist/client/messages.d.ts +1 -2
  28. package/dist/client/messages.js +8 -3
  29. package/dist/client/utils.d.ts +4 -2
  30. package/dist/client/utils.js +18 -3
  31. package/dist/dataset/data/data.d.ts +1 -1
  32. package/dist/dataset/data/data.js +13 -2
  33. package/dist/dataset/data/preprocessing/image_preprocessing.js +6 -4
  34. package/dist/default_tasks/cifar10.js +1 -2
  35. package/dist/default_tasks/lus_covid.js +0 -5
  36. package/dist/default_tasks/mnist.js +15 -14
  37. package/dist/default_tasks/simple_face.js +0 -2
  38. package/dist/default_tasks/titanic.js +2 -4
  39. package/dist/default_tasks/wikitext.js +7 -1
  40. package/dist/index.d.ts +0 -1
  41. package/dist/index.js +0 -1
  42. package/dist/models/gpt/config.js +1 -1
  43. package/dist/privacy.d.ts +8 -10
  44. package/dist/privacy.js +25 -40
  45. package/dist/task/task_handler.js +10 -2
  46. package/dist/task/training_information.d.ts +7 -4
  47. package/dist/task/training_information.js +25 -6
  48. package/dist/training/disco.d.ts +30 -28
  49. package/dist/training/disco.js +75 -73
  50. package/dist/training/index.d.ts +1 -1
  51. package/dist/training/index.js +1 -0
  52. package/dist/training/trainer.d.ts +16 -0
  53. package/dist/training/trainer.js +72 -0
  54. package/dist/types.d.ts +0 -2
  55. package/dist/weights/weights_container.d.ts +0 -5
  56. package/dist/weights/weights_container.js +0 -7
  57. package/package.json +1 -1
  58. package/dist/async_informant.d.ts +0 -15
  59. package/dist/async_informant.js +0 -42
  60. package/dist/training/trainer/distributed_trainer.d.ts +0 -20
  61. package/dist/training/trainer/distributed_trainer.js +0 -41
  62. package/dist/training/trainer/local_trainer.d.ts +0 -12
  63. package/dist/training/trainer/local_trainer.js +0 -24
  64. package/dist/training/trainer/trainer.d.ts +0 -32
  65. package/dist/training/trainer/trainer.js +0 -61
  66. package/dist/training/trainer/trainer_builder.d.ts +0 -23
  67. package/dist/training/trainer/trainer_builder.js +0 -47
@@ -1,65 +1,73 @@
1
- import { List } from 'immutable';
2
- import { async_iterator } from '../index.js';
3
- import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
4
- import { MeanAggregator } from '../aggregator/mean.js';
5
- import { enumerate, split } from '../utils/async_iterator.js';
6
- import { TrainerBuilder } from './trainer/trainer_builder.js';
1
+ import { async_iterator, } from "../index.js";
2
+ import { client as clients, ConsoleLogger, EmptyMemory } from "../index.js";
3
+ import { getAggregator } from "../aggregator/index.js";
4
+ import { enumerate, split } from "../utils/async_iterator.js";
5
+ import { Trainer } from "./trainer.js";
7
6
  /**
8
7
  * Top-level class handling distributed training from a client's perspective. It is meant to be
9
8
  * a convenient object providing a reduced yet complete API that wraps model training,
10
9
  * communication with nodes, logs and model memory.
11
10
  */
12
11
  export class Disco {
13
- task;
14
- logger;
15
- memory;
16
- client;
17
12
  trainer;
18
- constructor(task, options) {
19
- if (options.scheme === undefined) {
20
- options.scheme = task.trainingInformation.scheme;
21
- }
22
- if (options.aggregator === undefined) {
23
- options.aggregator = new MeanAggregator();
13
+ #client;
14
+ #logger;
15
+ // small helper to avoid keeping Task & Memory around
16
+ #updateWorkingModel;
17
+ constructor(trainer, task, client, memory, logger) {
18
+ this.trainer = trainer;
19
+ this.#client = client;
20
+ this.#logger = logger;
21
+ this.#updateWorkingModel = () => memory.updateWorkingModel({
22
+ type: "working",
23
+ taskID: task.id,
24
+ name: task.trainingInformation.modelID,
25
+ tensorBackend: task.trainingInformation.tensorBackend,
26
+ }, this.trainer.model);
27
+ }
28
+ /**
29
+ * Connect to the given task and get ready to train.
30
+ *
31
+ * Will load the model from memory if available or fetch it from the server.
32
+ *
33
+ * @param clientConfig client to connect with or parameters on how to create one.
34
+ **/
35
+ static async fromTask(task, clientConfig, config) {
36
+ const { scheme, logger, memory } = {
37
+ scheme: task.trainingInformation.scheme,
38
+ logger: new ConsoleLogger(),
39
+ memory: new EmptyMemory(),
40
+ ...config,
41
+ };
42
+ let client;
43
+ if (clientConfig instanceof clients.Client) {
44
+ client = clientConfig;
24
45
  }
25
- if (options.client === undefined) {
26
- if (options.url === undefined) {
27
- throw new Error('could not determine client from given parameters');
28
- }
29
- if (typeof options.url === 'string') {
30
- options.url = new URL(options.url);
46
+ else {
47
+ let url, aggregator;
48
+ if (clientConfig instanceof URL) {
49
+ url = clientConfig;
50
+ aggregator = getAggregator(task, { scheme });
31
51
  }
32
- switch (options.scheme) {
33
- case 'federated':
34
- options.client = new clients.federated.FederatedClient(options.url, task, options.aggregator);
35
- break;
36
- case 'decentralized':
37
- options.client = new clients.decentralized.DecentralizedClient(options.url, task, options.aggregator);
38
- break;
39
- case 'local':
40
- options.client = new clients.Local(options.url, task, options.aggregator);
41
- break;
42
- default: {
43
- const _ = options.scheme;
44
- throw new Error('should never happen');
45
- }
52
+ else {
53
+ ({ url, aggregator } = clientConfig);
46
54
  }
55
+ client = clients.getClient(scheme, url, task, aggregator);
47
56
  }
48
- if (options.logger === undefined) {
49
- options.logger = new ConsoleLogger();
50
- }
51
- if (options.memory === undefined) {
52
- options.memory = new EmptyMemory();
53
- }
54
- if (options.client.task !== task) {
55
- throw new Error('client not setup for given task');
56
- }
57
- this.task = task;
58
- this.client = options.client;
59
- this.memory = options.memory;
60
- this.logger = options.logger;
61
- const trainerBuilder = new TrainerBuilder(this.memory, this.task);
62
- this.trainer = trainerBuilder.build(this.client, options.scheme !== 'local');
57
+ if (client.task !== task)
58
+ throw new Error("client not setup for given task");
59
+ let model;
60
+ const memoryInfo = {
61
+ type: "working",
62
+ taskID: task.id,
63
+ name: task.trainingInformation.modelID,
64
+ tensorBackend: task.trainingInformation.tensorBackend,
65
+ };
66
+ if (await memory.contains(memoryInfo))
67
+ model = await memory.getModel(memoryInfo);
68
+ else
69
+ model = await client.getLatestModel();
70
+ return new Disco(new Trainer(task, model, client), task, client, memory, logger);
63
71
  }
64
72
  /** Train on dataset, yielding logs of every round. */
65
73
  async *trainByRound(dataTuple) {
@@ -96,27 +104,24 @@ export class Disco {
96
104
  ;
97
105
  }
98
106
  /**
99
- * Train on dataset, yield the nested steps.
100
- *
101
- * Don't forget to await the yielded generator otherwise nothing will progress.
102
- * If you don't care about the whole process, use one of the other train methods.
103
- **/
104
- // TODO RoundLogs should contain number of participants but Trainer doesn't need client
107
+ * Train on dataset, yield the nested steps.
108
+ *
109
+ * Don't forget to await the yielded generator otherwise nothing will progress.
110
+ * If you don't care about the whole process, use one of the other train methods.
111
+ **/
105
112
  async *train(dataTuple) {
106
- this.logger.success("Training started.");
113
+ this.#logger.success("Training started.");
107
114
  const trainData = dataTuple.train.preprocess().batch();
108
115
  const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
109
- await this.client.connect();
110
- const trainer = await this.trainer;
111
- for await (const [round, epochs] of enumerate(trainer.fitModel(trainData.dataset, validationData.dataset))) {
116
+ await this.#client.connect();
117
+ for await (const [round, epochs] of enumerate(this.trainer.train(trainData.dataset, validationData.dataset))) {
112
118
  yield async function* () {
113
- let epochsLogs = List();
114
- for await (const [epoch, batches] of enumerate(epochs)) {
119
+ const [gen, returnedRoundLogs] = split(epochs);
120
+ for await (const [epoch, batches] of enumerate(gen)) {
115
121
  const [gen, returnedEpochLogs] = split(batches);
116
122
  yield gen;
117
123
  const epochLogs = await returnedEpochLogs;
118
- epochsLogs = epochsLogs.push(epochLogs);
119
- this.logger.success([
124
+ this.#logger.success([
120
125
  `Round: ${round}`,
121
126
  ` Epoch: ${epoch}`,
122
127
  ` Training loss: ${epochLogs.training.loss}`,
@@ -129,26 +134,23 @@ export class Disco {
129
134
  : "",
130
135
  ].join("\n"));
131
136
  }
132
- return {
133
- epochs: epochsLogs,
134
- participants: this.client.nodes.size + 1, // add ourself
135
- };
137
+ return await returnedRoundLogs;
136
138
  }.bind(this)();
139
+ await this.#updateWorkingModel(this.trainer.model);
137
140
  }
138
- this.logger.success("Training finished.");
141
+ this.#logger.success("Training finished.");
139
142
  }
140
143
  /**
141
144
  * Stops the ongoing training instance without disconnecting the client.
142
145
  */
143
146
  async pause() {
144
- const trainer = await this.trainer;
145
- await trainer.stopTraining();
147
+ await this.trainer.stopTraining();
146
148
  }
147
149
  /**
148
150
  * Completely stops the ongoing training instance.
149
151
  */
150
152
  async close() {
151
153
  await this.pause();
152
- await this.client.disconnect();
154
+ await this.#client.disconnect();
153
155
  }
154
156
  }
@@ -1,2 +1,2 @@
1
1
  export { Disco } from './disco.js';
2
- export { RoundLogs } from './trainer/trainer.js';
2
+ export { RoundLogs, Trainer } from './trainer.js';
@@ -1 +1,2 @@
1
1
  export { Disco } from './disco.js';
2
+ export { Trainer } from './trainer.js';
@@ -0,0 +1,16 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { List } from "immutable";
3
+ import type { BatchLogs, EpochLogs, Model, Task } from "../index.js";
4
+ import { Client } from "../client/index.js";
5
+ export interface RoundLogs {
6
+ epochs: List<EpochLogs>;
7
+ participants: number;
8
+ }
9
+ /** Train a model and exchange with others **/
10
+ export declare class Trainer {
11
+ #private;
12
+ readonly model: Model;
13
+ constructor(task: Task, model: Model, client: Client);
14
+ stopTraining(): Promise<void>;
15
+ train(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
16
+ }
@@ -0,0 +1,72 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { List } from "immutable";
3
+ import { privacy } from "../index.js";
4
+ import * as async_iterator from "../utils/async_iterator.js";
5
+ /** Train a model and exchange with others **/
6
+ export class Trainer {
7
+ model;
8
+ #client;
9
+ #roundDuration;
10
+ #epochs;
11
+ #privacy;
12
+ #training;
13
+ constructor(task, model, client) {
14
+ this.model = model;
15
+ this.#client = client;
16
+ this.#roundDuration = task.trainingInformation.roundDuration;
17
+ this.#epochs = task.trainingInformation.epochs;
18
+ this.#privacy = task.trainingInformation.privacy;
19
+ if (!Number.isInteger(this.#epochs / this.#roundDuration))
20
+ throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
21
+ }
22
+ async stopTraining() {
23
+ await this.#training?.return();
24
+ }
25
+ async *train(dataset, valDataset) {
26
+ if (this.#training !== undefined)
27
+ throw new Error("training already running, stop it before launching a new one");
28
+ try {
29
+ this.#training = this.#runRounds(dataset, valDataset);
30
+ yield* this.#training;
31
+ }
32
+ finally {
33
+ this.#training = undefined;
34
+ }
35
+ }
36
+ async *#runRounds(dataset, valDataset) {
37
+ const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
38
+ let previousRoundWeights;
39
+ for (let round = 0; round < totalRound; round++) {
40
+ await this.#client.onRoundBeginCommunication(this.model.weights, round);
41
+ yield this.#runRound(dataset, valDataset);
42
+ let localWeights = this.model.weights;
43
+ if (this.#privacy !== undefined)
44
+ localWeights = await applyPrivacy(previousRoundWeights, localWeights, this.#privacy);
45
+ const networkWeights = await this.#client.onRoundEndCommunication(localWeights, round);
46
+ this.model.weights = previousRoundWeights = networkWeights;
47
+ }
48
+ }
49
+ async *#runRound(dataset, valDataset) {
50
+ let epochsLogs = List();
51
+ for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
52
+ const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
53
+ yield gen;
54
+ epochsLogs = epochsLogs.push(await epochLogs);
55
+ }
56
+ return {
57
+ epochs: epochsLogs,
58
+ participants: this.#client.nbOfParticipants,
59
+ };
60
+ }
61
+ }
62
+ async function applyPrivacy(previous, current, options) {
63
+ let ret = current;
64
+ if (options.clippingRadius !== undefined) {
65
+ const previousRoundWeights = previous ?? current.map((w) => tf.zerosLike(w));
66
+ const weightsProgress = current.sub(previousRoundWeights);
67
+ ret = previousRoundWeights.add(await privacy.clipNorm(weightsProgress, options.clippingRadius));
68
+ }
69
+ if (options.noiseScale !== undefined)
70
+ ret = privacy.addNoise(ret, options.noiseScale);
71
+ return ret;
72
+ }
package/dist/types.d.ts CHANGED
@@ -2,7 +2,5 @@ import type { Map } from 'immutable';
2
2
  import type { WeightsContainer } from './index.js';
3
3
  import type { NodeID } from './client/index.js';
4
4
  export type Path = string;
5
- export type MetadataKey = string;
6
- export type MetadataValue = string;
7
5
  export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][];
8
6
  export type Contributions = Map<NodeID, WeightsContainer>;
@@ -51,11 +51,6 @@ export declare class WeightsContainer {
51
51
  * @returns The tensor located at the index
52
52
  */
53
53
  get(index: number): tf.Tensor | undefined;
54
- /**
55
- * Computes the weights container's Frobenius norm
56
- * @returns The Frobenius norm
57
- */
58
- frobeniusNorm(): number;
59
54
  concat(other: WeightsContainer): WeightsContainer;
60
55
  equals(other: WeightsContainer, margin?: number): boolean;
61
56
  /**
@@ -70,13 +70,6 @@ export class WeightsContainer {
70
70
  get(index) {
71
71
  return this._weights.get(index);
72
72
  }
73
- /**
74
- * Computes the weights container's Frobenius norm
75
- * @returns The Frobenius norm
76
- */
77
- frobeniusNorm() {
78
- return Math.sqrt(this.map((w) => w.square().sum()).reduce((a, b) => a.add(b)).dataSync()[0]);
79
- }
80
73
  concat(other) {
81
74
  return WeightsContainer.of(...this.weights, ...other.weights);
82
75
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.2.1",
3
+ "version": "3.0.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,15 +0,0 @@
1
- import type { AggregatorBase } from './aggregator/index.js';
2
- export declare class AsyncInformant<T> {
3
- private readonly aggregator;
4
- private _round;
5
- private _currentNumberOfParticipants;
6
- private _totalNumberOfParticipants;
7
- private _averageNumberOfParticipants;
8
- constructor(aggregator: AggregatorBase<T>);
9
- update(): void;
10
- get round(): number;
11
- get currentNumberOfParticipants(): number;
12
- get totalNumberOfParticipants(): number;
13
- get averageNumberOfParticipants(): number;
14
- getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
15
- }
@@ -1,42 +0,0 @@
1
- export class AsyncInformant {
2
- aggregator;
3
- _round = 0;
4
- _currentNumberOfParticipants = 0;
5
- _totalNumberOfParticipants = 0;
6
- _averageNumberOfParticipants = 0;
7
- constructor(aggregator) {
8
- this.aggregator = aggregator;
9
- }
10
- update() {
11
- if (this.round === 0 || this.round < this.aggregator.round) {
12
- this._round = this.aggregator.round;
13
- this._currentNumberOfParticipants = this.aggregator.size;
14
- this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
15
- this._totalNumberOfParticipants += this.currentNumberOfParticipants;
16
- }
17
- else {
18
- this._round = this.aggregator.round;
19
- }
20
- }
21
- // Getter functions
22
- get round() {
23
- return this._round;
24
- }
25
- get currentNumberOfParticipants() {
26
- return this._currentNumberOfParticipants;
27
- }
28
- get totalNumberOfParticipants() {
29
- return this._totalNumberOfParticipants;
30
- }
31
- get averageNumberOfParticipants() {
32
- return this._averageNumberOfParticipants;
33
- }
34
- getAllStatistics() {
35
- return {
36
- round: this.round,
37
- currentNumberOfParticipants: this.currentNumberOfParticipants,
38
- totalNumberOfParticipants: this.totalNumberOfParticipants,
39
- averageNumberOfParticipants: this.averageNumberOfParticipants
40
- };
41
- }
42
- }
@@ -1,20 +0,0 @@
1
- import type { Model, Memory, Task, client as clients } from "../../index.js";
2
- import { Trainer } from "./trainer.js";
3
- /**
4
- * Class whose role is to train a model in a distributed way with a given dataset.
5
- */
6
- export declare class DistributedTrainer extends Trainer {
7
- private readonly task;
8
- private readonly memory;
9
- private readonly client;
10
- private readonly aggregator;
11
- /**
12
- * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
13
- */
14
- constructor(task: Task, memory: Memory, model: Model, client: clients.Client);
15
- onRoundBegin(round: number): Promise<void>;
16
- /**
17
- * Callback called every time a round is over
18
- */
19
- onRoundEnd(round: number): Promise<void>;
20
- }
@@ -1,41 +0,0 @@
1
- import { Trainer } from "./trainer.js";
2
- /**
3
- * Class whose role is to train a model in a distributed way with a given dataset.
4
- */
5
- export class DistributedTrainer extends Trainer {
6
- task;
7
- memory;
8
- client;
9
- aggregator;
10
- /**
11
- * DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
12
- */
13
- constructor(task, memory, model, client) {
14
- super(task, model);
15
- this.task = task;
16
- this.memory = memory;
17
- this.client = client;
18
- this.aggregator = this.client.aggregator;
19
- this.aggregator.setModel(model);
20
- }
21
- async onRoundBegin(round) {
22
- await this.client.onRoundBeginCommunication(this.model.weights, round);
23
- }
24
- /**
25
- * Callback called every time a round is over
26
- */
27
- async onRoundEnd(round) {
28
- await this.client.onRoundEndCommunication(this.model.weights, round);
29
- if (this.aggregator.model !== undefined) {
30
- // The aggregator's own aggregation is async. The trainer updates its model to match the aggregator's
31
- // after it has completed a round of training.
32
- this.model.weights = this.aggregator.model.weights;
33
- }
34
- await this.memory.updateWorkingModel({
35
- type: 'working',
36
- taskID: this.task.id,
37
- name: this.task.trainingInformation.modelID,
38
- tensorBackend: this.task.trainingInformation.tensorBackend
39
- }, this.model);
40
- }
41
- }
@@ -1,12 +0,0 @@
1
- import type { Memory, Model, Task } from "../../index.js";
2
- import { Trainer } from "./trainer.js";
3
- /** Class whose role is to locally (alone) train a model on a given dataset,
4
- * without any collaborators.
5
- */
6
- export declare class LocalTrainer extends Trainer {
7
- private readonly task;
8
- private readonly memory;
9
- constructor(task: Task, memory: Memory, model: Model);
10
- onRoundBegin(): Promise<void>;
11
- onRoundEnd(): Promise<void>;
12
- }
@@ -1,24 +0,0 @@
1
- import { Trainer } from "./trainer.js";
2
- /** Class whose role is to locally (alone) train a model on a given dataset,
3
- * without any collaborators.
4
- */
5
- export class LocalTrainer extends Trainer {
6
- task;
7
- memory;
8
- constructor(task, memory, model) {
9
- super(task, model);
10
- this.task = task;
11
- this.memory = memory;
12
- }
13
- async onRoundBegin() {
14
- return await Promise.resolve();
15
- }
16
- async onRoundEnd() {
17
- await this.memory.updateWorkingModel({
18
- type: 'working',
19
- taskID: this.task.id,
20
- name: this.task.trainingInformation.modelID,
21
- tensorBackend: this.task.trainingInformation.tensorBackend
22
- }, this.model);
23
- }
24
- }
@@ -1,32 +0,0 @@
1
- import type tf from "@tensorflow/tfjs";
2
- import { List } from "immutable";
3
- import type { Model, Task } from "../../index.js";
4
- import { BatchLogs, EpochLogs } from "../../models/index.js";
5
- export interface RoundLogs {
6
- epochs: List<EpochLogs>;
7
- }
8
- /** Abstract class whose role is to train a model with a given dataset. This can be either done
9
- * locally (alone) or in a distributed way with collaborators.
10
- *
11
- * 1. Call `fitModel(dataset)` to start training.
12
- * 2. which will then call onRoundEnd once the round has ended.
13
- *
14
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
15
- */
16
- export declare abstract class Trainer {
17
- #private;
18
- readonly model: Model;
19
- private training?;
20
- constructor(task: Task, model: Model);
21
- protected abstract onRoundBegin(round: number): Promise<void>;
22
- protected abstract onRoundEnd(round: number): Promise<void>;
23
- /**
24
- * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
25
- */
26
- stopTraining(): Promise<void>;
27
- /**
28
- * Start training the model with the given dataset
29
- * @param dataset
30
- */
31
- fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
32
- }
@@ -1,61 +0,0 @@
1
- import { List } from "immutable";
2
- import * as async_iterator from "../../utils/async_iterator.js";
3
- /** Abstract class whose role is to train a model with a given dataset. This can be either done
4
- * locally (alone) or in a distributed way with collaborators.
5
- *
6
- * 1. Call `fitModel(dataset)` to start training.
7
- * 2. which will then call onRoundEnd once the round has ended.
8
- *
9
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
10
- */
11
- export class Trainer {
12
- model;
13
- #roundDuration;
14
- #epochs;
15
- training;
16
- constructor(task, model) {
17
- this.model = model;
18
- this.#roundDuration = task.trainingInformation.roundDuration;
19
- this.#epochs = task.trainingInformation.epochs;
20
- if (!Number.isInteger(this.#epochs / this.#roundDuration))
21
- throw new Error(`round duration doesn't divide epochs`);
22
- }
23
- /**
24
- * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
25
- */
26
- async stopTraining() {
27
- await this.training?.return();
28
- }
29
- /**
30
- * Start training the model with the given dataset
31
- * @param dataset
32
- */
33
- async *fitModel(dataset, valDataset) {
34
- if (this.training !== undefined)
35
- throw new Error("training already running, cancel it before launching a new one");
36
- try {
37
- this.training = this.#runRounds(dataset, valDataset);
38
- yield* this.training;
39
- }
40
- finally {
41
- this.training = undefined;
42
- }
43
- }
44
- async *#runRounds(dataset, valDataset) {
45
- const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
46
- for (let round = 0; round < totalRound; round++) {
47
- await this.onRoundBegin(round);
48
- yield this.#runRound(dataset, valDataset);
49
- await this.onRoundEnd(round);
50
- }
51
- }
52
- async *#runRound(dataset, valDataset) {
53
- let epochsLogs = List();
54
- for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
55
- const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
56
- yield gen;
57
- epochsLogs = epochsLogs.push(await epochLogs);
58
- }
59
- return { epochs: epochsLogs };
60
- }
61
- }
@@ -1,23 +0,0 @@
1
- import type { client as clients, Task, Memory } from '../../index.js';
2
- import type { Trainer } from './trainer.js';
3
- /**
4
- * A class that helps build the Trainer and auxiliary classes.
5
- */
6
- export declare class TrainerBuilder {
7
- private readonly memory;
8
- private readonly task;
9
- constructor(memory: Memory, task: Task);
10
- /**
11
- * Builds a trainer object.
12
- *
13
- * @param client client to share weights with (either distributed or federated)
14
- * @param distributed whether to build a distributed or local trainer
15
- * @returns
16
- */
17
- build(client: clients.Client, distributed?: boolean): Promise<Trainer>;
18
- /**
19
- * If a model exists in memory, load it, otherwise load model from server
20
- * @returns
21
- */
22
- private getModel;
23
- }