@epfml/discojs 2.1.2-p20240507140056.0 → 2.1.2-p20240515132210.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +24 -0
- package/dist/dataset/dataset_builder.js +1 -1
- package/dist/default_tasks/cifar10/index.js +3 -4
- package/dist/default_tasks/index.d.ts +0 -2
- package/dist/default_tasks/index.js +0 -2
- package/dist/default_tasks/lus_covid.js +3 -3
- package/dist/default_tasks/mnist.js +0 -1
- package/dist/default_tasks/simple_face/index.js +0 -2
- package/dist/default_tasks/titanic.js +0 -1
- package/dist/default_tasks/wikitext.js +2 -3
- package/dist/index.d.ts +1 -1
- package/dist/index.js +1 -1
- package/dist/memory/base.d.ts +6 -19
- package/dist/memory/empty.d.ts +2 -2
- package/dist/memory/empty.js +2 -2
- package/dist/memory/index.d.ts +1 -1
- package/dist/memory/index.js +1 -1
- package/dist/memory/model_type.d.ts +1 -1
- package/dist/memory/model_type.js +5 -5
- package/dist/models/gpt/config.d.ts +32 -0
- package/dist/models/gpt/config.js +42 -0
- package/dist/models/gpt/evaluate.d.ts +7 -0
- package/dist/models/gpt/evaluate.js +44 -0
- package/dist/models/gpt/index.d.ts +37 -0
- package/dist/models/gpt/index.js +107 -0
- package/dist/models/gpt/layers.d.ts +13 -0
- package/dist/models/gpt/layers.js +272 -0
- package/dist/models/gpt/model.d.ts +43 -0
- package/dist/models/gpt/model.js +191 -0
- package/dist/models/gpt/optimizers.d.ts +4 -0
- package/dist/models/gpt/optimizers.js +95 -0
- package/dist/models/index.d.ts +5 -0
- package/dist/models/index.js +4 -0
- package/dist/models/model.d.ts +51 -0
- package/dist/models/model.js +8 -0
- package/dist/models/tfjs.d.ts +24 -0
- package/dist/models/tfjs.js +107 -0
- package/dist/models/tokenizer.d.ts +14 -0
- package/dist/models/tokenizer.js +23 -0
- package/dist/task/display_information.d.ts +3 -6
- package/dist/task/display_information.js +21 -10
- package/dist/task/index.d.ts +0 -1
- package/dist/task/index.js +0 -1
- package/dist/training/trainer/trainer_builder.js +2 -2
- package/package.json +1 -1
- package/dist/default_tasks/geotags/index.d.ts +0 -2
- package/dist/default_tasks/geotags/index.js +0 -65
- package/dist/default_tasks/geotags/model.d.ts +0 -593
- package/dist/default_tasks/geotags/model.js +0 -4715
- package/dist/default_tasks/skin_mnist.d.ts +0 -2
- package/dist/default_tasks/skin_mnist.js +0 -80
- package/dist/task/label_type.d.ts +0 -9
- package/dist/task/label_type.js +0 -28
package/README.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# @epfml/discojs
|
|
2
|
+
|
|
3
|
+
Decentralized & federated privacy-preserving ML training in TypeScript.
|
|
4
|
+
|
|
5
|
+
This is the core library of the Disco.js project.
|
|
6
|
+
|
|
7
|
+
It is platform-agnostic, and has two companions library:
|
|
8
|
+
- [`discojs-node`](../discojs-node) for Node.js
|
|
9
|
+
- [`discojs-web`](../discojs-web) for web browsers
|
|
10
|
+
|
|
11
|
+
The easiest way to start using it is through the `Disco` object.
|
|
12
|
+
Create your own `Task` or load one from our `default_tasks`,
|
|
13
|
+
setup the `Dataset` you want, and train with it.
|
|
14
|
+
|
|
15
|
+
```ts
|
|
16
|
+
import { Disco } from '@epfml/discojs'
|
|
17
|
+
|
|
18
|
+
const url = ...; // url to a Disco.js server
|
|
19
|
+
const dataset = ...;
|
|
20
|
+
const task = ...;
|
|
21
|
+
|
|
22
|
+
const disco = new Disco(task, { url })
|
|
23
|
+
for await (const _ of disco.fit(dataset));
|
|
24
|
+
```
|
|
@@ -92,7 +92,7 @@ export class DatasetBuilder {
|
|
|
92
92
|
async build(config) {
|
|
93
93
|
// Require that at least one source collection is non-empty, but not both
|
|
94
94
|
if ((this._sources.length > 0) === (this.labelledSources.size > 0)) {
|
|
95
|
-
throw new Error('Please provide dataset input files');
|
|
95
|
+
throw new Error('Please provide dataset input files'); // This error message is parsed in DatasetInput.vue
|
|
96
96
|
}
|
|
97
97
|
let dataTuple;
|
|
98
98
|
if (this._sources.length > 0) {
|
|
@@ -11,11 +11,10 @@ export const cifar10 = {
|
|
|
11
11
|
preview: 'In this challenge, we ask you to classify images into categories based on the objects shown on the image.',
|
|
12
12
|
overview: 'The CIFAR-10 dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research.'
|
|
13
13
|
},
|
|
14
|
-
limitations: 'The training data is limited to small images of size 32x32.',
|
|
15
|
-
tradeoffs: 'Training success strongly depends on label distribution',
|
|
16
14
|
dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The label file should be .csv, where each row contains a file_name, class. <br> <br> e.g. if you have images: 0.png (of a frog) and 1.png (of a car) <br> labels.csv contains: (Note that no header is needed)<br> 0.png, frog <br> 1.png, car',
|
|
17
15
|
dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
|
|
18
|
-
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png'
|
|
16
|
+
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png',
|
|
17
|
+
sampleDatasetLink: 'https://www.kaggle.com/competitions/cifar-10/data'
|
|
19
18
|
},
|
|
20
19
|
trainingInformation: {
|
|
21
20
|
modelID: 'cifar10-model',
|
|
@@ -27,7 +26,7 @@ export const cifar10 = {
|
|
|
27
26
|
preprocessingFunctions: [data.ImagePreprocessing.Resize],
|
|
28
27
|
IMAGE_H: 224,
|
|
29
28
|
IMAGE_W: 224,
|
|
30
|
-
LABEL_LIST: ['
|
|
29
|
+
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
|
|
31
30
|
scheme: 'decentralized',
|
|
32
31
|
noiseScale: undefined,
|
|
33
32
|
clippingRadius: 20,
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
export { cifar10 } from './cifar10/index.js';
|
|
2
|
-
export { geotags } from './geotags/index.js';
|
|
3
2
|
export { lusCovid } from './lus_covid.js';
|
|
4
3
|
export { mnist } from './mnist.js';
|
|
5
4
|
export { simpleFace } from './simple_face/index.js';
|
|
6
|
-
export { skinMnist } from './skin_mnist.js';
|
|
7
5
|
export { titanic } from './titanic.js';
|
|
8
6
|
export { wikitext } from './wikitext.js';
|
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
export { cifar10 } from './cifar10/index.js';
|
|
2
|
-
export { geotags } from './geotags/index.js';
|
|
3
2
|
export { lusCovid } from './lus_covid.js';
|
|
4
3
|
export { mnist } from './mnist.js';
|
|
5
4
|
export { simpleFace } from './simple_face/index.js';
|
|
6
|
-
export { skinMnist } from './skin_mnist.js';
|
|
7
5
|
export { titanic } from './titanic.js';
|
|
8
6
|
export { wikitext } from './wikitext.js';
|
|
@@ -11,16 +11,16 @@ export const lusCovid = {
|
|
|
11
11
|
overview: "Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
|
|
12
12
|
},
|
|
13
13
|
model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
|
|
14
|
-
tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.',
|
|
15
14
|
dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
|
|
16
15
|
dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
|
|
17
|
-
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png'
|
|
16
|
+
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png',
|
|
17
|
+
sampleDatasetLink: 'https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly'
|
|
18
18
|
},
|
|
19
19
|
trainingInformation: {
|
|
20
20
|
modelID: 'lus-covid-model',
|
|
21
21
|
epochs: 50,
|
|
22
22
|
roundDuration: 2,
|
|
23
|
-
validationSplit: 0,
|
|
23
|
+
validationSplit: 0.2,
|
|
24
24
|
batchSize: 5,
|
|
25
25
|
IMAGE_H: 100,
|
|
26
26
|
IMAGE_W: 100,
|
|
@@ -11,7 +11,6 @@ export const mnist = {
|
|
|
11
11
|
overview: 'The MNIST handwritten digit classification problem is a standard dataset used in computer vision and deep learning. Although the dataset is effectively solved, we use it to test our Decentralised Learning algorithms and platform.'
|
|
12
12
|
},
|
|
13
13
|
model: 'The current model is a very simple CNN and its main goal is to test the app and the Decentralizsed Learning functionality.',
|
|
14
|
-
tradeoffs: 'We are using a simple model, first a 2d convolutional layer > max pooling > 2d convolutional layer > max pooling > convolutional layer > 2 dense layers.',
|
|
15
14
|
dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can upload each digit image of your dataset in the box corresponding to its label. The model taskes images of size 28x28 as input.',
|
|
16
15
|
dataExampleText: 'Below you can find an example of an expected image representing the digit 9.',
|
|
17
16
|
dataExampleImage: 'http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png'
|
|
@@ -11,8 +11,6 @@ export const simpleFace = {
|
|
|
11
11
|
preview: 'Can you detect if the person in a picture is a child or an adult?',
|
|
12
12
|
overview: 'Simple face is a small subset of face_task from Kaggle'
|
|
13
13
|
},
|
|
14
|
-
limitations: 'The training data is limited to small images of size 200x200.',
|
|
15
|
-
tradeoffs: 'Training success strongly depends on label distribution',
|
|
16
14
|
dataFormatInformation: '',
|
|
17
15
|
dataExampleText: 'Below you find an example',
|
|
18
16
|
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
|
|
@@ -11,7 +11,6 @@ export const titanic = {
|
|
|
11
11
|
overview: 'We all know the unfortunate story of the Titanic: this flamboyant new transatlantic boat that sunk in 1912 in the North Atlantic Ocean. Today, we revist this tragedy by trying to predict the survival odds of the passenger given some basic features.'
|
|
12
12
|
},
|
|
13
13
|
model: 'The current model does not normalize the given data and applies only a very simple pre-processing of the data.',
|
|
14
|
-
tradeoffs: 'We are using a small model for this task: 4 fully connected layers with few neurons. This allows fast training but can yield to reduced accuracy.',
|
|
15
14
|
dataFormatInformation: 'This model takes as input a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br><br>pclass: A proxy for socio-economic status (SES)<br>1st = Upper<br>2nd = Middle<br>3rd = Lower<br><br>age: Age is fractional if less than 1. If the age is estimated, it is in the form of xx.5<br><br>sibsp: The dataset defines family relations in this way:<br>Sibling = brother, sister, stepbrother, stepsister<br>Spouse = husband, wife (mistresses and fiancés were ignored)<br><br>parch: The dataset defines family relations in this way:<br>Parent = mother, father<br>Child = daughter, son, stepdaughter, stepson<br>Some children travelled only with a nanny, therefore parch=0 for them.<br><br>The first line of the CSV contains the header:<br> PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked<br><br>Each susequent row contains the corresponding data.',
|
|
16
15
|
dataExampleText: 'Below one can find an example of a datapoint taken as input by our model. In this datapoint, the person is young man named Owen Harris that unfortunnalty perished with the Titanic. He boarded the boat in South Hamptons and was a 3rd class passenger. On the testing & validation page, the data should not contain the label column (Survived).',
|
|
17
16
|
dataExample: [
|
|
@@ -9,10 +9,9 @@ export const wikitext = {
|
|
|
9
9
|
preview: 'In this challenge, we ask you to do next word prediction on a dataset of Wikipedia articles.',
|
|
10
10
|
overview: 'Wikitext-103-raw is a dataset comprising unprocessed text excerpts from Wikipedia articles, designed for tasks related to natural language processing and language modeling.'
|
|
11
11
|
},
|
|
12
|
-
limitations: 'The dataset may contain noise, inconsistencies, and unstructured content due to its raw nature, potentially posing challenges for certain NLP tasks.',
|
|
13
|
-
tradeoffs: 'The raw format may lack structured annotations and may require additional preprocessing for specific applications.',
|
|
14
12
|
dataFormatInformation: 'The dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
|
|
15
|
-
dataExampleText: 'An example excerpt from the dataset could be: "The history of artificial intelligence dates back to ancient times, with philosophical discussions on the nature of thought and reasoning."'
|
|
13
|
+
dataExampleText: 'An example excerpt from the dataset could be: "The history of artificial intelligence dates back to ancient times, with philosophical discussions on the nature of thought and reasoning."',
|
|
14
|
+
sampleDatasetLink: 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'
|
|
16
15
|
},
|
|
17
16
|
trainingInformation: {
|
|
18
17
|
dataType: 'text',
|
package/dist/index.d.ts
CHANGED
|
@@ -8,7 +8,7 @@ export * as aggregator from './aggregator/index.js';
|
|
|
8
8
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
9
9
|
export { AsyncInformant } from './async_informant.js';
|
|
10
10
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
11
|
-
export { Memory,
|
|
11
|
+
export { Memory, StoredModelType, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
|
|
12
12
|
export { Disco, RoundLogs } from './training/index.js';
|
|
13
13
|
export { Validator } from './validation/index.js';
|
|
14
14
|
export { Model, EpochLogs } from './models/index.js';
|
package/dist/index.js
CHANGED
|
@@ -8,7 +8,7 @@ export * as aggregator from './aggregator/index.js';
|
|
|
8
8
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
9
9
|
export { AsyncInformant } from './async_informant.js';
|
|
10
10
|
export { ConsoleLogger } from './logging/index.js';
|
|
11
|
-
export { Memory,
|
|
11
|
+
export { Memory, StoredModelType, Empty as EmptyMemory } from './memory/index.js';
|
|
12
12
|
export { Disco } from './training/index.js';
|
|
13
13
|
export { Validator } from './validation/index.js';
|
|
14
14
|
export { Model } from './models/index.js';
|
package/dist/memory/base.d.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { Model, TaskID } from '../index.js';
|
|
2
|
-
import type {
|
|
2
|
+
import type { StoredModelType } from './model_type.js';
|
|
3
3
|
/**
|
|
4
4
|
* Model path which uniquely identifies a model in memory.
|
|
5
5
|
*/
|
|
@@ -8,22 +8,9 @@ export type Path = string;
|
|
|
8
8
|
* Model information which uniquely identifies a model in memory.
|
|
9
9
|
*/
|
|
10
10
|
export interface ModelInfo {
|
|
11
|
-
|
|
12
|
-
* The model's type: "working" or "saved" model.
|
|
13
|
-
*/
|
|
14
|
-
type?: ModelType;
|
|
15
|
-
/**
|
|
16
|
-
* The model's version, to allow for multiple saved models of a same task without
|
|
17
|
-
* causing id conflicts
|
|
18
|
-
*/
|
|
11
|
+
type?: StoredModelType;
|
|
19
12
|
version?: number;
|
|
20
|
-
/**
|
|
21
|
-
* The model's corresponding task
|
|
22
|
-
*/
|
|
23
13
|
taskID: TaskID;
|
|
24
|
-
/**
|
|
25
|
-
* The model's name
|
|
26
|
-
*/
|
|
27
14
|
name: string;
|
|
28
15
|
}
|
|
29
16
|
/**
|
|
@@ -95,21 +82,21 @@ export declare abstract class Memory {
|
|
|
95
82
|
/**
|
|
96
83
|
* Computes the path in memory corresponding to the given model source, be it a path or model information.
|
|
97
84
|
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
98
|
-
* with a one-to-one
|
|
85
|
+
* with a one-to-one equivalence. Returns undefined instead if no path could be inferred from the given
|
|
99
86
|
* model source.
|
|
100
87
|
* @param source The model source
|
|
101
88
|
* @returns The model path
|
|
102
89
|
*/
|
|
103
|
-
abstract
|
|
90
|
+
abstract getModelMemoryPath(source: ModelSource): Path | undefined;
|
|
104
91
|
/**
|
|
105
92
|
* Computes the model information corresponding to the given model source, be it a path or model information.
|
|
106
93
|
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
107
|
-
* with a one-to-one
|
|
94
|
+
* with a one-to-one equivalence. Returns undefined instead if no unique model information could be inferred
|
|
108
95
|
* from the given model source.
|
|
109
96
|
* @param source The model source
|
|
110
97
|
* @returns The model information
|
|
111
98
|
*/
|
|
112
|
-
abstract
|
|
99
|
+
abstract getModelInfo(source: ModelSource): ModelInfo | undefined;
|
|
113
100
|
/**
|
|
114
101
|
* Computes the lowest version a model source can have without conflicting with model versions currently in memory.
|
|
115
102
|
* @param source The model source
|
package/dist/memory/empty.d.ts
CHANGED
|
@@ -14,7 +14,7 @@ export declare class Empty extends Memory {
|
|
|
14
14
|
saveModel(): Promise<undefined>;
|
|
15
15
|
deleteModel(): Promise<void>;
|
|
16
16
|
downloadModel(): Promise<void>;
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
getModelMemoryPath(): Path;
|
|
18
|
+
getModelInfo(): ModelInfo;
|
|
19
19
|
duplicateSource(): Promise<undefined>;
|
|
20
20
|
}
|
package/dist/memory/empty.js
CHANGED
|
@@ -31,10 +31,10 @@ export class Empty extends Memory {
|
|
|
31
31
|
downloadModel() {
|
|
32
32
|
return Promise.reject(new Error('empty'));
|
|
33
33
|
}
|
|
34
|
-
|
|
34
|
+
getModelMemoryPath() {
|
|
35
35
|
throw new Error('empty');
|
|
36
36
|
}
|
|
37
|
-
|
|
37
|
+
getModelInfo() {
|
|
38
38
|
throw new Error('empty');
|
|
39
39
|
}
|
|
40
40
|
duplicateSource() {
|
package/dist/memory/index.d.ts
CHANGED
package/dist/memory/index.js
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
* being trained ("working model") or a regular model saved in memory ("saved model").
|
|
4
4
|
* There can only be a single working model for a given task.
|
|
5
5
|
*/
|
|
6
|
-
export declare enum
|
|
6
|
+
export declare enum StoredModelType {
|
|
7
7
|
WORKING = "working",
|
|
8
8
|
SAVED = "saved"
|
|
9
9
|
}
|
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
* being trained ("working model") or a regular model saved in memory ("saved model").
|
|
4
4
|
* There can only be a single working model for a given task.
|
|
5
5
|
*/
|
|
6
|
-
export var
|
|
7
|
-
(function (
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
})(
|
|
6
|
+
export var StoredModelType;
|
|
7
|
+
(function (StoredModelType) {
|
|
8
|
+
StoredModelType["WORKING"] = "working";
|
|
9
|
+
StoredModelType["SAVED"] = "saved";
|
|
10
|
+
})(StoredModelType || (StoredModelType = {}));
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
type GPTModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
|
|
2
|
+
export interface GPTConfig {
|
|
3
|
+
lr: number;
|
|
4
|
+
blockSize: number;
|
|
5
|
+
vocabSize: number;
|
|
6
|
+
modelType: GPTModelType;
|
|
7
|
+
name?: string;
|
|
8
|
+
evaluate?: boolean;
|
|
9
|
+
maxEvalBatches?: number;
|
|
10
|
+
evaluateEvery?: number;
|
|
11
|
+
maxIter?: number;
|
|
12
|
+
weightDecay?: number;
|
|
13
|
+
verbose?: 0 | 1;
|
|
14
|
+
bias?: boolean;
|
|
15
|
+
debug?: boolean;
|
|
16
|
+
dropout?: number;
|
|
17
|
+
residDrop?: number;
|
|
18
|
+
embdDrop?: number;
|
|
19
|
+
tokEmb?: boolean;
|
|
20
|
+
lmHead?: boolean;
|
|
21
|
+
nLayer?: number;
|
|
22
|
+
nHead?: number;
|
|
23
|
+
nEmbd?: number;
|
|
24
|
+
}
|
|
25
|
+
export declare const DEFAULT_CONFIG: Required<GPTConfig>;
|
|
26
|
+
export type ModelSize = {
|
|
27
|
+
nLayer: number;
|
|
28
|
+
nHead: number;
|
|
29
|
+
nEmbd: number;
|
|
30
|
+
};
|
|
31
|
+
export declare function getModelSizes(modelType: GPTModelType): Required<ModelSize>;
|
|
32
|
+
export {};
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
|
|
2
|
+
export const DEFAULT_CONFIG = {
|
|
3
|
+
name: 'transformer',
|
|
4
|
+
lr: 0.001,
|
|
5
|
+
weightDecay: 0,
|
|
6
|
+
maxIter: 5,
|
|
7
|
+
verbose: 0,
|
|
8
|
+
modelType: 'gpt-nano',
|
|
9
|
+
evaluate: true,
|
|
10
|
+
maxEvalBatches: 12,
|
|
11
|
+
evaluateEvery: 100,
|
|
12
|
+
blockSize: 128,
|
|
13
|
+
vocabSize: 50258,
|
|
14
|
+
bias: true,
|
|
15
|
+
debug: false,
|
|
16
|
+
dropout: 0.2,
|
|
17
|
+
residDrop: 0.2,
|
|
18
|
+
embdDrop: 0.2,
|
|
19
|
+
tokEmb: true,
|
|
20
|
+
lmHead: true,
|
|
21
|
+
nLayer: 3,
|
|
22
|
+
nHead: 3,
|
|
23
|
+
nEmbd: 48,
|
|
24
|
+
};
|
|
25
|
+
export function getModelSizes(modelType) {
|
|
26
|
+
switch (modelType) {
|
|
27
|
+
case 'gpt2':
|
|
28
|
+
return { nLayer: 12, nHead: 12, nEmbd: 768 };
|
|
29
|
+
case 'gpt2-medium':
|
|
30
|
+
return { nLayer: 24, nHead: 16, nEmbd: 1024 };
|
|
31
|
+
case 'gpt2-large':
|
|
32
|
+
return { nLayer: 36, nHead: 20, nEmbd: 1280 };
|
|
33
|
+
case 'gpt2-xl':
|
|
34
|
+
return { nLayer: 48, nHead: 25, nEmbd: 1600 };
|
|
35
|
+
case 'gpt-mini':
|
|
36
|
+
return { nLayer: 6, nHead: 6, nEmbd: 192 };
|
|
37
|
+
case 'gpt-micro':
|
|
38
|
+
return { nLayer: 4, nHead: 4, nEmbd: 128 };
|
|
39
|
+
case 'gpt-nano':
|
|
40
|
+
return { nLayer: 3, nHead: 3, nEmbd: 48 };
|
|
41
|
+
}
|
|
42
|
+
}
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
interface DataPoint extends tf.TensorContainerObject {
|
|
3
|
+
xs: tf.Tensor2D;
|
|
4
|
+
ys: tf.Tensor3D;
|
|
5
|
+
}
|
|
6
|
+
export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>>;
|
|
7
|
+
export {};
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
export default async function evaluate(model, dataset, maxEvalBatches) {
|
|
3
|
+
let datasetSize = 0;
|
|
4
|
+
let totalLoss = 0;
|
|
5
|
+
const acc = [0, 0];
|
|
6
|
+
await dataset.take(maxEvalBatches).map(({ xs, ys }) => {
|
|
7
|
+
const logits = model.apply(xs);
|
|
8
|
+
if (Array.isArray(logits)) {
|
|
9
|
+
throw new Error('model output too many tensor');
|
|
10
|
+
}
|
|
11
|
+
if (logits instanceof tf.SymbolicTensor) {
|
|
12
|
+
throw new Error('model output symbolic tensor');
|
|
13
|
+
}
|
|
14
|
+
xs.dispose();
|
|
15
|
+
return { logits, ys };
|
|
16
|
+
}).mapAsync(async ({ logits, ys }) => {
|
|
17
|
+
const lossTensor = tf.losses.softmaxCrossEntropy(ys, logits);
|
|
18
|
+
const loss = await lossTensor.array();
|
|
19
|
+
if (typeof loss !== 'number') {
|
|
20
|
+
throw new Error('got multiple loss');
|
|
21
|
+
}
|
|
22
|
+
const accTensor = tf.metrics.categoricalAccuracy(ys, logits);
|
|
23
|
+
const accSize = accTensor.shape.reduce((l, r) => l * r, 1);
|
|
24
|
+
const accSum = accTensor.sum();
|
|
25
|
+
const accSummed = await accSum.array();
|
|
26
|
+
if (typeof accSummed !== 'number') {
|
|
27
|
+
throw new Error('got multiple accuracy sum');
|
|
28
|
+
}
|
|
29
|
+
tf.dispose([ys, logits, accTensor, accSum, lossTensor]);
|
|
30
|
+
return { loss, accSummed, accSize };
|
|
31
|
+
}).forEachAsync(({ loss, accSummed, accSize }) => {
|
|
32
|
+
datasetSize += 1;
|
|
33
|
+
totalLoss += loss;
|
|
34
|
+
acc[0] += accSummed;
|
|
35
|
+
acc[1] += accSize;
|
|
36
|
+
});
|
|
37
|
+
const loss = totalLoss / datasetSize;
|
|
38
|
+
return {
|
|
39
|
+
val_loss: loss,
|
|
40
|
+
val_perplexity: Math.exp(loss),
|
|
41
|
+
acc: acc[0] / acc[1],
|
|
42
|
+
val_acc: acc[0] / acc[1]
|
|
43
|
+
};
|
|
44
|
+
}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import * as tf from '@tensorflow/tfjs';
|
|
5
|
+
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
6
|
+
import { WeightsContainer } from '../../index.js';
|
|
7
|
+
import type { Dataset } from '../../dataset/index.js';
|
|
8
|
+
import { Model } from '../model.js';
|
|
9
|
+
import type { EpochLogs, Prediction, Sample } from '../model.js';
|
|
10
|
+
import type { GPTConfig } from './config.js';
|
|
11
|
+
export declare class GPT extends Model {
|
|
12
|
+
private readonly model;
|
|
13
|
+
constructor(partialConfig?: GPTConfig);
|
|
14
|
+
/**
|
|
15
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
16
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
17
|
+
*
|
|
18
|
+
* @param trainingData training dataset
|
|
19
|
+
* @param validationData validation dataset
|
|
20
|
+
* @param epochs the number of passes of the training dataset
|
|
21
|
+
* @param tracker
|
|
22
|
+
*/
|
|
23
|
+
train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
|
|
24
|
+
predict(input: Sample): Promise<Prediction>;
|
|
25
|
+
generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
|
|
26
|
+
get config(): Required<GPTConfig>;
|
|
27
|
+
get weights(): WeightsContainer;
|
|
28
|
+
set weights(ws: WeightsContainer);
|
|
29
|
+
static deserialize(data: GPTSerialization): Model;
|
|
30
|
+
serialize(): GPTSerialization;
|
|
31
|
+
extract(): tf.LayersModel;
|
|
32
|
+
[Symbol.dispose](): void;
|
|
33
|
+
}
|
|
34
|
+
export type GPTSerialization = {
|
|
35
|
+
weights: WeightsContainer;
|
|
36
|
+
config?: GPTConfig;
|
|
37
|
+
};
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import { WeightsContainer } from '../../index.js';
|
|
5
|
+
import { Model } from '../model.js';
|
|
6
|
+
import { GPTForCausalLM } from './model.js';
|
|
7
|
+
export class GPT extends Model {
|
|
8
|
+
model;
|
|
9
|
+
constructor(partialConfig) {
|
|
10
|
+
super();
|
|
11
|
+
this.model = new GPTForCausalLM(partialConfig);
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
15
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
16
|
+
*
|
|
17
|
+
* @param trainingData training dataset
|
|
18
|
+
* @param validationData validation dataset
|
|
19
|
+
* @param epochs the number of passes of the training dataset
|
|
20
|
+
* @param tracker
|
|
21
|
+
*/
|
|
22
|
+
async *train(trainingData, validationData, epochs = 1) {
|
|
23
|
+
this.model.compile();
|
|
24
|
+
let logs;
|
|
25
|
+
const trainingArgs = {
|
|
26
|
+
epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
|
|
27
|
+
validationData,
|
|
28
|
+
callbacks: { onEpochEnd: (_, cur) => { logs = cur; } },
|
|
29
|
+
};
|
|
30
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
31
|
+
await this.model.fitDataset(trainingData, trainingArgs);
|
|
32
|
+
if (logs === undefined) {
|
|
33
|
+
throw new Error("Epoch didn't gave any logs");
|
|
34
|
+
}
|
|
35
|
+
const { loss, val_acc, val_loss, peakMemory } = logs;
|
|
36
|
+
if (loss === undefined || isNaN(loss)) {
|
|
37
|
+
throw new Error("Training loss is undefined or nan");
|
|
38
|
+
}
|
|
39
|
+
const structuredLogs = {
|
|
40
|
+
epoch,
|
|
41
|
+
peakMemory,
|
|
42
|
+
training: {
|
|
43
|
+
loss: logs.loss
|
|
44
|
+
}
|
|
45
|
+
};
|
|
46
|
+
if (validationData !== undefined) {
|
|
47
|
+
if (val_loss === undefined || isNaN(val_loss) ||
|
|
48
|
+
val_acc === undefined || isNaN(val_acc)) {
|
|
49
|
+
throw new Error("Invalid validation logs");
|
|
50
|
+
}
|
|
51
|
+
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss };
|
|
52
|
+
}
|
|
53
|
+
yield structuredLogs;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
predict(input) {
|
|
57
|
+
const ret = this.model.predict(input);
|
|
58
|
+
if (Array.isArray(ret)) {
|
|
59
|
+
throw new Error('prediction yield many Tensors but should have only returned one');
|
|
60
|
+
}
|
|
61
|
+
return Promise.resolve(ret);
|
|
62
|
+
}
|
|
63
|
+
async generate(input, tokenizer, newTokens = 10) {
|
|
64
|
+
const { input_ids: tokens } = await tokenizer(input, { return_tensor: false });
|
|
65
|
+
const generationConfig = {
|
|
66
|
+
maxNewTokens: newTokens,
|
|
67
|
+
temperature: 1.0,
|
|
68
|
+
doSample: false
|
|
69
|
+
};
|
|
70
|
+
const predictedTokens = await this.model.generate(tokens, generationConfig);
|
|
71
|
+
const generatedWords = tokenizer.decode(predictedTokens[0]);
|
|
72
|
+
return generatedWords;
|
|
73
|
+
}
|
|
74
|
+
get config() {
|
|
75
|
+
return this.model.getGPTConfig;
|
|
76
|
+
}
|
|
77
|
+
get weights() {
|
|
78
|
+
return new WeightsContainer(this.model.weights.map((w) => w.read()));
|
|
79
|
+
}
|
|
80
|
+
set weights(ws) {
|
|
81
|
+
this.model.setWeights(ws.weights);
|
|
82
|
+
}
|
|
83
|
+
static deserialize(data) {
|
|
84
|
+
const model = new GPT(data.config);
|
|
85
|
+
model.weights = data.weights;
|
|
86
|
+
return model;
|
|
87
|
+
}
|
|
88
|
+
serialize() {
|
|
89
|
+
return {
|
|
90
|
+
weights: this.weights,
|
|
91
|
+
config: this.config
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
extract() {
|
|
95
|
+
return this.model;
|
|
96
|
+
}
|
|
97
|
+
[Symbol.dispose]() {
|
|
98
|
+
console.log("Disposing model");
|
|
99
|
+
if (this.model.optimizer !== undefined) {
|
|
100
|
+
this.model.optimizer.dispose();
|
|
101
|
+
}
|
|
102
|
+
// Some tensors are not cleaned up when model.dispose is called
|
|
103
|
+
// So we dispose them manually
|
|
104
|
+
this.model.disposeRefs();
|
|
105
|
+
this.model.dispose();
|
|
106
|
+
}
|
|
107
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { GPTConfig } from './config.js';
|
|
3
|
+
/**
|
|
4
|
+
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
5
|
+
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
6
|
+
* used to create a GPTModel
|
|
7
|
+
*
|
|
8
|
+
* @param conf GPTConfig
|
|
9
|
+
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
|
|
10
|
+
*/
|
|
11
|
+
export declare function GPTArchitecture(config: Required<GPTConfig>, disposalRefs: tf.TensorContainer[], peakMemory: {
|
|
12
|
+
value: number;
|
|
13
|
+
}): tf.LayersModel;
|