@epfml/discojs 2.1.2-p20240627125649.0 → 2.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
- /// <reference types="node" resolution-mode="require"/>
2
1
  import type { NodeID } from '../types.js';
3
2
  export type SignalData = {
4
3
  type: 'answer' | 'offer' | 'pranswer' | 'rollback';
@@ -20,7 +20,7 @@ export const wikitext = {
20
20
  modelID: 'llm-raw-model',
21
21
  preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
22
22
  scheme: 'federated',
23
- epochs: 5,
23
+ epochs: 6,
24
24
  // Unused by wikitext because data already comes split
25
25
  // But if set to 0 then the webapp doesn't display the validation metrics
26
26
  validationSplit: 0.1,
package/dist/index.d.ts CHANGED
@@ -10,8 +10,9 @@ export { Logger, ConsoleLogger } from './logging/index.js';
10
10
  export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
11
11
  export { Disco, RoundLogs } from './training/index.js';
12
12
  export { Validator } from './validation/index.js';
13
- export { Model, EpochLogs } from './models/index.js';
13
+ export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
14
14
  export * as models from './models/index.js';
15
15
  export * from './task/index.js';
16
16
  export * as defaultTasks from './default_tasks/index.js';
17
17
  export * from './types.js';
18
+ export * as async_iterator from "./utils/async_iterator.js";
package/dist/index.js CHANGED
@@ -10,8 +10,9 @@ export { ConsoleLogger } from './logging/index.js';
10
10
  export { Memory, Empty as EmptyMemory } from './memory/index.js';
11
11
  export { Disco } from './training/index.js';
12
12
  export { Validator } from './validation/index.js';
13
- export { Model } from './models/index.js';
13
+ export { Model, EpochLogs } from './models/index.js';
14
14
  export * as models from './models/index.js';
15
15
  export * from './task/index.js';
16
16
  export * as defaultTasks from './default_tasks/index.js';
17
17
  export * from './types.js';
18
+ export * as async_iterator from "./utils/async_iterator.js";
@@ -3,5 +3,5 @@ interface DataPoint extends tf.TensorContainerObject {
3
3
  xs: tf.Tensor2D;
4
4
  ys: tf.Tensor3D;
5
5
  }
6
- export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>>;
6
+ export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'val_acc' | 'val_loss' | 'val_perplexity', number>>;
7
7
  export {};
@@ -38,7 +38,6 @@ export default async function evaluate(model, dataset, maxEvalBatches) {
38
38
  return {
39
39
  val_loss: loss,
40
40
  val_perplexity: Math.exp(loss),
41
- acc: acc[0] / acc[1],
42
41
  val_acc: acc[0] / acc[1]
43
42
  };
44
43
  }
@@ -5,14 +5,15 @@ import * as tf from '@tensorflow/tfjs';
5
5
  import { PreTrainedTokenizer } from '@xenova/transformers';
6
6
  import { WeightsContainer } from '../../index.js';
7
7
  import type { Dataset } from '../../dataset/index.js';
8
- import { Model } from '../model.js';
9
- import type { EpochLogs, Prediction, Sample } from '../model.js';
10
- import type { GPTConfig } from './config.js';
8
+ import { BatchLogs, Model, EpochLogs } from "../index.js";
9
+ import type { Prediction, Sample } from '../model.js';
10
+ import { type GPTConfig } from './config.js';
11
11
  export type GPTSerialization = {
12
12
  weights: WeightsContainer;
13
13
  config?: GPTConfig;
14
14
  };
15
15
  export declare class GPT extends Model {
16
+ #private;
16
17
  private readonly model;
17
18
  constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
18
19
  /**
@@ -24,7 +25,7 @@ export declare class GPT extends Model {
24
25
  * @param epochs the number of passes of the training dataset
25
26
  * @param tracker
26
27
  */
27
- train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
28
+ train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
28
29
  predict(input: Sample): Promise<Prediction>;
29
30
  generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
30
31
  get config(): Required<GPTConfig>;
@@ -1,14 +1,20 @@
1
1
  /**
2
2
  * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
3
3
  **/
4
+ import * as tf from '@tensorflow/tfjs';
4
5
  import { WeightsContainer } from '../../index.js';
5
- import { Model } from '../model.js';
6
+ import { Model, EpochLogs } from "../index.js";
6
7
  import { GPTForCausalLM } from './model.js';
8
+ import { DEFAULT_CONFIG } from './config.js';
9
+ import evaluate from './evaluate.js';
10
+ import { List } from 'immutable';
7
11
  export class GPT extends Model {
8
12
  model;
13
+ #maxBatchCount;
9
14
  constructor(partialConfig, layersModel) {
10
15
  super();
11
16
  this.model = new GPTForCausalLM(partialConfig, layersModel);
17
+ this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter;
12
18
  }
13
19
  /**
14
20
  * The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
@@ -19,45 +25,65 @@ export class GPT extends Model {
19
25
  * @param epochs the number of passes of the training dataset
20
26
  * @param tracker
21
27
  */
22
- async *train(trainingData, validationData, epochs = 1) {
28
+ async *train(trainingData, validationData) {
23
29
  this.model.compile();
30
+ const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
31
+ let batchesLogs = List();
32
+ for (let batchNumber = 0; batchNumber < this.#maxBatchCount; batchNumber++) {
33
+ const iteration = await batches.next();
34
+ if (iteration.done)
35
+ break;
36
+ const batch = iteration.value;
37
+ const batchLogs = await this.#runBatch(batch);
38
+ tf.dispose(batch);
39
+ yield batchLogs;
40
+ batchesLogs = batchesLogs.push(batchLogs);
41
+ }
42
+ const validation = validationData && (await this.#evaluate(validationData));
43
+ return new EpochLogs(batchesLogs, validation);
44
+ }
45
+ async #runBatch(batch) {
24
46
  let logs;
25
- const trainingArgs = {
26
- epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
27
- validationData,
28
- callbacks: { onEpochEnd: (_, cur) => { logs = cur; } },
47
+ await this.model.fitDataset(tf.data.array([batch]), {
48
+ epochs: 1,
49
+ verbose: 0, // don't pollute
50
+ callbacks: {
51
+ onEpochEnd: (_, cur) => {
52
+ logs = cur;
53
+ },
54
+ },
55
+ });
56
+ if (logs === undefined)
57
+ throw new Error("batch didn't gave any logs");
58
+ const { loss, acc: accuracy } = logs;
59
+ if (loss === undefined || isNaN(loss))
60
+ throw new Error("training loss is undefined or NaN");
61
+ return {
62
+ accuracy,
63
+ loss,
64
+ memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
29
65
  };
30
- for (let epoch = 0; epoch < epochs; epoch++) {
31
- await this.model.fitDataset(trainingData, trainingArgs);
32
- if (logs === undefined) {
33
- throw new Error("Epoch didn't gave any logs");
34
- }
35
- const { loss, val_acc, val_loss, peakMemory } = logs;
36
- if (loss === undefined || isNaN(loss)) {
37
- throw new Error("Training loss is undefined or nan");
38
- }
39
- const structuredLogs = {
40
- epoch,
41
- peakMemory,
42
- training: {
43
- loss: logs.loss,
44
- accuracy: logs.acc
45
- }
46
- };
47
- if (validationData !== undefined) {
48
- if (val_loss === undefined || isNaN(val_loss) ||
49
- val_acc === undefined || isNaN(val_acc)) {
50
- throw new Error("Validation accuracy or loss is undefined or nan");
51
- }
52
- structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss };
66
+ }
67
+ async #evaluate(dataset) {
68
+ const evaluation = await evaluate(this.model, dataset.map((t) => {
69
+ switch (t) {
70
+ case null:
71
+ case undefined:
72
+ throw new Error("nullish value in dataset");
73
+ default:
74
+ // TODO unsafe cast
75
+ return t;
53
76
  }
54
- yield structuredLogs;
55
- }
77
+ }), this.config.maxEvalBatches);
78
+ return {
79
+ accuracy: evaluation.val_acc,
80
+ loss: evaluation.val_loss,
81
+ };
56
82
  }
57
83
  predict(input) {
58
84
  const ret = this.model.predict(input);
59
85
  if (Array.isArray(ret)) {
60
- throw new Error('prediction yield many Tensors but should have only returned one');
86
+ throw new Error("prediction yield many Tensors but should have only returned one");
61
87
  }
62
88
  return Promise.resolve(ret);
63
89
  }
@@ -37,29 +37,44 @@ class GPTModel extends tf.LayersModel {
37
37
  const evalDataset = trainingArgs.validationData;
38
38
  await callbacks.onTrainBegin?.();
39
39
  for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
40
+ let accuracyFraction = [0, 0];
40
41
  let averageLoss = 0;
41
- let peakMemory = 0;
42
42
  let iteration = 1;
43
43
  const iterator = await dataset.iterator();
44
- let preprocessingTime = performance.now();
45
44
  let next = await iterator.next();
46
- preprocessingTime = performance.now() - preprocessingTime;
47
45
  while (next.done !== true && iteration <= this.config.maxIter) {
48
46
  let weightUpdateTime = performance.now();
49
47
  await callbacks.onEpochBegin?.(epoch);
50
48
  const { xs, ys } = next.value;
51
- const lossFn = () => {
49
+ let preprocessingTime = performance.now();
50
+ await Promise.all([xs.data(), ys.data()]);
51
+ preprocessingTime = performance.now() - preprocessingTime;
52
+ // TODO include as a tensor inside the model
53
+ const accTensor = tf.tidy(() => {
52
54
  const logits = this.apply(xs);
53
- if (Array.isArray(logits)) {
55
+ if (Array.isArray(logits))
54
56
  throw new Error('model outputs too many tensor');
55
- }
56
- if (logits instanceof tf.SymbolicTensor) {
57
+ if (logits instanceof tf.SymbolicTensor)
57
58
  throw new Error('model outputs symbolic tensor');
58
- }
59
- return tf.losses.softmaxCrossEntropy(ys, logits);
60
- };
59
+ return tf.metrics.categoricalAccuracy(ys, logits);
60
+ });
61
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1);
62
+ const accSumTensor = accTensor.sum();
63
+ const accSum = await accSumTensor.array();
64
+ tf.dispose(accSumTensor);
65
+ if (typeof accSum !== 'number')
66
+ throw new Error('got multiple accuracy sum');
67
+ accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize];
68
+ tf.dispose([accTensor]);
61
69
  const lossTensor = tf.tidy(() => {
62
- const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn);
70
+ const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
71
+ const logits = this.apply(xs);
72
+ if (Array.isArray(logits))
73
+ throw new Error('model outputs too many tensor');
74
+ if (logits instanceof tf.SymbolicTensor)
75
+ throw new Error('model outputs symbolic tensor');
76
+ return tf.losses.softmaxCrossEntropy(ys, logits);
77
+ });
63
78
  const gradsClipped = clipByGlobalNormObj(grads, 1);
64
79
  this.optimizer.applyGradients(gradsClipped);
65
80
  return lossTensor;
@@ -75,9 +90,6 @@ class GPTModel extends tf.LayersModel {
75
90
  console.log(iterationLogs);
76
91
  }
77
92
  const memory = tf.memory().numBytes / 1024 / 1024 / 1024;
78
- if (memory > peakMemory) {
79
- peakMemory = memory;
80
- }
81
93
  console.log(`Epoch: ${epoch}`, `\tStep: ${iteration} / ${this.config.maxIter}`, `\tLoss: ${loss.toFixed(3)}`, `\tMemory: ${memory.toFixed(2)} GB`, `\tNumber of tensors allocated: ${tf.memory().numTensors}`, `\tPreprocessing time: ${preprocessingTime.toFixed(0)} ms`, `\tWeight update time: ${weightUpdateTime.toFixed(0)} ms`);
82
94
  iteration++;
83
95
  next = await iterator.next();
@@ -89,7 +101,7 @@ class GPTModel extends tf.LayersModel {
89
101
  }
90
102
  let logs = {
91
103
  'loss': averageLoss / iteration,
92
- 'peakMemory': peakMemory
104
+ 'acc': accuracyFraction[0] / accuracyFraction[1],
93
105
  };
94
106
  if (evalDataset !== undefined) {
95
107
  logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) };
@@ -1,4 +1,5 @@
1
- export { EpochLogs, Model } from './model.js';
1
+ export { Model } from './model.js';
2
+ export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js";
2
3
  export { GPT } from './gpt/index.js';
3
4
  export { GPTConfig } from './gpt/config.js';
4
5
  export { TFJS } from './tfjs.js';
@@ -1,4 +1,5 @@
1
1
  export { Model } from './model.js';
2
+ export { EpochLogs } from "./logs.js";
2
3
  export { GPT } from './gpt/index.js';
3
4
  export { TFJS } from './tfjs.js';
4
5
  export { getTaskTokenizer } from './tokenizer.js';
@@ -0,0 +1,17 @@
1
+ import { List } from "immutable";
2
+ export interface ValidationMetrics {
3
+ accuracy: number;
4
+ loss: number;
5
+ }
6
+ export interface BatchLogs {
7
+ accuracy: number;
8
+ loss: number;
9
+ memoryUsage: number;
10
+ }
11
+ export declare class EpochLogs {
12
+ readonly validation?: ValidationMetrics | undefined;
13
+ readonly batches: List<BatchLogs>;
14
+ constructor(batches: Iterable<BatchLogs>, validation?: ValidationMetrics | undefined);
15
+ get training(): Record<"accuracy" | "loss", number>;
16
+ get peakMemory(): number;
17
+ }
@@ -0,0 +1,22 @@
1
+ import { List } from "immutable";
2
+ export class EpochLogs {
3
+ validation;
4
+ batches;
5
+ constructor(batches, validation) {
6
+ this.validation = validation;
7
+ this.batches = List(batches);
8
+ }
9
+ get training() {
10
+ const sum = this.batches.reduce((acc, batch) => ({
11
+ accuracy: acc.accuracy + batch.accuracy,
12
+ loss: acc.loss + batch.loss,
13
+ }), { loss: 0, accuracy: 0 });
14
+ return {
15
+ accuracy: sum.accuracy / this.batches.size,
16
+ loss: sum.loss / this.batches.size,
17
+ };
18
+ }
19
+ get peakMemory() {
20
+ return this.batches.map((batch) => batch.memoryUsage).max() ?? 0;
21
+ }
22
+ }
@@ -1,19 +1,7 @@
1
- /// <reference types="node" resolution-mode="require"/>
2
1
  import type tf from "@tensorflow/tfjs";
3
2
  import type { WeightsContainer } from "../index.js";
4
3
  import type { Dataset } from "../dataset/index.js";
5
- export interface EpochLogs {
6
- epoch: number;
7
- training: {
8
- loss: number;
9
- accuracy?: number;
10
- };
11
- validation?: {
12
- loss: number;
13
- accuracy: number;
14
- };
15
- peakMemory: number;
16
- }
4
+ import type { BatchLogs, EpochLogs } from "./logs.js";
17
5
  export type Prediction = tf.Tensor;
18
6
  export type Sample = tf.Tensor;
19
7
  /**
@@ -35,7 +23,7 @@ export declare abstract class Model implements Disposable {
35
23
  * @param tracker watch the various steps
36
24
  * @yields on every epoch, training can be stop by `return`ing it
37
25
  */
38
- abstract train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
26
+ abstract train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
39
27
  /** Predict likely values */
40
28
  abstract predict(input: Sample): Promise<Prediction>;
41
29
  /**
@@ -1,16 +1,18 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
2
  import { WeightsContainer } from '../index.js';
3
- import { Model } from './index.js';
4
- import type { EpochLogs, Prediction, Sample } from './model.js';
5
3
  import type { Dataset } from '../dataset/index.js';
4
+ import { BatchLogs, EpochLogs } from './index.js';
5
+ import { Model } from './index.js';
6
+ import type { Prediction, Sample } from './model.js';
6
7
  /** TensorFlow JavaScript model with standard training */
7
8
  export declare class TFJS extends Model {
9
+ #private;
8
10
  private readonly model;
9
11
  /** Wrap the given trainable model */
10
12
  constructor(model: tf.LayersModel);
11
13
  get weights(): WeightsContainer;
12
14
  set weights(ws: WeightsContainer);
13
- train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs>;
15
+ train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
14
16
  predict(input: Sample): Promise<Prediction>;
15
17
  static deserialize(raw: tf.io.ModelArtifacts): Promise<Model>;
16
18
  serialize(): Promise<tf.io.ModelArtifacts>;
@@ -1,5 +1,7 @@
1
+ import { List, Map } from 'immutable';
1
2
  import * as tf from '@tensorflow/tfjs';
2
3
  import { WeightsContainer } from '../index.js';
4
+ import { EpochLogs } from './index.js';
3
5
  import { Model } from './index.js';
4
6
  /** TensorFlow JavaScript model with standard training */
5
7
  export class TFJS extends Model {
@@ -18,50 +20,71 @@ export class TFJS extends Model {
18
20
  set weights(ws) {
19
21
  this.model.setWeights(ws.weights);
20
22
  }
21
- async *train(trainingData, validationData, epochs = 1) {
22
- for (let epoch = 0; epoch < epochs; epoch++) {
23
- let logs;
24
- let peakMemory = 0;
25
- await this.model.fitDataset(trainingData, {
26
- epochs: 1,
27
- validationData,
28
- callbacks: {
29
- onBatchEnd: (_) => {
30
- const currentMemory = tf.memory().numBytes / 1024 / 1024 / 1024; // GB
31
- if (currentMemory > peakMemory) {
32
- peakMemory = currentMemory;
33
- }
34
- },
35
- onEpochEnd: (_, cur) => { logs = cur; }
36
- },
37
- });
38
- if (logs === undefined) {
39
- throw new Error("Epoch didn't gave any logs");
40
- }
41
- const { loss, acc, val_acc, val_loss } = logs;
42
- if (loss === undefined || isNaN(loss) || acc === undefined || isNaN(acc)) {
43
- throw new Error("Training loss is undefined or nan");
44
- }
45
- const structuredLogs = {
46
- epoch,
47
- peakMemory,
48
- training: {
49
- loss: logs.loss,
50
- accuracy: logs.acc,
51
- }
23
+ async *train(trainingData, validationData) {
24
+ const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
25
+ let batchesLogs = List();
26
+ for (let batchNumber = 0; true; batchNumber++) {
27
+ const iteration = await batches.next();
28
+ if (iteration.done)
29
+ break;
30
+ const batch = iteration.value;
31
+ const batchLogs = {
32
+ batch: batchNumber,
33
+ ...(await this.#runBatch(batch)),
52
34
  };
53
- if (validationData !== undefined) {
54
- if (val_loss === undefined || isNaN(val_loss) ||
55
- val_acc === undefined || isNaN(val_acc)) {
56
- throw new Error("Invalid validation logs");
57
- }
58
- structuredLogs.validation = {
59
- accuracy: logs.val_acc,
60
- loss: logs.val_loss
61
- };
62
- }
63
- yield structuredLogs;
35
+ tf.dispose(batch);
36
+ yield batchLogs;
37
+ batchesLogs = batchesLogs.push(batchLogs);
64
38
  }
39
+ const validation = validationData && (await this.#evaluate(validationData));
40
+ return new EpochLogs(batchesLogs, validation);
41
+ }
42
+ async #runBatch(batch) {
43
+ let logs;
44
+ await this.model.fitDataset(tf.data.array([batch]), {
45
+ epochs: 1,
46
+ verbose: 0, // don't pollute
47
+ callbacks: {
48
+ onEpochEnd: (_, cur) => {
49
+ logs = cur;
50
+ },
51
+ },
52
+ });
53
+ if (logs === undefined)
54
+ throw new Error("batch didn't gave any logs");
55
+ const { loss, acc: accuracy } = logs;
56
+ if (loss === undefined || isNaN(loss))
57
+ throw new Error("training loss is undefined or NaN");
58
+ return {
59
+ accuracy,
60
+ loss,
61
+ memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
62
+ };
63
+ }
64
+ async #evaluate(dataset) {
65
+ const evaluation = await this.model.evaluateDataset(dataset.map((t) => {
66
+ switch (t) {
67
+ case null:
68
+ case undefined:
69
+ throw new Error("nullish value in dataset");
70
+ default:
71
+ return t;
72
+ }
73
+ }));
74
+ const metricToValue = Map(List(this.model.metricsNames).zip(Array.isArray(evaluation)
75
+ ? List(await Promise.all(evaluation.map((t) => t.data())))
76
+ : List.of(await evaluation.data()))).map((values) => {
77
+ if (values.length !== 1)
78
+ throw new Error("more than one metric value");
79
+ return values[0];
80
+ });
81
+ const [accuracy, loss] = [
82
+ metricToValue.get("acc"),
83
+ metricToValue.get("loss"),
84
+ ];
85
+ if (accuracy === undefined || loss === undefined)
86
+ throw new Error("some needed metrics are missing");
87
+ return { accuracy, loss };
65
88
  }
66
89
  predict(input) {
67
90
  const ret = this.model.predict(input);
@@ -1,4 +1,4 @@
1
- import type { data, Logger, Memory, Task, TrainingInformation } from '../index.js';
1
+ import { BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js';
2
2
  import { client as clients } from '../index.js';
3
3
  import type { Aggregator } from '../aggregator/index.js';
4
4
  import type { RoundLogs } from './trainer/trainer.js';
@@ -22,13 +22,25 @@ export declare class Disco {
22
22
  private readonly client;
23
23
  private readonly trainer;
24
24
  constructor(task: Task, options: DiscoOptions);
25
- /**
26
- * Starts a training instance for the Disco object's task on the provided data tuple.
27
- * @param dataTuple The data tuple
28
- */
29
- fit(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
25
+ /** Train on dataset, yielding logs of every round. */
26
+ trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
30
27
  participants: number;
31
28
  }>;
29
+ /** Train on dataset, yielding logs of every epoch. */
30
+ trainByEpoch(dataTuple: data.DataSplit): AsyncGenerator<EpochLogs>;
31
+ /** Train on dataset, yielding logs of every batch. */
32
+ trainByBatch(dataTuple: data.DataSplit): AsyncGenerator<BatchLogs>;
33
+ /** Run whole train on dataset. */
34
+ trainFully(dataTuple: data.DataSplit): Promise<void>;
35
+ /**
36
+ * Train on dataset, yield the nested steps.
37
+ *
38
+ * Don't forget to await the yielded generator otherwise nothing will progress.
39
+ * If you don't care about the whole process, use one of the other train methods.
40
+ **/
41
+ train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs & {
42
+ participants: number;
43
+ }>>;
32
44
  /**
33
45
  * Stops the ongoing training instance without disconnecting the client.
34
46
  */
@@ -1,5 +1,8 @@
1
+ import { List } from 'immutable';
2
+ import { async_iterator } from '../index.js';
1
3
  import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
2
4
  import { MeanAggregator } from '../aggregator/mean.js';
5
+ import { enumerate, split } from '../utils/async_iterator.js';
3
6
  import { TrainerBuilder } from './trainer/trainer_builder.js';
4
7
  /**
5
8
  * Top-level class handling distributed training from a client's perspective. It is meant to be
@@ -58,35 +61,79 @@ export class Disco {
58
61
  const trainerBuilder = new TrainerBuilder(this.memory, this.task);
59
62
  this.trainer = trainerBuilder.build(this.client, options.scheme !== 'local');
60
63
  }
64
+ /** Train on dataset, yielding logs of every round. */
65
+ async *trainByRound(dataTuple) {
66
+ for await (const round of this.train(dataTuple)) {
67
+ const [roundGen, roundLogs] = async_iterator.split(round);
68
+ for await (const epoch of roundGen)
69
+ for await (const _ of epoch)
70
+ ;
71
+ yield await roundLogs;
72
+ }
73
+ }
74
+ /** Train on dataset, yielding logs of every epoch. */
75
+ async *trainByEpoch(dataTuple) {
76
+ for await (const round of this.train(dataTuple)) {
77
+ for await (const epoch of round) {
78
+ const [epochGen, epochLogs] = async_iterator.split(epoch);
79
+ for await (const _ of epochGen)
80
+ ;
81
+ yield await epochLogs;
82
+ }
83
+ }
84
+ }
85
+ /** Train on dataset, yielding logs of every batch. */
86
+ async *trainByBatch(dataTuple) {
87
+ for await (const round of this.train(dataTuple))
88
+ for await (const epoch of round)
89
+ yield* epoch;
90
+ }
91
+ /** Run whole train on dataset. */
92
+ async trainFully(dataTuple) {
93
+ for await (const round of this.train(dataTuple))
94
+ for await (const epoch of round)
95
+ for await (const _ of epoch)
96
+ ;
97
+ }
61
98
  /**
62
- * Starts a training instance for the Disco object's task on the provided data tuple.
63
- * @param dataTuple The data tuple
64
- */
99
+ * Train on dataset, yield the nested steps.
100
+ *
101
+ * Don't forget to await the yielded generator otherwise nothing will progress.
102
+ * If you don't care about the whole process, use one of the other train methods.
103
+ **/
65
104
  // TODO RoundLogs should contain number of participants but Trainer doesn't need client
66
- async *fit(dataTuple) {
105
+ async *train(dataTuple) {
67
106
  this.logger.success("Training started.");
68
107
  const trainData = dataTuple.train.preprocess().batch();
69
108
  const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
70
109
  await this.client.connect();
71
110
  const trainer = await this.trainer;
72
- for await (const roundLogs of trainer.fitModel(trainData.dataset, validationData.dataset)) {
73
- let msg = `Round: ${roundLogs.round}\n`;
74
- for (const epochLogs of roundLogs.epochs.values()) {
75
- msg += ` Epoch: ${epochLogs.epoch}\n`;
76
- msg += ` Training loss: ${epochLogs.training.loss}\n`;
77
- if (epochLogs.training.accuracy !== undefined) {
78
- msg += ` Training accuracy: ${epochLogs.training.accuracy}\n`;
111
+ for await (const [round, epochs] of enumerate(trainer.fitModel(trainData.dataset, validationData.dataset))) {
112
+ yield async function* () {
113
+ let epochsLogs = List();
114
+ for await (const [epoch, batches] of enumerate(epochs)) {
115
+ const [gen, returnedEpochLogs] = split(batches);
116
+ yield gen;
117
+ const epochLogs = await returnedEpochLogs;
118
+ epochsLogs = epochsLogs.push(epochLogs);
119
+ this.logger.success([
120
+ `Round: ${round}`,
121
+ ` Epoch: ${epoch}`,
122
+ ` Training loss: ${epochLogs.training.loss}`,
123
+ ` Training accuracy: ${epochLogs.training.accuracy}`,
124
+ epochLogs.validation !== undefined
125
+ ? ` Validation loss: ${epochLogs.validation.loss}`
126
+ : "",
127
+ epochLogs.validation !== undefined
128
+ ? ` Validation accuracy: ${epochLogs.validation.accuracy}`
129
+ : "",
130
+ ].join("\n"));
79
131
  }
80
- if (epochLogs.validation !== undefined) {
81
- msg += ` Validation loss: ${epochLogs.validation.loss}\n`;
82
- msg += ` Validation accuracy: ${epochLogs.validation.accuracy}\n`;
83
- }
84
- }
85
- this.logger.success(msg);
86
- yield {
87
- ...roundLogs,
88
- participants: this.client.nodes.size + 1 // add ourself
89
- };
132
+ return {
133
+ epochs: epochsLogs,
134
+ participants: this.client.nodes.size + 1, // add ourself
135
+ };
136
+ }.bind(this)();
90
137
  }
91
138
  this.logger.success("Training finished.");
92
139
  }
@@ -1,9 +1,8 @@
1
1
  import type tf from "@tensorflow/tfjs";
2
2
  import { List } from "immutable";
3
3
  import type { Model, Task } from "../../index.js";
4
- import { EpochLogs } from "../../models/model.js";
4
+ import { BatchLogs, EpochLogs } from "../../models/index.js";
5
5
  export interface RoundLogs {
6
- round: number;
7
6
  epochs: List<EpochLogs>;
8
7
  }
9
8
  /** Abstract class whose role is to train a model with a given dataset. This can be either done
@@ -29,5 +28,5 @@ export declare abstract class Trainer {
29
28
  * Start training the model with the given dataset
30
29
  * @param dataset
31
30
  */
32
- fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<RoundLogs>;
31
+ fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
33
32
  }
@@ -1,4 +1,5 @@
1
1
  import { List } from "immutable";
2
+ import * as async_iterator from "../../utils/async_iterator.js";
2
3
  /** Abstract class whose role is to train a model with a given dataset. This can be either done
3
4
  * locally (alone) or in a distributed way with collaborators.
4
5
  *
@@ -16,6 +17,8 @@ export class Trainer {
16
17
  this.model = model;
17
18
  this.#roundDuration = task.trainingInformation.roundDuration;
18
19
  this.#epochs = task.trainingInformation.epochs;
20
+ if (!Number.isInteger(this.#epochs / this.#roundDuration))
21
+ throw new Error(`round duration doesn't divide epochs`);
19
22
  }
20
23
  /**
21
24
  * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
@@ -28,25 +31,31 @@ export class Trainer {
28
31
  * @param dataset
29
32
  */
30
33
  async *fitModel(dataset, valDataset) {
31
- if (this.training !== undefined) {
34
+ if (this.training !== undefined)
32
35
  throw new Error("training already running, cancel it before launching a new one");
36
+ try {
37
+ this.training = this.#runRounds(dataset, valDataset);
38
+ yield* this.training;
33
39
  }
34
- await this.onRoundBegin(0);
35
- this.training = this.model.train(dataset, valDataset, this.#epochs);
36
- for await (const logs of this.training) {
37
- // for now, round (sharing on network) == epoch (full pass over local data)
38
- yield {
39
- round: logs.epoch,
40
- epochs: List.of(logs),
41
- };
42
- if (logs.epoch % this.#roundDuration === 0) {
43
- const round = Math.trunc(logs.epoch / this.#roundDuration);
44
- await this.onRoundEnd(round);
45
- await this.onRoundBegin(round);
46
- }
40
+ finally {
41
+ this.training = undefined;
47
42
  }
48
- const round = Math.trunc(this.#epochs / this.#roundDuration);
49
- await this.onRoundEnd(round);
50
- this.training = undefined;
43
+ }
44
+ async *#runRounds(dataset, valDataset) {
45
+ const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
46
+ for (let round = 0; round < totalRound; round++) {
47
+ await this.onRoundBegin(round);
48
+ yield this.#runRound(dataset, valDataset);
49
+ await this.onRoundEnd(round);
50
+ }
51
+ }
52
+ async *#runRound(dataset, valDataset) {
53
+ let epochsLogs = List();
54
+ for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
55
+ const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
56
+ yield gen;
57
+ epochsLogs = epochsLogs.push(await epochLogs);
58
+ }
59
+ return { epochs: epochsLogs };
51
60
  }
52
61
  }
@@ -0,0 +1,11 @@
1
+ import { List } from "immutable";
2
+ /**
3
+ * Split yields from return value
4
+ *
5
+ * You need to consume the iterator to resolve the returned value
6
+ **/
7
+ export declare function split<T, U>(iter: AsyncIterator<T, U>): [AsyncGenerator<T, U>, Promise<U>];
8
+ /** Zip iterator with a infinite counter */
9
+ export declare function enumerate<T, U>(iter: AsyncIterator<T, U> | Iterator<T, U>): AsyncGenerator<[number, T], U>;
10
+ /** Run the whole iterator to get yielded & returned */
11
+ export declare function gather<T, U>(iter: AsyncIterator<T, U>): Promise<[List<T>, U]>;
@@ -0,0 +1,63 @@
1
+ import { List } from "immutable";
2
+ // `Promise.withResolvers` not widely deployed
3
+ function PromiseWithResolvers() {
4
+ let resolve, reject;
5
+ resolve = reject = () => {
6
+ // should not happen as Promise are run on creation
7
+ throw new Error("race condition triggered");
8
+ };
9
+ const promise = new Promise((res, rej) => {
10
+ resolve = res;
11
+ reject = rej;
12
+ });
13
+ return [promise, resolve, reject];
14
+ }
15
+ /**
16
+ * Split yields from return value
17
+ *
18
+ * You need to consume the iterator to resolve the returned value
19
+ **/
20
+ export function split(iter) {
21
+ const [returnPromise, returnResolve, returnReject] = PromiseWithResolvers();
22
+ return [
23
+ (async function* () {
24
+ try {
25
+ while (true) {
26
+ const v = await iter.next();
27
+ if (!v.done) {
28
+ yield v.value;
29
+ continue;
30
+ }
31
+ returnResolve(v.value);
32
+ return v.value;
33
+ }
34
+ }
35
+ catch (e) {
36
+ returnReject(e);
37
+ throw e;
38
+ }
39
+ })(),
40
+ returnPromise,
41
+ ];
42
+ }
43
+ /** Zip iterator with a infinite counter */
44
+ export function enumerate(iter) {
45
+ return (async function* () {
46
+ for (let i = 0;; i++) {
47
+ const v = await iter.next();
48
+ if (v.done)
49
+ return v.value;
50
+ yield [i, v.value];
51
+ }
52
+ })();
53
+ }
54
+ /** Run the whole iterator to get yielded & returned */
55
+ export async function gather(iter) {
56
+ let elems = List();
57
+ for (;;) {
58
+ const v = await iter.next();
59
+ if (v.done)
60
+ return [elems, v.value];
61
+ elems = elems.push(v.value);
62
+ }
63
+ }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.1.2-p20240627125649.0",
3
+ "version": "2.2.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",