@epfml/discojs 2.1.2-p20240624145915.0 → 2.1.2-p20240702170238.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,3 @@
1
- /// <reference types="node" resolution-mode="require"/>
2
1
  import type { NodeID } from '../types.js';
3
2
  export type SignalData = {
4
3
  type: 'answer' | 'offer' | 'pranswer' | 'rollback';
@@ -8,13 +8,15 @@ export const cifar10 = {
8
8
  displayInformation: {
9
9
  taskTitle: 'CIFAR10',
10
10
  summary: {
11
- preview: 'In this challenge, we ask you to classify images into categories based on the objects shown on the image.',
12
- overview: 'The CIFAR-10 dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research.'
11
+ preview: 'CIFAR-10 is a classic image classification task, and one of the most widely used datasets for machine learning research.',
12
+ overview: "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found <a class='underline text-blue-400' href='https://www.cs.toronto.edu/~kriz/cifar.html' target='_blank'>here</a>. You can find a link to a sample dataset at the next step (Connect Your Data)."
13
13
  },
14
- dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The label file should be .csv, where each row contains a file_name, class. <br> <br> e.g. if you have images: 0.png (of a frog) and 1.png (of a car) <br> labels.csv contains: (Note that no header is needed)<br> 0.png, frog <br> 1.png, car',
14
+ model: 'The model is a pretrained <a class="underline text-blue-400" target="_blank" href="https://github.com/tensorflow/tfjs-models/tree/master/mobilenet">MobileNetV1 model</a> trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
15
+ dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The CSV file should start with the exact header "filename,label", and each row should contain an image filename (without extension) and its label.<br><br> For example if you have images: 0.png (of a frog) and 1.png (of a car) <br> The CSV file should be: <br>filename, label <br><br> 0, frog <br> 1, car',
15
16
  dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
16
17
  dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png',
17
- sampleDatasetLink: 'https://www.kaggle.com/competitions/cifar-10/data'
18
+ sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data.tar.gz',
19
+ sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, use the CSV option below and select the file named "cifar10-labels.csv". You can now connect the images located in the "CIFAR10" folder. Note that there are only 24 images in this sample dataset which is far too few to successfully train a machine learning model.'
18
20
  },
19
21
  trainingInformation: {
20
22
  modelID: 'cifar10-model',
@@ -23,7 +25,7 @@ export const cifar10 = {
23
25
  validationSplit: 0.2,
24
26
  batchSize: 10,
25
27
  dataType: 'image',
26
- preprocessingFunctions: [data.ImagePreprocessing.Resize],
28
+ preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
27
29
  IMAGE_H: 224,
28
30
  IMAGE_W: 224,
29
31
  LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
@@ -1,6 +1,5 @@
1
1
  export { cifar10 } from './cifar10.js';
2
2
  export { lusCovid } from './lus_covid.js';
3
- export { skinCondition } from './skin_condition.js';
4
3
  export { mnist } from './mnist.js';
5
4
  export { simpleFace } from './simple_face.js';
6
5
  export { titanic } from './titanic.js';
@@ -1,6 +1,5 @@
1
1
  export { cifar10 } from './cifar10.js';
2
2
  export { lusCovid } from './lus_covid.js';
3
- export { skinCondition } from './skin_condition.js';
4
3
  export { mnist } from './mnist.js';
5
4
  export { simpleFace } from './simple_face.js';
6
5
  export { titanic } from './titanic.js';
@@ -5,16 +5,17 @@ export const lusCovid = {
5
5
  return {
6
6
  id: 'lus_covid',
7
7
  displayInformation: {
8
- taskTitle: 'COVID Lung Ultrasound',
8
+ taskTitle: 'COVID-19 Diagnosis from Lung Ultrasounds',
9
9
  summary: {
10
- preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
11
- overview: "Don't have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
10
+ preview: "Medical images are a typical example of data that exists in huge quantity yet that can't be shared due to confidentiality reasons. Medical applications would immensely benefit from training on data currently locked. More data diversity leads to better generalization and bias mitigation.",
11
+ overview: "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. <br>Don't have a dataset of your own? You can find a link to a sample dataset at the next step."
12
12
  },
13
- model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
14
- dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
15
- dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
13
+ model: "The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 100x100 pixels and normalizes values between 0 and 1",
14
+ dataFormatInformation: 'This model takes as input an image dataset of lung ultrasounds. The images are resized automatically.',
15
+ dataExampleText: 'Below you can find an example of an expected lung image.',
16
16
  dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png',
17
- sampleDatasetLink: 'https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly'
17
+ sampleDatasetLink: 'https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly',
18
+ sampleDatasetInstructions: 'Opening the link will take you to a Switch Drive folder. You can click on the Download button in the top right corner. Unzip the file and you will get two subfolders: "COVID-" and "COVID+". You can connect the data by using the Group option and selecting each image group in its respective field.'
18
19
  },
19
20
  trainingInformation: {
20
21
  modelID: 'lus-covid-model',
@@ -1,30 +1,32 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { models } from '../index.js';
2
+ import { data, models } from '../index.js';
3
3
  export const mnist = {
4
4
  getTask() {
5
5
  return {
6
6
  id: 'mnist',
7
7
  displayInformation: {
8
- taskTitle: 'MNIST',
8
+ taskTitle: 'Handwritten Digit Recognition',
9
9
  summary: {
10
- preview: "Test our platform by using a publicly available <b>image</b> dataset. <br><br> Download the classic MNIST imagebank of hand-written numbers <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/scolianni/mnistasjpg'>here</a>. <br> This model learns to identify hand written numbers.",
11
- overview: 'The MNIST handwritten digit classification problem is a standard dataset used in computer vision and deep learning. Although the dataset is effectively solved, we use it to test our Decentralised Learning algorithms and platform.'
10
+ preview: "The MNIST handwritten digit classification problem is a classic dataset used in computer vision and deep learning. The objective is to classify handwritten digits from 28x28 pixel images.",
11
+ overview: "Download the classic MNIST dataset of hand-written numbers <a class='underline text-blue-400' target='_blank' href='https://www.kaggle.com/scolianni/mnistasjpg'>here</a>. You can also find a sample dataset at the next step."
12
12
  },
13
- model: 'The current model is a very simple CNN and its main goal is to test the app and the Decentralizsed Learning functionality.',
14
- dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can upload each digit image of your dataset in the box corresponding to its label. The model taskes images of size 28x28 as input.',
13
+ model: "The model is a simple Convolutional Neural Network composed of three convolutional layers with ReLU activations and max pooling layers, followed by two fully connected layers. The data preprocessing simply normalizes values between 0 and 1. The neural network is optimized via RMSProp and a categorical cross-entropy loss.",
14
+ dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can connect your own images of each digit in the box corresponding to its label. The model takes images of size 28x28 as input.',
15
15
  dataExampleText: 'Below you can find an example of an expected image representing the digit 9.',
16
- dataExampleImage: 'http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png'
16
+ dataExampleImage: 'http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png',
17
+ sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/MNIST_samples.tar.gz',
18
+ sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. You can connect the data with the CSV option below using the CSV file named "mnist_labels.csv". After selecting in the CSV file, you will be able to connect the data under in the "images" folder.'
17
19
  },
18
20
  trainingInformation: {
19
21
  modelID: 'mnist-model',
20
- epochs: 10,
22
+ epochs: 20,
21
23
  roundDuration: 10,
22
24
  validationSplit: 0.2,
23
25
  batchSize: 30,
24
26
  dataType: 'image',
25
27
  IMAGE_H: 28,
26
28
  IMAGE_W: 28,
27
- preprocessingFunctions: [],
29
+ preprocessingFunctions: [data.ImagePreprocessing.Normalize],
28
30
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
29
31
  scheme: 'decentralized',
30
32
  noiseScale: undefined,
@@ -9,11 +9,13 @@ export const simpleFace = {
9
9
  taskTitle: 'Simple Face',
10
10
  summary: {
11
11
  preview: 'Can you detect if the person in a picture is a child or an adult?',
12
- overview: 'Simple face is a small subset of face_task from Kaggle'
12
+ overview: 'Simple face is a small subset of the public face_task dataset from Kaggle'
13
13
  },
14
14
  dataFormatInformation: '',
15
15
  dataExampleText: 'Below you can find an example',
16
- dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
16
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png',
17
+ sampleDatasetLink: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data.tar.gz",
18
+ sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. Inside the "example_training_data" directory you should find the "simple_face" folder which contains the "adult" and "child" folders. To connect the data, select the Group option below and connect adults and children image groups.'
17
19
  },
18
20
  trainingInformation: {
19
21
  modelID: 'simple_face-model',
@@ -5,14 +5,14 @@ export const titanic = {
5
5
  return {
6
6
  id: 'titanic',
7
7
  displayInformation: {
8
- taskTitle: 'Titanic',
8
+ taskTitle: 'Titanic Prediction',
9
9
  summary: {
10
- preview: "Test our platform by using a publicly available <b>tabular</b> dataset. <br><br> Download the passenger list from the Titanic shipwreck here: <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/raw/develop/example_training_data/titanic_train.csv'>titanic_train.csv</a> (more info <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/c/titanic'>here</a>). <br> This model predicts the type of person most likely to survive/die in the historic ship accident, based on their characteristics (sex, age, class etc.).",
11
- overview: 'We all know the unfortunate story of the Titanic: this flamboyant new transatlantic boat that sunk in 1912 in the North Atlantic Ocean. Today, we revist this tragedy by trying to predict the survival odds of the passenger given some basic features.'
10
+ preview: "The Titanic classification task is one of the main entrypoints into machine learning. Using passenger data (name, age, gender, socio-economic class, etc), the goal is to identify who was more likely to survive the infamous shipwreck.",
11
+ overview: "The original competition can be found on <a target='_blank' class='underline text-blue-400' href='https://www.kaggle.com/c/titanic'>Kaggle</a> and a link to the training set can be found here <a target='_blank' class='underline text-blue-400' href='https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv'>here</a>."
12
12
  },
13
- model: 'The current model does not normalize the given data and applies only a very simple pre-processing of the data.',
14
- dataFormatInformation: 'This model takes as input a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br><br>pclass: A proxy for socio-economic status (SES)<br>1st = Upper<br>2nd = Middle<br>3rd = Lower<br><br>age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5<br><br>sibsp: The dataset defines family relations in this way:<br>Sibling = brother, sister, stepbrother, stepsister<br>Spouse = husband, wife (mistresses and fiancés were ignored)<br><br>parch: The dataset defines family relations in this way:<br>Parent = mother, father<br>Child = daughter, son, stepdaughter, stepson<br>Some children travelled only with a nanny, therefore parch=0 for them.<br><br>The first line of the CSV contains the header:<br> PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked<br><br>Each susequent row contains the corresponding data.',
15
- dataExampleText: 'Below one can find an example of a datapoint taken as input by our model. In this datapoint, the person is young man named Owen Harris that unfortunnalty perished with the Titanic. He boarded the boat in South Hamptons and was a 3rd class passenger. On the testing & validation page, the data should not contain the label column (Survived).',
13
+ model: 'The model is a simple 5-layer feedforward network with ReLU activations. The model is optimized with Adam and binary cross-entropy loss. The preprocessing only fills missing value with a placeholder value (0).',
14
+ dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br>The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked"<br>Each subsequent row contains passenger data.',
15
+ dataExampleText: "Here's an example of one data point:",
16
16
  dataExample: [
17
17
  { columnName: 'PassengerId', columnData: '1' },
18
18
  { columnName: 'Survived', columnData: '0' },
@@ -40,11 +40,13 @@ export const titanic = {
40
40
  'Cabin',
41
41
  'Embarked',
42
42
  'Pclass'
43
- ]
43
+ ],
44
+ sampleDatasetLink: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv",
45
+ sampleDatasetInstructions: 'Opening the link should start downloading a CSV file which you can drag and drop in the field below.'
44
46
  },
45
47
  trainingInformation: {
46
48
  modelID: 'titanic-model',
47
- epochs: 20,
49
+ epochs: 40,
48
50
  roundDuration: 10,
49
51
  validationSplit: 0.2,
50
52
  batchSize: 30,
@@ -79,7 +81,7 @@ export const titanic = {
79
81
  model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
80
82
  model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
81
83
  model.compile({
82
- optimizer: tf.train.sgd(0.001),
84
+ optimizer: 'adam',
83
85
  loss: 'binaryCrossentropy',
84
86
  metrics: ['accuracy']
85
87
  });
@@ -2,28 +2,30 @@ import { data, models } from '../index.js';
2
2
  export const wikitext = {
3
3
  getTask() {
4
4
  return {
5
- id: 'wikitext-103',
5
+ id: 'llm_task',
6
6
  displayInformation: {
7
- taskTitle: 'Language modelling on wikitext',
7
+ taskTitle: 'GPT Language Modeling',
8
8
  summary: {
9
- preview: 'In this challenge, we ask you to do next word prediction on a dataset of Wikipedia articles.',
10
- overview: 'Wikitext-103-raw is a dataset comprising unprocessed text excerpts from Wikipedia articles, designed for tasks related to natural language processing and language modeling.'
9
+ preview: 'Train a language model (L)LM in your browser, collaboratively and from scratch.',
10
+ overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling, which you can download <a class='underline text-blue-400' target='_blank' href='https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'>here</a>. More information on how to connect the dataset at the next step."
11
11
  },
12
- dataFormatInformation: 'The dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
13
- dataExampleText: 'An example excerpt from the dataset could be: "The history of artificial intelligence dates back to ancient times, with philosophical discussions on the nature of thought and reasoning."',
14
- sampleDatasetLink: 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'
12
+ model: 'The model follows the exact GPT-2 architecture and is implemented in TensorFlow.js. The tokenizer used for preprocessing is the GPT-2 Byte-Pair encoding tokenizer. The model is trained via an Adam optimizer with unit gradient clipping and softmax cross-entropy loss. To accommodate all devices, the context length is currently kept at 128 and the batch size at 1.',
13
+ dataFormatInformation: 'You can use any natural language (text) dataset you like. For example the Wikitext-103 dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
14
+ dataExampleText: 'An example excerpt from the dataset is: <i>"For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work ( except eight private performances for Ludwig II at Munich in 1884 and 1885 ) ."</i>',
15
+ sampleDatasetLink: 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz',
16
+ sampleDatasetInstructions: 'Opening the link should start downloading a zip file. Unzip it and drag and drop the training set named "wiki.train.tokens" in the field below (or use the "Select File" button). Even though the file extension is ".tokens" it is indeed a text file. You can use "wiki.test.tokens" at the evaluation step after training a language model.'
15
17
  },
16
18
  trainingInformation: {
17
19
  dataType: 'text',
18
- modelID: 'wikitext-103-raw-model',
20
+ modelID: 'llm-raw-model',
19
21
  preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
20
22
  scheme: 'federated',
21
- epochs: 5,
23
+ epochs: 6,
22
24
  // Unused by wikitext because data already comes split
23
25
  // But if set to 0 then the webapp doesn't display the validation metrics
24
26
  validationSplit: 0.1,
25
27
  roundDuration: 2,
26
- batchSize: 1, // If set too high (e.g. 16) then firefox raises a WebGL error
28
+ batchSize: 1, // If set too high (e.g. 16) firefox raises a WebGL error
27
29
  tokenizer: 'Xenova/gpt2',
28
30
  maxSequenceLength: 128,
29
31
  tensorBackend: 'gpt'
package/dist/index.d.ts CHANGED
@@ -10,8 +10,9 @@ export { Logger, ConsoleLogger } from './logging/index.js';
10
10
  export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
11
11
  export { Disco, RoundLogs } from './training/index.js';
12
12
  export { Validator } from './validation/index.js';
13
- export { Model, EpochLogs } from './models/index.js';
13
+ export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
14
14
  export * as models from './models/index.js';
15
15
  export * from './task/index.js';
16
16
  export * as defaultTasks from './default_tasks/index.js';
17
17
  export * from './types.js';
18
+ export * as async_iterator from "./utils/async_iterator.js";
package/dist/index.js CHANGED
@@ -10,8 +10,9 @@ export { ConsoleLogger } from './logging/index.js';
10
10
  export { Memory, Empty as EmptyMemory } from './memory/index.js';
11
11
  export { Disco } from './training/index.js';
12
12
  export { Validator } from './validation/index.js';
13
- export { Model } from './models/index.js';
13
+ export { Model, EpochLogs } from './models/index.js';
14
14
  export * as models from './models/index.js';
15
15
  export * from './task/index.js';
16
16
  export * as defaultTasks from './default_tasks/index.js';
17
17
  export * from './types.js';
18
+ export * as async_iterator from "./utils/async_iterator.js";
@@ -3,5 +3,5 @@ interface DataPoint extends tf.TensorContainerObject {
3
3
  xs: tf.Tensor2D;
4
4
  ys: tf.Tensor3D;
5
5
  }
6
- export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>>;
6
+ export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'val_acc' | 'val_loss' | 'val_perplexity', number>>;
7
7
  export {};
@@ -38,7 +38,6 @@ export default async function evaluate(model, dataset, maxEvalBatches) {
38
38
  return {
39
39
  val_loss: loss,
40
40
  val_perplexity: Math.exp(loss),
41
- acc: acc[0] / acc[1],
42
41
  val_acc: acc[0] / acc[1]
43
42
  };
44
43
  }
@@ -5,14 +5,15 @@ import * as tf from '@tensorflow/tfjs';
5
5
  import { PreTrainedTokenizer } from '@xenova/transformers';
6
6
  import { WeightsContainer } from '../../index.js';
7
7
  import type { Dataset } from '../../dataset/index.js';
8
- import { Model } from '../model.js';
9
- import type { EpochLogs, Prediction, Sample } from '../model.js';
10
- import type { GPTConfig } from './config.js';
8
+ import { BatchLogs, Model, EpochLogs } from "../index.js";
9
+ import type { Prediction, Sample } from '../model.js';
10
+ import { type GPTConfig } from './config.js';
11
11
  export type GPTSerialization = {
12
12
  weights: WeightsContainer;
13
13
  config?: GPTConfig;
14
14
  };
15
15
  export declare class GPT extends Model {
16
+ #private;
16
17
  private readonly model;
17
18
  constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
18
19
  /**
@@ -24,7 +25,7 @@ export declare class GPT extends Model {
24
25
  * @param epochs the number of passes of the training dataset
25
26
  * @param tracker
26
27
  */
27
- train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
28
+ train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
28
29
  predict(input: Sample): Promise<Prediction>;
29
30
  generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
30
31
  get config(): Required<GPTConfig>;
@@ -1,14 +1,20 @@
1
1
  /**
2
2
  * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
3
3
  **/
4
+ import * as tf from '@tensorflow/tfjs';
4
5
  import { WeightsContainer } from '../../index.js';
5
- import { Model } from '../model.js';
6
+ import { Model, EpochLogs } from "../index.js";
6
7
  import { GPTForCausalLM } from './model.js';
8
+ import { DEFAULT_CONFIG } from './config.js';
9
+ import evaluate from './evaluate.js';
10
+ import { List } from 'immutable';
7
11
  export class GPT extends Model {
8
12
  model;
13
+ #maxBatchCount;
9
14
  constructor(partialConfig, layersModel) {
10
15
  super();
11
16
  this.model = new GPTForCausalLM(partialConfig, layersModel);
17
+ this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter;
12
18
  }
13
19
  /**
14
20
  * The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
@@ -19,45 +25,65 @@ export class GPT extends Model {
19
25
  * @param epochs the number of passes of the training dataset
20
26
  * @param tracker
21
27
  */
22
- async *train(trainingData, validationData, epochs = 1) {
28
+ async *train(trainingData, validationData) {
23
29
  this.model.compile();
30
+ const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
31
+ let batchesLogs = List();
32
+ for (let batchNumber = 0; batchNumber < this.#maxBatchCount; batchNumber++) {
33
+ const iteration = await batches.next();
34
+ if (iteration.done)
35
+ break;
36
+ const batch = iteration.value;
37
+ const batchLogs = await this.#runBatch(batch);
38
+ tf.dispose(batch);
39
+ yield batchLogs;
40
+ batchesLogs = batchesLogs.push(batchLogs);
41
+ }
42
+ const validation = validationData && (await this.#evaluate(validationData));
43
+ return new EpochLogs(batchesLogs, validation);
44
+ }
45
+ async #runBatch(batch) {
24
46
  let logs;
25
- const trainingArgs = {
26
- epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
27
- validationData,
28
- callbacks: { onEpochEnd: (_, cur) => { logs = cur; } },
47
+ await this.model.fitDataset(tf.data.array([batch]), {
48
+ epochs: 1,
49
+ verbose: 0, // don't pollute
50
+ callbacks: {
51
+ onEpochEnd: (_, cur) => {
52
+ logs = cur;
53
+ },
54
+ },
55
+ });
56
+ if (logs === undefined)
57
+ throw new Error("batch didn't gave any logs");
58
+ const { loss, acc: accuracy } = logs;
59
+ if (loss === undefined || isNaN(loss))
60
+ throw new Error("training loss is undefined or NaN");
61
+ return {
62
+ accuracy,
63
+ loss,
64
+ memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
29
65
  };
30
- for (let epoch = 0; epoch < epochs; epoch++) {
31
- await this.model.fitDataset(trainingData, trainingArgs);
32
- if (logs === undefined) {
33
- throw new Error("Epoch didn't gave any logs");
34
- }
35
- const { loss, val_acc, val_loss, peakMemory } = logs;
36
- if (loss === undefined || isNaN(loss)) {
37
- throw new Error("Training loss is undefined or nan");
38
- }
39
- const structuredLogs = {
40
- epoch,
41
- peakMemory,
42
- training: {
43
- loss: logs.loss,
44
- accuracy: logs.acc
45
- }
46
- };
47
- if (validationData !== undefined) {
48
- if (val_loss === undefined || isNaN(val_loss) ||
49
- val_acc === undefined || isNaN(val_acc)) {
50
- throw new Error("Validation accuracy or loss is undefined or nan");
51
- }
52
- structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss };
66
+ }
67
+ async #evaluate(dataset) {
68
+ const evaluation = await evaluate(this.model, dataset.map((t) => {
69
+ switch (t) {
70
+ case null:
71
+ case undefined:
72
+ throw new Error("nullish value in dataset");
73
+ default:
74
+ // TODO unsafe cast
75
+ return t;
53
76
  }
54
- yield structuredLogs;
55
- }
77
+ }), this.config.maxEvalBatches);
78
+ return {
79
+ accuracy: evaluation.val_acc,
80
+ loss: evaluation.val_loss,
81
+ };
56
82
  }
57
83
  predict(input) {
58
84
  const ret = this.model.predict(input);
59
85
  if (Array.isArray(ret)) {
60
- throw new Error('prediction yield many Tensors but should have only returned one');
86
+ throw new Error("prediction yield many Tensors but should have only returned one");
61
87
  }
62
88
  return Promise.resolve(ret);
63
89
  }
@@ -37,29 +37,44 @@ class GPTModel extends tf.LayersModel {
37
37
  const evalDataset = trainingArgs.validationData;
38
38
  await callbacks.onTrainBegin?.();
39
39
  for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
40
+ let accuracyFraction = [0, 0];
40
41
  let averageLoss = 0;
41
- let peakMemory = 0;
42
42
  let iteration = 1;
43
43
  const iterator = await dataset.iterator();
44
- let preprocessingTime = performance.now();
45
44
  let next = await iterator.next();
46
- preprocessingTime = performance.now() - preprocessingTime;
47
45
  while (next.done !== true && iteration <= this.config.maxIter) {
48
46
  let weightUpdateTime = performance.now();
49
47
  await callbacks.onEpochBegin?.(epoch);
50
48
  const { xs, ys } = next.value;
51
- const lossFn = () => {
49
+ let preprocessingTime = performance.now();
50
+ await Promise.all([xs.data(), ys.data()]);
51
+ preprocessingTime = performance.now() - preprocessingTime;
52
+ // TODO include as a tensor inside the model
53
+ const accTensor = tf.tidy(() => {
52
54
  const logits = this.apply(xs);
53
- if (Array.isArray(logits)) {
55
+ if (Array.isArray(logits))
54
56
  throw new Error('model outputs too many tensor');
55
- }
56
- if (logits instanceof tf.SymbolicTensor) {
57
+ if (logits instanceof tf.SymbolicTensor)
57
58
  throw new Error('model outputs symbolic tensor');
58
- }
59
- return tf.losses.softmaxCrossEntropy(ys, logits);
60
- };
59
+ return tf.metrics.categoricalAccuracy(ys, logits);
60
+ });
61
+ const accSize = accTensor.shape.reduce((l, r) => l * r, 1);
62
+ const accSumTensor = accTensor.sum();
63
+ const accSum = await accSumTensor.array();
64
+ tf.dispose(accSumTensor);
65
+ if (typeof accSum !== 'number')
66
+ throw new Error('got multiple accuracy sum');
67
+ accuracyFraction = [accuracyFraction[0] + accSum, accuracyFraction[1] + accSize];
68
+ tf.dispose([accTensor]);
61
69
  const lossTensor = tf.tidy(() => {
62
- const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn);
70
+ const { grads, value: lossTensor } = this.optimizer.computeGradients(() => {
71
+ const logits = this.apply(xs);
72
+ if (Array.isArray(logits))
73
+ throw new Error('model outputs too many tensor');
74
+ if (logits instanceof tf.SymbolicTensor)
75
+ throw new Error('model outputs symbolic tensor');
76
+ return tf.losses.softmaxCrossEntropy(ys, logits);
77
+ });
63
78
  const gradsClipped = clipByGlobalNormObj(grads, 1);
64
79
  this.optimizer.applyGradients(gradsClipped);
65
80
  return lossTensor;
@@ -75,9 +90,6 @@ class GPTModel extends tf.LayersModel {
75
90
  console.log(iterationLogs);
76
91
  }
77
92
  const memory = tf.memory().numBytes / 1024 / 1024 / 1024;
78
- if (memory > peakMemory) {
79
- peakMemory = memory;
80
- }
81
93
  console.log(`Epoch: ${epoch}`, `\tStep: ${iteration} / ${this.config.maxIter}`, `\tLoss: ${loss.toFixed(3)}`, `\tMemory: ${memory.toFixed(2)} GB`, `\tNumber of tensors allocated: ${tf.memory().numTensors}`, `\tPreprocessing time: ${preprocessingTime.toFixed(0)} ms`, `\tWeight update time: ${weightUpdateTime.toFixed(0)} ms`);
82
94
  iteration++;
83
95
  next = await iterator.next();
@@ -89,7 +101,7 @@ class GPTModel extends tf.LayersModel {
89
101
  }
90
102
  let logs = {
91
103
  'loss': averageLoss / iteration,
92
- 'peakMemory': peakMemory
104
+ 'acc': accuracyFraction[0] / accuracyFraction[1],
93
105
  };
94
106
  if (evalDataset !== undefined) {
95
107
  logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) };
@@ -1,4 +1,5 @@
1
- export { EpochLogs, Model } from './model.js';
1
+ export { Model } from './model.js';
2
+ export { BatchLogs, EpochLogs, ValidationMetrics } from "./logs.js";
2
3
  export { GPT } from './gpt/index.js';
3
4
  export { GPTConfig } from './gpt/config.js';
4
5
  export { TFJS } from './tfjs.js';
@@ -1,4 +1,5 @@
1
1
  export { Model } from './model.js';
2
+ export { EpochLogs } from "./logs.js";
2
3
  export { GPT } from './gpt/index.js';
3
4
  export { TFJS } from './tfjs.js';
4
5
  export { getTaskTokenizer } from './tokenizer.js';
@@ -0,0 +1,17 @@
1
+ import { List } from "immutable";
2
+ export interface ValidationMetrics {
3
+ accuracy: number;
4
+ loss: number;
5
+ }
6
+ export interface BatchLogs {
7
+ accuracy: number;
8
+ loss: number;
9
+ memoryUsage: number;
10
+ }
11
+ export declare class EpochLogs {
12
+ readonly validation?: ValidationMetrics | undefined;
13
+ readonly batches: List<BatchLogs>;
14
+ constructor(batches: Iterable<BatchLogs>, validation?: ValidationMetrics | undefined);
15
+ get training(): Record<"accuracy" | "loss", number>;
16
+ get peakMemory(): number;
17
+ }
@@ -0,0 +1,22 @@
1
+ import { List } from "immutable";
2
+ export class EpochLogs {
3
+ validation;
4
+ batches;
5
+ constructor(batches, validation) {
6
+ this.validation = validation;
7
+ this.batches = List(batches);
8
+ }
9
+ get training() {
10
+ const sum = this.batches.reduce((acc, batch) => ({
11
+ accuracy: acc.accuracy + batch.accuracy,
12
+ loss: acc.loss + batch.loss,
13
+ }), { loss: 0, accuracy: 0 });
14
+ return {
15
+ accuracy: sum.accuracy / this.batches.size,
16
+ loss: sum.loss / this.batches.size,
17
+ };
18
+ }
19
+ get peakMemory() {
20
+ return this.batches.map((batch) => batch.memoryUsage).max() ?? 0;
21
+ }
22
+ }
@@ -1,19 +1,7 @@
1
- /// <reference types="node" resolution-mode="require"/>
2
1
  import type tf from "@tensorflow/tfjs";
3
2
  import type { WeightsContainer } from "../index.js";
4
3
  import type { Dataset } from "../dataset/index.js";
5
- export interface EpochLogs {
6
- epoch: number;
7
- training: {
8
- loss: number;
9
- accuracy?: number;
10
- };
11
- validation?: {
12
- loss: number;
13
- accuracy: number;
14
- };
15
- peakMemory: number;
16
- }
4
+ import type { BatchLogs, EpochLogs } from "./logs.js";
17
5
  export type Prediction = tf.Tensor;
18
6
  export type Sample = tf.Tensor;
19
7
  /**
@@ -35,7 +23,7 @@ export declare abstract class Model implements Disposable {
35
23
  * @param tracker watch the various steps
36
24
  * @yields on every epoch, training can be stop by `return`ing it
37
25
  */
38
- abstract train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
26
+ abstract train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
39
27
  /** Predict likely values */
40
28
  abstract predict(input: Sample): Promise<Prediction>;
41
29
  /**
@@ -1,16 +1,18 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
2
  import { WeightsContainer } from '../index.js';
3
- import { Model } from './index.js';
4
- import type { EpochLogs, Prediction, Sample } from './model.js';
5
3
  import type { Dataset } from '../dataset/index.js';
4
+ import { BatchLogs, EpochLogs } from './index.js';
5
+ import { Model } from './index.js';
6
+ import type { Prediction, Sample } from './model.js';
6
7
  /** TensorFlow JavaScript model with standard training */
7
8
  export declare class TFJS extends Model {
9
+ #private;
8
10
  private readonly model;
9
11
  /** Wrap the given trainable model */
10
12
  constructor(model: tf.LayersModel);
11
13
  get weights(): WeightsContainer;
12
14
  set weights(ws: WeightsContainer);
13
- train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs>;
15
+ train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
14
16
  predict(input: Sample): Promise<Prediction>;
15
17
  static deserialize(raw: tf.io.ModelArtifacts): Promise<Model>;
16
18
  serialize(): Promise<tf.io.ModelArtifacts>;
@@ -1,5 +1,7 @@
1
+ import { List, Map } from 'immutable';
1
2
  import * as tf from '@tensorflow/tfjs';
2
3
  import { WeightsContainer } from '../index.js';
4
+ import { EpochLogs } from './index.js';
3
5
  import { Model } from './index.js';
4
6
  /** TensorFlow JavaScript model with standard training */
5
7
  export class TFJS extends Model {
@@ -18,50 +20,71 @@ export class TFJS extends Model {
18
20
  set weights(ws) {
19
21
  this.model.setWeights(ws.weights);
20
22
  }
21
- async *train(trainingData, validationData, epochs = 1) {
22
- for (let epoch = 0; epoch < epochs; epoch++) {
23
- let logs;
24
- let peakMemory = 0;
25
- await this.model.fitDataset(trainingData, {
26
- epochs: 1,
27
- validationData,
28
- callbacks: {
29
- onBatchEnd: (_) => {
30
- const currentMemory = tf.memory().numBytes / 1024 / 1024 / 1024; // GB
31
- if (currentMemory > peakMemory) {
32
- peakMemory = currentMemory;
33
- }
34
- },
35
- onEpochEnd: (_, cur) => { logs = cur; }
36
- },
37
- });
38
- if (logs === undefined) {
39
- throw new Error("Epoch didn't gave any logs");
40
- }
41
- const { loss, acc, val_acc, val_loss } = logs;
42
- if (loss === undefined || isNaN(loss) || acc === undefined || isNaN(acc)) {
43
- throw new Error("Training loss is undefined or nan");
44
- }
45
- const structuredLogs = {
46
- epoch,
47
- peakMemory,
48
- training: {
49
- loss: logs.loss,
50
- accuracy: logs.acc,
51
- }
23
+ async *train(trainingData, validationData) {
24
+ const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
25
+ let batchesLogs = List();
26
+ for (let batchNumber = 0; true; batchNumber++) {
27
+ const iteration = await batches.next();
28
+ if (iteration.done)
29
+ break;
30
+ const batch = iteration.value;
31
+ const batchLogs = {
32
+ batch: batchNumber,
33
+ ...(await this.#runBatch(batch)),
52
34
  };
53
- if (validationData !== undefined) {
54
- if (val_loss === undefined || isNaN(val_loss) ||
55
- val_acc === undefined || isNaN(val_acc)) {
56
- throw new Error("Invalid validation logs");
57
- }
58
- structuredLogs.validation = {
59
- accuracy: logs.val_acc,
60
- loss: logs.val_loss
61
- };
62
- }
63
- yield structuredLogs;
35
+ tf.dispose(batch);
36
+ yield batchLogs;
37
+ batchesLogs = batchesLogs.push(batchLogs);
64
38
  }
39
+ const validation = validationData && (await this.#evaluate(validationData));
40
+ return new EpochLogs(batchesLogs, validation);
41
+ }
42
+ async #runBatch(batch) {
43
+ let logs;
44
+ await this.model.fitDataset(tf.data.array([batch]), {
45
+ epochs: 1,
46
+ verbose: 0, // don't pollute
47
+ callbacks: {
48
+ onEpochEnd: (_, cur) => {
49
+ logs = cur;
50
+ },
51
+ },
52
+ });
53
+ if (logs === undefined)
54
+ throw new Error("batch didn't gave any logs");
55
+ const { loss, acc: accuracy } = logs;
56
+ if (loss === undefined || isNaN(loss))
57
+ throw new Error("training loss is undefined or NaN");
58
+ return {
59
+ accuracy,
60
+ loss,
61
+ memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
62
+ };
63
+ }
64
+ async #evaluate(dataset) {
65
+ const evaluation = await this.model.evaluateDataset(dataset.map((t) => {
66
+ switch (t) {
67
+ case null:
68
+ case undefined:
69
+ throw new Error("nullish value in dataset");
70
+ default:
71
+ return t;
72
+ }
73
+ }));
74
+ const metricToValue = Map(List(this.model.metricsNames).zip(Array.isArray(evaluation)
75
+ ? List(await Promise.all(evaluation.map((t) => t.data())))
76
+ : List.of(await evaluation.data()))).map((values) => {
77
+ if (values.length !== 1)
78
+ throw new Error("more than one metric value");
79
+ return values[0];
80
+ });
81
+ const [accuracy, loss] = [
82
+ metricToValue.get("acc"),
83
+ metricToValue.get("loss"),
84
+ ];
85
+ if (accuracy === undefined || loss === undefined)
86
+ throw new Error("some needed metrics are missing");
87
+ return { accuracy, loss };
65
88
  }
66
89
  predict(input) {
67
90
  const ret = this.model.predict(input);
@@ -10,5 +10,6 @@ export interface DisplayInformation {
10
10
  headers?: string[];
11
11
  dataExampleImage?: string;
12
12
  sampleDatasetLink?: string;
13
+ sampleDatasetInstructions?: string;
13
14
  }
14
15
  export declare function isDisplayInformation(raw: unknown): raw is DisplayInformation;
@@ -4,13 +4,14 @@ export function isDisplayInformation(raw) {
4
4
  if (typeof raw !== 'object' || raw === null) {
5
5
  return false;
6
6
  }
7
- const { dataExample, dataExampleImage, dataExampleText, dataFormatInformation, sampleDatasetLink, headers, model, summary, taskTitle, } = raw;
7
+ const { dataExample, dataExampleImage, dataExampleText, dataFormatInformation, sampleDatasetLink, sampleDatasetInstructions, headers, model, summary, taskTitle, } = raw;
8
8
  if (typeof taskTitle !== 'string' ||
9
9
  (dataExampleText !== undefined && typeof dataExampleText !== 'string') ||
10
10
  (sampleDatasetLink !== undefined && typeof sampleDatasetLink !== 'string') ||
11
11
  (dataFormatInformation !== undefined && typeof dataFormatInformation !== 'string') ||
12
12
  (model !== undefined && typeof model !== 'string') ||
13
- (dataExampleImage !== undefined && typeof dataExampleImage !== 'string')) {
13
+ (dataExampleImage !== undefined && typeof dataExampleImage !== 'string') ||
14
+ (sampleDatasetInstructions !== undefined && typeof sampleDatasetInstructions !== 'string')) {
14
15
  return false;
15
16
  }
16
17
  if (!isSummary(summary)) {
@@ -46,6 +47,7 @@ export function isDisplayInformation(raw) {
46
47
  dataExampleText,
47
48
  dataFormatInformation,
48
49
  sampleDatasetLink,
50
+ sampleDatasetInstructions,
49
51
  headers,
50
52
  model,
51
53
  summary,
@@ -1,4 +1,4 @@
1
- import type { data, Logger, Memory, Task, TrainingInformation } from '../index.js';
1
+ import { BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js';
2
2
  import { client as clients } from '../index.js';
3
3
  import type { Aggregator } from '../aggregator/index.js';
4
4
  import type { RoundLogs } from './trainer/trainer.js';
@@ -22,13 +22,25 @@ export declare class Disco {
22
22
  private readonly client;
23
23
  private readonly trainer;
24
24
  constructor(task: Task, options: DiscoOptions);
25
- /**
26
- * Starts a training instance for the Disco object's task on the provided data tuple.
27
- * @param dataTuple The data tuple
28
- */
29
- fit(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
25
+ /** Train on dataset, yielding logs of every round. */
26
+ trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
30
27
  participants: number;
31
28
  }>;
29
+ /** Train on dataset, yielding logs of every epoch. */
30
+ trainByEpoch(dataTuple: data.DataSplit): AsyncGenerator<EpochLogs>;
31
+ /** Train on dataset, yielding logs of every batch. */
32
+ trainByBatch(dataTuple: data.DataSplit): AsyncGenerator<BatchLogs>;
33
+ /** Run whole train on dataset. */
34
+ trainFully(dataTuple: data.DataSplit): Promise<void>;
35
+ /**
36
+ * Train on dataset, yield the nested steps.
37
+ *
38
+ * Don't forget to await the yielded generator otherwise nothing will progress.
39
+ * If you don't care about the whole process, use one of the other train methods.
40
+ **/
41
+ train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs & {
42
+ participants: number;
43
+ }>>;
32
44
  /**
33
45
  * Stops the ongoing training instance without disconnecting the client.
34
46
  */
@@ -1,5 +1,8 @@
1
+ import { List } from 'immutable';
2
+ import { async_iterator } from '../index.js';
1
3
  import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
2
4
  import { MeanAggregator } from '../aggregator/mean.js';
5
+ import { enumerate, split } from '../utils/async_iterator.js';
3
6
  import { TrainerBuilder } from './trainer/trainer_builder.js';
4
7
  /**
5
8
  * Top-level class handling distributed training from a client's perspective. It is meant to be
@@ -58,35 +61,79 @@ export class Disco {
58
61
  const trainerBuilder = new TrainerBuilder(this.memory, this.task);
59
62
  this.trainer = trainerBuilder.build(this.client, options.scheme !== 'local');
60
63
  }
64
+ /** Train on dataset, yielding logs of every round. */
65
+ async *trainByRound(dataTuple) {
66
+ for await (const round of this.train(dataTuple)) {
67
+ const [roundGen, roundLogs] = async_iterator.split(round);
68
+ for await (const epoch of roundGen)
69
+ for await (const _ of epoch)
70
+ ;
71
+ yield await roundLogs;
72
+ }
73
+ }
74
+ /** Train on dataset, yielding logs of every epoch. */
75
+ async *trainByEpoch(dataTuple) {
76
+ for await (const round of this.train(dataTuple)) {
77
+ for await (const epoch of round) {
78
+ const [epochGen, epochLogs] = async_iterator.split(epoch);
79
+ for await (const _ of epochGen)
80
+ ;
81
+ yield await epochLogs;
82
+ }
83
+ }
84
+ }
85
+ /** Train on dataset, yielding logs of every batch. */
86
+ async *trainByBatch(dataTuple) {
87
+ for await (const round of this.train(dataTuple))
88
+ for await (const epoch of round)
89
+ yield* epoch;
90
+ }
91
+ /** Run whole train on dataset. */
92
+ async trainFully(dataTuple) {
93
+ for await (const round of this.train(dataTuple))
94
+ for await (const epoch of round)
95
+ for await (const _ of epoch)
96
+ ;
97
+ }
61
98
  /**
62
- * Starts a training instance for the Disco object's task on the provided data tuple.
63
- * @param dataTuple The data tuple
64
- */
99
+ * Train on dataset, yield the nested steps.
100
+ *
101
+ * Don't forget to await the yielded generator otherwise nothing will progress.
102
+ * If you don't care about the whole process, use one of the other train methods.
103
+ **/
65
104
  // TODO RoundLogs should contain number of participants but Trainer doesn't need client
66
- async *fit(dataTuple) {
105
+ async *train(dataTuple) {
67
106
  this.logger.success("Training started.");
68
107
  const trainData = dataTuple.train.preprocess().batch();
69
108
  const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
70
109
  await this.client.connect();
71
110
  const trainer = await this.trainer;
72
- for await (const roundLogs of trainer.fitModel(trainData.dataset, validationData.dataset)) {
73
- let msg = `Round: ${roundLogs.round}\n`;
74
- for (const epochLogs of roundLogs.epochs.values()) {
75
- msg += ` Epoch: ${epochLogs.epoch}\n`;
76
- msg += ` Training loss: ${epochLogs.training.loss}\n`;
77
- if (epochLogs.training.accuracy !== undefined) {
78
- msg += ` Training accuracy: ${epochLogs.training.accuracy}\n`;
111
+ for await (const [round, epochs] of enumerate(trainer.fitModel(trainData.dataset, validationData.dataset))) {
112
+ yield async function* () {
113
+ let epochsLogs = List();
114
+ for await (const [epoch, batches] of enumerate(epochs)) {
115
+ const [gen, returnedEpochLogs] = split(batches);
116
+ yield gen;
117
+ const epochLogs = await returnedEpochLogs;
118
+ epochsLogs = epochsLogs.push(epochLogs);
119
+ this.logger.success([
120
+ `Round: ${round}`,
121
+ ` Epoch: ${epoch}`,
122
+ ` Training loss: ${epochLogs.training.loss}`,
123
+ ` Training accuracy: ${epochLogs.training.accuracy}`,
124
+ epochLogs.validation !== undefined
125
+ ? ` Validation loss: ${epochLogs.validation.loss}`
126
+ : "",
127
+ epochLogs.validation !== undefined
128
+ ? ` Validation accuracy: ${epochLogs.validation.accuracy}`
129
+ : "",
130
+ ].join("\n"));
79
131
  }
80
- if (epochLogs.validation !== undefined) {
81
- msg += ` Validation loss: ${epochLogs.validation.loss}\n`;
82
- msg += ` Validation accuracy: ${epochLogs.validation.accuracy}\n`;
83
- }
84
- }
85
- this.logger.success(msg);
86
- yield {
87
- ...roundLogs,
88
- participants: this.client.nodes.size + 1 // add ourself
89
- };
132
+ return {
133
+ epochs: epochsLogs,
134
+ participants: this.client.nodes.size + 1, // add ourself
135
+ };
136
+ }.bind(this)();
90
137
  }
91
138
  this.logger.success("Training finished.");
92
139
  }
@@ -1,9 +1,8 @@
1
1
  import type tf from "@tensorflow/tfjs";
2
2
  import { List } from "immutable";
3
3
  import type { Model, Task } from "../../index.js";
4
- import { EpochLogs } from "../../models/model.js";
4
+ import { BatchLogs, EpochLogs } from "../../models/index.js";
5
5
  export interface RoundLogs {
6
- round: number;
7
6
  epochs: List<EpochLogs>;
8
7
  }
9
8
  /** Abstract class whose role is to train a model with a given dataset. This can be either done
@@ -29,5 +28,5 @@ export declare abstract class Trainer {
29
28
  * Start training the model with the given dataset
30
29
  * @param dataset
31
30
  */
32
- fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<RoundLogs>;
31
+ fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
33
32
  }
@@ -1,4 +1,5 @@
1
1
  import { List } from "immutable";
2
+ import * as async_iterator from "../../utils/async_iterator.js";
2
3
  /** Abstract class whose role is to train a model with a given dataset. This can be either done
3
4
  * locally (alone) or in a distributed way with collaborators.
4
5
  *
@@ -16,6 +17,8 @@ export class Trainer {
16
17
  this.model = model;
17
18
  this.#roundDuration = task.trainingInformation.roundDuration;
18
19
  this.#epochs = task.trainingInformation.epochs;
20
+ if (!Number.isInteger(this.#epochs / this.#roundDuration))
21
+ throw new Error(`round duration doesn't divide epochs`);
19
22
  }
20
23
  /**
21
24
  * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
@@ -28,25 +31,31 @@ export class Trainer {
28
31
  * @param dataset
29
32
  */
30
33
  async *fitModel(dataset, valDataset) {
31
- if (this.training !== undefined) {
34
+ if (this.training !== undefined)
32
35
  throw new Error("training already running, cancel it before launching a new one");
36
+ try {
37
+ this.training = this.#runRounds(dataset, valDataset);
38
+ yield* this.training;
33
39
  }
34
- await this.onRoundBegin(0);
35
- this.training = this.model.train(dataset, valDataset, this.#epochs);
36
- for await (const logs of this.training) {
37
- // for now, round (sharing on network) == epoch (full pass over local data)
38
- yield {
39
- round: logs.epoch,
40
- epochs: List.of(logs),
41
- };
42
- if (logs.epoch % this.#roundDuration === 0) {
43
- const round = Math.trunc(logs.epoch / this.#roundDuration);
44
- await this.onRoundEnd(round);
45
- await this.onRoundBegin(round);
46
- }
40
+ finally {
41
+ this.training = undefined;
47
42
  }
48
- const round = Math.trunc(this.#epochs / this.#roundDuration);
49
- await this.onRoundEnd(round);
50
- this.training = undefined;
43
+ }
44
+ async *#runRounds(dataset, valDataset) {
45
+ const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
46
+ for (let round = 0; round < totalRound; round++) {
47
+ await this.onRoundBegin(round);
48
+ yield this.#runRound(dataset, valDataset);
49
+ await this.onRoundEnd(round);
50
+ }
51
+ }
52
+ async *#runRound(dataset, valDataset) {
53
+ let epochsLogs = List();
54
+ for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
55
+ const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
56
+ yield gen;
57
+ epochsLogs = epochsLogs.push(await epochLogs);
58
+ }
59
+ return { epochs: epochsLogs };
51
60
  }
52
61
  }
@@ -0,0 +1,11 @@
1
+ import { List } from "immutable";
2
+ /**
3
+ * Split yields from return value
4
+ *
5
+ * You need to consume the iterator to resolve the returned value
6
+ **/
7
+ export declare function split<T, U>(iter: AsyncIterator<T, U>): [AsyncGenerator<T, U>, Promise<U>];
8
+ /** Zip iterator with a infinite counter */
9
+ export declare function enumerate<T, U>(iter: AsyncIterator<T, U> | Iterator<T, U>): AsyncGenerator<[number, T], U>;
10
+ /** Run the whole iterator to get yielded & returned */
11
+ export declare function gather<T, U>(iter: AsyncIterator<T, U>): Promise<[List<T>, U]>;
@@ -0,0 +1,63 @@
1
+ import { List } from "immutable";
2
+ // `Promise.withResolvers` not widely deployed
3
+ function PromiseWithResolvers() {
4
+ let resolve, reject;
5
+ resolve = reject = () => {
6
+ // should not happen as Promise are run on creation
7
+ throw new Error("race condition triggered");
8
+ };
9
+ const promise = new Promise((res, rej) => {
10
+ resolve = res;
11
+ reject = rej;
12
+ });
13
+ return [promise, resolve, reject];
14
+ }
15
+ /**
16
+ * Split yields from return value
17
+ *
18
+ * You need to consume the iterator to resolve the returned value
19
+ **/
20
+ export function split(iter) {
21
+ const [returnPromise, returnResolve, returnReject] = PromiseWithResolvers();
22
+ return [
23
+ (async function* () {
24
+ try {
25
+ while (true) {
26
+ const v = await iter.next();
27
+ if (!v.done) {
28
+ yield v.value;
29
+ continue;
30
+ }
31
+ returnResolve(v.value);
32
+ return v.value;
33
+ }
34
+ }
35
+ catch (e) {
36
+ returnReject(e);
37
+ throw e;
38
+ }
39
+ })(),
40
+ returnPromise,
41
+ ];
42
+ }
43
+ /** Zip iterator with a infinite counter */
44
+ export function enumerate(iter) {
45
+ return (async function* () {
46
+ for (let i = 0;; i++) {
47
+ const v = await iter.next();
48
+ if (v.done)
49
+ return v.value;
50
+ yield [i, v.value];
51
+ }
52
+ })();
53
+ }
54
+ /** Run the whole iterator to get yielded & returned */
55
+ export async function gather(iter) {
56
+ let elems = List();
57
+ for (;;) {
58
+ const v = await iter.next();
59
+ if (v.done)
60
+ return [elems, v.value];
61
+ elems = elems.push(v.value);
62
+ }
63
+ }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.1.2-p20240624145915.0",
3
+ "version": "2.1.2-p20240702170238.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,2 +0,0 @@
1
- import type { TaskProvider } from '../index.js';
2
- export declare const skinCondition: TaskProvider;
@@ -1,80 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../index.js';
3
- const IMAGE_SIZE = 128;
4
- const LABELS = ['Eczema', 'Allergic Contact Dermatitis', 'Urticaria'];
5
- export const skinCondition = {
6
- getTask() {
7
- return {
8
- id: 'skin_condition',
9
- displayInformation: {
10
- taskTitle: 'Skin Condition Classification',
11
- summary: {
12
- preview: "Identify common skin conditions from volunteer image contributions. You can find a sample dataset of 400 images <a class='underline text-primary-dark dark:text-primary-light' href='https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'>here</a> or see the full <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN dataset</a>. You can find how to download and preprocess the dataset <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/blob/develop/docs/examples/scin_dataset.ipynb'>in this notebook</a>.",
13
- overview: "The <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN (Skin Condition Image Network) open access dataset</a> aims to supplement publicly available dermatology datasets from health system sources with representative images from internet users. To this end, the SCIN dataset was collected from Google Search users in the United States through a voluntary, consented image donation application. The SCIN dataset is intended for health education and research, and to increase the diversity of dermatology images available for public use. The SCIN dataset contains 5,000+ volunteer contributions (10,000+ images) of common dermatology conditions. Contributions include Images, self-reported demographic, history, and symptom information, and self-reported Fitzpatrick skin type (sFST). In addition, dermatologist labels of the skin condition are provided for each contribution. You can find more information on the dataset and classification task <a class='underline text-primary-dark dark:text-primary-light' href='https://arxiv.org/abs/2402.18545'>here</a>."
14
- },
15
- dataFormatInformation: "There are hundreds of skin condition labels in the SCIN dataset. For the sake of simplicity, we only include the 3 most common conditions in the sample dataset: 'Eczema', 'Allergic Contact Dermatitis' and 'Urticaria'. Therefore, each image is expected to be labeled with one of these three categories.",
16
- sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'
17
- },
18
- trainingInformation: {
19
- modelID: 'skin-condition-model',
20
- epochs: 10,
21
- roundDuration: 2,
22
- validationSplit: 0.3,
23
- batchSize: 8,
24
- preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
25
- dataType: 'image',
26
- IMAGE_H: IMAGE_SIZE,
27
- IMAGE_W: IMAGE_SIZE,
28
- LABEL_LIST: LABELS,
29
- scheme: 'federated',
30
- noiseScale: undefined,
31
- clippingRadius: undefined,
32
- tensorBackend: 'tfjs'
33
- }
34
- };
35
- },
36
- async getModel() {
37
- const imageChannels = 3;
38
- const numOutputClasses = LABELS.length;
39
- const model = tf.sequential();
40
- model.add(tf.layers.conv2d({
41
- inputShape: [IMAGE_SIZE, IMAGE_SIZE, imageChannels],
42
- filters: 8,
43
- kernelSize: 3,
44
- strides: 1,
45
- kernelInitializer: 'varianceScaling',
46
- activation: 'relu'
47
- }));
48
- model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
49
- model.add(tf.layers.dropout({ rate: 0.2 }));
50
- const convFilters = [16, 32, 64, 128];
51
- for (const filters of convFilters) {
52
- model.add(tf.layers.conv2d({
53
- filters: filters,
54
- kernelSize: 3,
55
- strides: 1,
56
- kernelInitializer: 'varianceScaling',
57
- activation: 'relu'
58
- }));
59
- model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
60
- model.add(tf.layers.dropout({ rate: 0.2 }));
61
- }
62
- model.add(tf.layers.flatten());
63
- model.add(tf.layers.dense({
64
- units: 64,
65
- kernelInitializer: 'varianceScaling',
66
- activation: 'relu',
67
- }));
68
- model.add(tf.layers.dense({
69
- units: numOutputClasses,
70
- kernelInitializer: 'varianceScaling',
71
- activation: 'softmax'
72
- }));
73
- model.compile({
74
- optimizer: tf.train.adam(),
75
- loss: 'categoricalCrossentropy',
76
- metrics: ['accuracy']
77
- });
78
- return Promise.resolve(new models.TFJS(model));
79
- }
80
- };