@epfml/discojs 2.1.2-p20240528164510.0 → 2.1.2-p20240603114517.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/dataset/dataset_builder.d.ts +2 -11
- package/dist/dataset/dataset_builder.js +22 -46
- package/dist/default_tasks/cifar10.d.ts +2 -0
- package/dist/default_tasks/{cifar10/index.js → cifar10.js} +2 -2
- package/dist/default_tasks/index.d.ts +3 -2
- package/dist/default_tasks/index.js +3 -2
- package/dist/default_tasks/lus_covid.js +1 -1
- package/dist/default_tasks/simple_face.d.ts +2 -0
- package/dist/default_tasks/{simple_face/index.js → simple_face.js} +3 -3
- package/dist/default_tasks/skin_condition.d.ts +2 -0
- package/dist/default_tasks/skin_condition.js +79 -0
- package/dist/models/gpt/config.d.ts +32 -0
- package/dist/models/gpt/config.js +42 -0
- package/dist/models/gpt/evaluate.d.ts +7 -0
- package/dist/models/gpt/evaluate.js +44 -0
- package/dist/models/gpt/index.d.ts +35 -0
- package/dist/models/gpt/index.js +104 -0
- package/dist/models/gpt/layers.d.ts +13 -0
- package/dist/models/gpt/layers.js +272 -0
- package/dist/models/gpt/model.d.ts +43 -0
- package/dist/models/gpt/model.js +191 -0
- package/dist/models/gpt/optimizers.d.ts +4 -0
- package/dist/models/gpt/optimizers.js +95 -0
- package/dist/models/index.d.ts +5 -0
- package/dist/models/index.js +4 -0
- package/dist/{default_tasks/simple_face/model.js → models/mobileNetV2_35_alpha_2_classes.js} +2 -0
- package/dist/{default_tasks/cifar10/model.js → models/mobileNet_v1_025_224.js} +1 -0
- package/dist/models/model.d.ts +51 -0
- package/dist/models/model.js +8 -0
- package/dist/models/tfjs.d.ts +24 -0
- package/dist/models/tfjs.js +107 -0
- package/dist/models/tokenizer.d.ts +14 -0
- package/dist/models/tokenizer.js +22 -0
- package/dist/validation/validator.js +8 -7
- package/package.json +1 -1
- package/dist/default_tasks/cifar10/index.d.ts +0 -2
- package/dist/default_tasks/simple_face/index.d.ts +0 -2
- /package/dist/{default_tasks/simple_face/model.d.ts → models/mobileNetV2_35_alpha_2_classes.d.ts} +0 -0
- /package/dist/{default_tasks/cifar10/model.d.ts → models/mobileNet_v1_025_224.d.ts} +0 -0
|
@@ -17,15 +17,11 @@ export declare class DatasetBuilder<Source> {
|
|
|
17
17
|
/**
|
|
18
18
|
* The buffer of unlabelled file sources.
|
|
19
19
|
*/
|
|
20
|
-
private
|
|
20
|
+
private _unlabeledSources;
|
|
21
21
|
/**
|
|
22
22
|
* The buffer of labelled file sources.
|
|
23
23
|
*/
|
|
24
|
-
private
|
|
25
|
-
/**
|
|
26
|
-
* Whether a dataset was already produced.
|
|
27
|
-
*/
|
|
28
|
-
private _built;
|
|
24
|
+
private _labeledSources;
|
|
29
25
|
constructor(
|
|
30
26
|
/**
|
|
31
27
|
* The data loader used to load the data contained in the provided files.
|
|
@@ -48,13 +44,8 @@ export declare class DatasetBuilder<Source> {
|
|
|
48
44
|
* @param label The file sources label
|
|
49
45
|
*/
|
|
50
46
|
clearFiles(label?: string): void;
|
|
51
|
-
private resetBuiltState;
|
|
52
47
|
private getLabels;
|
|
53
48
|
build(config?: DataConfig): Promise<DataSplit>;
|
|
54
|
-
/**
|
|
55
|
-
* Whether the dataset builder has already been consumed to produce a dataset.
|
|
56
|
-
*/
|
|
57
|
-
get built(): boolean;
|
|
58
49
|
get size(): number;
|
|
59
50
|
get sources(): Source[];
|
|
60
51
|
}
|
|
@@ -9,16 +9,11 @@ export class DatasetBuilder {
|
|
|
9
9
|
/**
|
|
10
10
|
* The buffer of unlabelled file sources.
|
|
11
11
|
*/
|
|
12
|
-
|
|
12
|
+
_unlabeledSources;
|
|
13
13
|
/**
|
|
14
14
|
* The buffer of labelled file sources.
|
|
15
15
|
*/
|
|
16
|
-
|
|
17
|
-
/**
|
|
18
|
-
* Whether a dataset was already produced.
|
|
19
|
-
*/
|
|
20
|
-
// TODO useless, responsibility on callers
|
|
21
|
-
_built;
|
|
16
|
+
_labeledSources;
|
|
22
17
|
constructor(
|
|
23
18
|
/**
|
|
24
19
|
* The data loader used to load the data contained in the provided files.
|
|
@@ -30,9 +25,9 @@ export class DatasetBuilder {
|
|
|
30
25
|
task) {
|
|
31
26
|
this.dataLoader = dataLoader;
|
|
32
27
|
this.task = task;
|
|
33
|
-
this.
|
|
34
|
-
|
|
35
|
-
this.
|
|
28
|
+
this._unlabeledSources = [];
|
|
29
|
+
// Map from label to sources
|
|
30
|
+
this._labeledSources = Map();
|
|
36
31
|
}
|
|
37
32
|
/**
|
|
38
33
|
* Adds the given file sources to the builder's buffer. Sources may be provided a label in the case
|
|
@@ -41,19 +36,16 @@ export class DatasetBuilder {
|
|
|
41
36
|
* @param label The file sources label
|
|
42
37
|
*/
|
|
43
38
|
addFiles(sources, label) {
|
|
44
|
-
if (this.built) {
|
|
45
|
-
this.resetBuiltState();
|
|
46
|
-
}
|
|
47
39
|
if (label === undefined) {
|
|
48
|
-
this.
|
|
40
|
+
this._unlabeledSources = this._unlabeledSources.concat(sources);
|
|
49
41
|
}
|
|
50
42
|
else {
|
|
51
|
-
const currentSources = this.
|
|
43
|
+
const currentSources = this._labeledSources.get(label);
|
|
52
44
|
if (currentSources === undefined) {
|
|
53
|
-
this.
|
|
45
|
+
this._labeledSources = this._labeledSources.set(label, sources);
|
|
54
46
|
}
|
|
55
47
|
else {
|
|
56
|
-
this.
|
|
48
|
+
this._labeledSources = this._labeledSources.set(label, currentSources.concat(sources));
|
|
57
49
|
}
|
|
58
50
|
}
|
|
59
51
|
}
|
|
@@ -63,27 +55,19 @@ export class DatasetBuilder {
|
|
|
63
55
|
* @param label The file sources label
|
|
64
56
|
*/
|
|
65
57
|
clearFiles(label) {
|
|
66
|
-
if (this.built) {
|
|
67
|
-
this.resetBuiltState();
|
|
68
|
-
}
|
|
69
58
|
if (label === undefined) {
|
|
70
|
-
this.
|
|
59
|
+
this._unlabeledSources = [];
|
|
71
60
|
}
|
|
72
61
|
else {
|
|
73
|
-
this.
|
|
62
|
+
this._labeledSources = this._labeledSources.delete(label);
|
|
74
63
|
}
|
|
75
64
|
}
|
|
76
|
-
// If files are added or removed, then this should be called since the latest
|
|
77
|
-
// version of the dataset_builder has not yet been built.
|
|
78
|
-
resetBuiltState() {
|
|
79
|
-
this._built = false;
|
|
80
|
-
}
|
|
81
65
|
getLabels() {
|
|
82
66
|
// We need to duplicate the labels as we need one for each source.
|
|
83
67
|
// Say for label A we have sources [img1, img2, img3], then we
|
|
84
68
|
// need labels [A, A, A].
|
|
85
69
|
let labels = [];
|
|
86
|
-
this.
|
|
70
|
+
this._labeledSources.forEach((sources, label) => {
|
|
87
71
|
const sourcesLabels = Array.from({ length: sources.length }, (_) => label);
|
|
88
72
|
labels = labels.concat(sourcesLabels);
|
|
89
73
|
});
|
|
@@ -91,17 +75,17 @@ export class DatasetBuilder {
|
|
|
91
75
|
}
|
|
92
76
|
async build(config) {
|
|
93
77
|
// Require that at least one source collection is non-empty, but not both
|
|
94
|
-
if (
|
|
95
|
-
throw new Error('
|
|
78
|
+
if (this._unlabeledSources.length + this._labeledSources.size === 0) {
|
|
79
|
+
throw new Error('No input files connected'); // This error message is parsed in Trainer.vue
|
|
96
80
|
}
|
|
97
81
|
let dataTuple;
|
|
98
|
-
if (this.
|
|
82
|
+
if (this._unlabeledSources.length > 0) {
|
|
99
83
|
let defaultConfig = {};
|
|
100
84
|
if (config?.inference === true) {
|
|
101
85
|
// Inferring model, no labels needed
|
|
102
86
|
defaultConfig = {
|
|
103
87
|
features: this.task.trainingInformation.inputColumns,
|
|
104
|
-
shuffle:
|
|
88
|
+
shuffle: true
|
|
105
89
|
};
|
|
106
90
|
}
|
|
107
91
|
else {
|
|
@@ -109,34 +93,26 @@ export class DatasetBuilder {
|
|
|
109
93
|
defaultConfig = {
|
|
110
94
|
features: this.task.trainingInformation.inputColumns,
|
|
111
95
|
labels: this.task.trainingInformation.outputColumns,
|
|
112
|
-
shuffle:
|
|
96
|
+
shuffle: true
|
|
113
97
|
};
|
|
114
98
|
}
|
|
115
|
-
dataTuple = await this.dataLoader.loadAll(this.
|
|
99
|
+
dataTuple = await this.dataLoader.loadAll(this._unlabeledSources, { ...defaultConfig, ...config });
|
|
116
100
|
}
|
|
117
101
|
else {
|
|
118
102
|
// Labels are inferred from the file selection boxes
|
|
119
103
|
const defaultConfig = {
|
|
120
104
|
labels: this.getLabels(),
|
|
121
|
-
shuffle:
|
|
105
|
+
shuffle: true
|
|
122
106
|
};
|
|
123
|
-
const sources = this.
|
|
107
|
+
const sources = this._labeledSources.valueSeq().toArray().flat();
|
|
124
108
|
dataTuple = await this.dataLoader.loadAll(sources, { ...defaultConfig, ...config });
|
|
125
109
|
}
|
|
126
|
-
// TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
|
|
127
|
-
this._built = true;
|
|
128
110
|
return dataTuple;
|
|
129
111
|
}
|
|
130
|
-
/**
|
|
131
|
-
* Whether the dataset builder has already been consumed to produce a dataset.
|
|
132
|
-
*/
|
|
133
|
-
get built() {
|
|
134
|
-
return this._built;
|
|
135
|
-
}
|
|
136
112
|
get size() {
|
|
137
|
-
return Math.max(this.
|
|
113
|
+
return Math.max(this._unlabeledSources.length, this._labeledSources.size);
|
|
138
114
|
}
|
|
139
115
|
get sources() {
|
|
140
|
-
return this.
|
|
116
|
+
return this._unlabeledSources.length > 0 ? this._unlabeledSources : this._labeledSources.valueSeq().toArray().flat();
|
|
141
117
|
}
|
|
142
118
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { data, models } from '
|
|
3
|
-
import baseModel from '
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
import baseModel from '../models/mobileNet_v1_025_224.js';
|
|
4
4
|
export const cifar10 = {
|
|
5
5
|
getTask() {
|
|
6
6
|
return {
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
export { cifar10 } from './cifar10
|
|
1
|
+
export { cifar10 } from './cifar10.js';
|
|
2
2
|
export { lusCovid } from './lus_covid.js';
|
|
3
|
+
export { skinCondition } from './skin_condition.js';
|
|
3
4
|
export { mnist } from './mnist.js';
|
|
4
|
-
export { simpleFace } from './simple_face
|
|
5
|
+
export { simpleFace } from './simple_face.js';
|
|
5
6
|
export { titanic } from './titanic.js';
|
|
6
7
|
export { wikitext } from './wikitext.js';
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
export { cifar10 } from './cifar10
|
|
1
|
+
export { cifar10 } from './cifar10.js';
|
|
2
2
|
export { lusCovid } from './lus_covid.js';
|
|
3
|
+
export { skinCondition } from './skin_condition.js';
|
|
3
4
|
export { mnist } from './mnist.js';
|
|
4
|
-
export { simpleFace } from './simple_face
|
|
5
|
+
export { simpleFace } from './simple_face.js';
|
|
5
6
|
export { titanic } from './titanic.js';
|
|
6
7
|
export { wikitext } from './wikitext.js';
|
|
@@ -8,7 +8,7 @@ export const lusCovid = {
|
|
|
8
8
|
taskTitle: 'COVID Lung Ultrasound',
|
|
9
9
|
summary: {
|
|
10
10
|
preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
|
|
11
|
-
overview: "Don
|
|
11
|
+
overview: "Don't have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
|
|
12
12
|
},
|
|
13
13
|
model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
|
|
14
14
|
dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { data, models } from '
|
|
3
|
-
import baseModel from '
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js';
|
|
4
4
|
export const simpleFace = {
|
|
5
5
|
getTask() {
|
|
6
6
|
return {
|
|
@@ -12,7 +12,7 @@ export const simpleFace = {
|
|
|
12
12
|
overview: 'Simple face is a small subset of face_task from Kaggle'
|
|
13
13
|
},
|
|
14
14
|
dataFormatInformation: '',
|
|
15
|
-
dataExampleText: 'Below you find an example',
|
|
15
|
+
dataExampleText: 'Below you can find an example',
|
|
16
16
|
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
|
|
17
17
|
},
|
|
18
18
|
trainingInformation: {
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
const IMAGE_SIZE = 128;
|
|
4
|
+
const LABELS = ['Eczema', 'Allergic Contact Dermatitis', 'Urticaria'];
|
|
5
|
+
export const skinCondition = {
|
|
6
|
+
getTask() {
|
|
7
|
+
return {
|
|
8
|
+
id: 'skin_condition',
|
|
9
|
+
displayInformation: {
|
|
10
|
+
taskTitle: 'Skin Condition Classification',
|
|
11
|
+
summary: {
|
|
12
|
+
preview: "Identify common skin conditions from volunteer image contributions. You can find a sample dataset of 400 images <a class='underline text-primary-dark dark:text-primary-light' href='https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'>here</a> or see the full <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN dataset</a>. You can find how to download and preprocess the dataset <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/blob/develop/docs/examples/scin_dataset.ipynb'>in this notebook</a>.",
|
|
13
|
+
overview: "The <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN (Skin Condition Image Network) open access dataset</a> aims to supplement publicly available dermatology datasets from health system sources with representative images from internet users. To this end, the SCIN dataset was collected from Google Search users in the United States through a voluntary, consented image donation application. The SCIN dataset is intended for health education and research, and to increase the diversity of dermatology images available for public use. The SCIN dataset contains 5,000+ volunteer contributions (10,000+ images) of common dermatology conditions. Contributions include Images, self-reported demographic, history, and symptom information, and self-reported Fitzpatrick skin type (sFST). In addition, dermatologist labels of the skin condition are provided for each contribution. You can find more information on the dataset and classification task <a class='underline text-primary-dark dark:text-primary-light' href='https://arxiv.org/abs/2402.18545'>here</a>."
|
|
14
|
+
},
|
|
15
|
+
dataFormatInformation: "There are hundreds of skin condition labels in the SCIN dataset. For the sake of simplicity, we only include the 3 most common conditions in the sample dataset: 'Eczema', 'Allergic Contact Dermatitis' and 'Urticaria'. Therefore, each image is expected to be labeled with one of these three categories.",
|
|
16
|
+
sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'
|
|
17
|
+
},
|
|
18
|
+
trainingInformation: {
|
|
19
|
+
modelID: 'skin-condition-model',
|
|
20
|
+
epochs: 10,
|
|
21
|
+
roundDuration: 2,
|
|
22
|
+
validationSplit: 0.3,
|
|
23
|
+
batchSize: 8,
|
|
24
|
+
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
25
|
+
dataType: 'image',
|
|
26
|
+
IMAGE_H: IMAGE_SIZE,
|
|
27
|
+
IMAGE_W: IMAGE_SIZE,
|
|
28
|
+
LABEL_LIST: LABELS,
|
|
29
|
+
scheme: 'federated',
|
|
30
|
+
noiseScale: undefined,
|
|
31
|
+
clippingRadius: undefined
|
|
32
|
+
}
|
|
33
|
+
};
|
|
34
|
+
},
|
|
35
|
+
async getModel() {
|
|
36
|
+
const imageChannels = 3;
|
|
37
|
+
const numOutputClasses = LABELS.length;
|
|
38
|
+
const model = tf.sequential();
|
|
39
|
+
model.add(tf.layers.conv2d({
|
|
40
|
+
inputShape: [IMAGE_SIZE, IMAGE_SIZE, imageChannels],
|
|
41
|
+
filters: 8,
|
|
42
|
+
kernelSize: 3,
|
|
43
|
+
strides: 1,
|
|
44
|
+
kernelInitializer: 'varianceScaling',
|
|
45
|
+
activation: 'relu'
|
|
46
|
+
}));
|
|
47
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
|
|
48
|
+
model.add(tf.layers.dropout({ rate: 0.2 }));
|
|
49
|
+
const convFilters = [16, 32, 64, 128];
|
|
50
|
+
for (const filters of convFilters) {
|
|
51
|
+
model.add(tf.layers.conv2d({
|
|
52
|
+
filters: filters,
|
|
53
|
+
kernelSize: 3,
|
|
54
|
+
strides: 1,
|
|
55
|
+
kernelInitializer: 'varianceScaling',
|
|
56
|
+
activation: 'relu'
|
|
57
|
+
}));
|
|
58
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
|
|
59
|
+
model.add(tf.layers.dropout({ rate: 0.2 }));
|
|
60
|
+
}
|
|
61
|
+
model.add(tf.layers.flatten());
|
|
62
|
+
model.add(tf.layers.dense({
|
|
63
|
+
units: 64,
|
|
64
|
+
kernelInitializer: 'varianceScaling',
|
|
65
|
+
activation: 'relu',
|
|
66
|
+
}));
|
|
67
|
+
model.add(tf.layers.dense({
|
|
68
|
+
units: numOutputClasses,
|
|
69
|
+
kernelInitializer: 'varianceScaling',
|
|
70
|
+
activation: 'softmax'
|
|
71
|
+
}));
|
|
72
|
+
model.compile({
|
|
73
|
+
optimizer: tf.train.adam(),
|
|
74
|
+
loss: 'categoricalCrossentropy',
|
|
75
|
+
metrics: ['accuracy']
|
|
76
|
+
});
|
|
77
|
+
return Promise.resolve(new models.TFJS(model));
|
|
78
|
+
}
|
|
79
|
+
};
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
type ModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
|
|
2
|
+
export interface GPTConfig {
|
|
3
|
+
lr: number;
|
|
4
|
+
blockSize: number;
|
|
5
|
+
vocabSize: number;
|
|
6
|
+
modelType: ModelType;
|
|
7
|
+
name?: string;
|
|
8
|
+
evaluate?: boolean;
|
|
9
|
+
maxEvalBatches?: number;
|
|
10
|
+
evaluateEvery?: number;
|
|
11
|
+
maxIter?: number;
|
|
12
|
+
weightDecay?: number;
|
|
13
|
+
verbose?: 0 | 1;
|
|
14
|
+
bias?: boolean;
|
|
15
|
+
debug?: boolean;
|
|
16
|
+
dropout?: number;
|
|
17
|
+
residDrop?: number;
|
|
18
|
+
embdDrop?: number;
|
|
19
|
+
tokEmb?: boolean;
|
|
20
|
+
lmHead?: boolean;
|
|
21
|
+
nLayer?: number;
|
|
22
|
+
nHead?: number;
|
|
23
|
+
nEmbd?: number;
|
|
24
|
+
}
|
|
25
|
+
export declare const DEFAULT_CONFIG: Required<GPTConfig>;
|
|
26
|
+
export type ModelSize = {
|
|
27
|
+
nLayer: number;
|
|
28
|
+
nHead: number;
|
|
29
|
+
nEmbd: number;
|
|
30
|
+
};
|
|
31
|
+
export declare function getModelSizes(modelType: ModelType): Required<ModelSize>;
|
|
32
|
+
export {};
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
|
|
2
|
+
export const DEFAULT_CONFIG = {
|
|
3
|
+
name: 'transformer',
|
|
4
|
+
lr: 0.001,
|
|
5
|
+
weightDecay: 0,
|
|
6
|
+
maxIter: 5,
|
|
7
|
+
verbose: 0,
|
|
8
|
+
modelType: 'gpt-nano',
|
|
9
|
+
evaluate: true,
|
|
10
|
+
maxEvalBatches: 12,
|
|
11
|
+
evaluateEvery: 100,
|
|
12
|
+
blockSize: 128,
|
|
13
|
+
vocabSize: 50258,
|
|
14
|
+
bias: true,
|
|
15
|
+
debug: false,
|
|
16
|
+
dropout: 0.2,
|
|
17
|
+
residDrop: 0.2,
|
|
18
|
+
embdDrop: 0.2,
|
|
19
|
+
tokEmb: true,
|
|
20
|
+
lmHead: true,
|
|
21
|
+
nLayer: 3,
|
|
22
|
+
nHead: 3,
|
|
23
|
+
nEmbd: 48,
|
|
24
|
+
};
|
|
25
|
+
export function getModelSizes(modelType) {
|
|
26
|
+
switch (modelType) {
|
|
27
|
+
case 'gpt2':
|
|
28
|
+
return { nLayer: 12, nHead: 12, nEmbd: 768 };
|
|
29
|
+
case 'gpt2-medium':
|
|
30
|
+
return { nLayer: 24, nHead: 16, nEmbd: 1024 };
|
|
31
|
+
case 'gpt2-large':
|
|
32
|
+
return { nLayer: 36, nHead: 20, nEmbd: 1280 };
|
|
33
|
+
case 'gpt2-xl':
|
|
34
|
+
return { nLayer: 48, nHead: 25, nEmbd: 1600 };
|
|
35
|
+
case 'gpt-mini':
|
|
36
|
+
return { nLayer: 6, nHead: 6, nEmbd: 192 };
|
|
37
|
+
case 'gpt-micro':
|
|
38
|
+
return { nLayer: 4, nHead: 4, nEmbd: 128 };
|
|
39
|
+
case 'gpt-nano':
|
|
40
|
+
return { nLayer: 3, nHead: 3, nEmbd: 48 };
|
|
41
|
+
}
|
|
42
|
+
}
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
interface DataPoint extends tf.TensorContainerObject {
|
|
3
|
+
xs: tf.Tensor2D;
|
|
4
|
+
ys: tf.Tensor3D;
|
|
5
|
+
}
|
|
6
|
+
export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>>;
|
|
7
|
+
export {};
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
export default async function evaluate(model, dataset, maxEvalBatches) {
|
|
3
|
+
let datasetSize = 0;
|
|
4
|
+
let totalLoss = 0;
|
|
5
|
+
const acc = [0, 0];
|
|
6
|
+
await dataset.take(maxEvalBatches).map(({ xs, ys }) => {
|
|
7
|
+
const logits = model.apply(xs);
|
|
8
|
+
if (Array.isArray(logits)) {
|
|
9
|
+
throw new Error('model output too many tensor');
|
|
10
|
+
}
|
|
11
|
+
if (logits instanceof tf.SymbolicTensor) {
|
|
12
|
+
throw new Error('model output symbolic tensor');
|
|
13
|
+
}
|
|
14
|
+
xs.dispose();
|
|
15
|
+
return { logits, ys };
|
|
16
|
+
}).mapAsync(async ({ logits, ys }) => {
|
|
17
|
+
const lossTensor = tf.losses.softmaxCrossEntropy(ys, logits);
|
|
18
|
+
const loss = await lossTensor.array();
|
|
19
|
+
if (typeof loss !== 'number') {
|
|
20
|
+
throw new Error('got multiple loss');
|
|
21
|
+
}
|
|
22
|
+
const accTensor = tf.metrics.categoricalAccuracy(ys, logits);
|
|
23
|
+
const accSize = accTensor.shape.reduce((l, r) => l * r, 1);
|
|
24
|
+
const accSum = accTensor.sum();
|
|
25
|
+
const accSummed = await accSum.array();
|
|
26
|
+
if (typeof accSummed !== 'number') {
|
|
27
|
+
throw new Error('got multiple accuracy sum');
|
|
28
|
+
}
|
|
29
|
+
tf.dispose([ys, logits, accTensor, accSum, lossTensor]);
|
|
30
|
+
return { loss, accSummed, accSize };
|
|
31
|
+
}).forEachAsync(({ loss, accSummed, accSize }) => {
|
|
32
|
+
datasetSize += 1;
|
|
33
|
+
totalLoss += loss;
|
|
34
|
+
acc[0] += accSummed;
|
|
35
|
+
acc[1] += accSize;
|
|
36
|
+
});
|
|
37
|
+
const loss = totalLoss / datasetSize;
|
|
38
|
+
return {
|
|
39
|
+
val_loss: loss,
|
|
40
|
+
val_perplexity: Math.exp(loss),
|
|
41
|
+
acc: acc[0] / acc[1],
|
|
42
|
+
val_acc: acc[0] / acc[1]
|
|
43
|
+
};
|
|
44
|
+
}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
5
|
+
import { WeightsContainer } from '../../index.js';
|
|
6
|
+
import type { Dataset } from '../../dataset/index.js';
|
|
7
|
+
import { Model } from '../model.js';
|
|
8
|
+
import type { EpochLogs, Prediction, Sample } from '../model.js';
|
|
9
|
+
import type { GPTConfig } from './config.js';
|
|
10
|
+
export declare class GPT extends Model {
|
|
11
|
+
private readonly model;
|
|
12
|
+
constructor(partialConfig?: GPTConfig);
|
|
13
|
+
/**
|
|
14
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
15
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
16
|
+
*
|
|
17
|
+
* @param trainingData training dataset
|
|
18
|
+
* @param validationData validation dataset
|
|
19
|
+
* @param epochs the number of passes of the training dataset
|
|
20
|
+
* @param tracker
|
|
21
|
+
*/
|
|
22
|
+
train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
|
|
23
|
+
predict(input: Sample): Promise<Prediction>;
|
|
24
|
+
generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
|
|
25
|
+
get config(): Required<GPTConfig>;
|
|
26
|
+
get weights(): WeightsContainer;
|
|
27
|
+
set weights(ws: WeightsContainer);
|
|
28
|
+
static deserialize(data: GPTSerialization): Model;
|
|
29
|
+
serialize(): GPTSerialization;
|
|
30
|
+
[Symbol.dispose](): void;
|
|
31
|
+
}
|
|
32
|
+
export type GPTSerialization = {
|
|
33
|
+
weights: WeightsContainer;
|
|
34
|
+
config?: GPTConfig;
|
|
35
|
+
};
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import { WeightsContainer } from '../../index.js';
|
|
5
|
+
import { Model } from '../model.js';
|
|
6
|
+
import { GPTForCausalLM } from './model.js';
|
|
7
|
+
export class GPT extends Model {
|
|
8
|
+
model;
|
|
9
|
+
constructor(partialConfig) {
|
|
10
|
+
super();
|
|
11
|
+
this.model = new GPTForCausalLM(partialConfig);
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
15
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
16
|
+
*
|
|
17
|
+
* @param trainingData training dataset
|
|
18
|
+
* @param validationData validation dataset
|
|
19
|
+
* @param epochs the number of passes of the training dataset
|
|
20
|
+
* @param tracker
|
|
21
|
+
*/
|
|
22
|
+
async *train(trainingData, validationData, epochs = 1) {
|
|
23
|
+
this.model.compile();
|
|
24
|
+
let logs;
|
|
25
|
+
const trainingArgs = {
|
|
26
|
+
epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
|
|
27
|
+
validationData,
|
|
28
|
+
callbacks: { onEpochEnd: (_, cur) => { logs = cur; } },
|
|
29
|
+
};
|
|
30
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
31
|
+
await this.model.fitDataset(trainingData, trainingArgs);
|
|
32
|
+
if (logs === undefined) {
|
|
33
|
+
throw new Error("Epoch didn't gave any logs");
|
|
34
|
+
}
|
|
35
|
+
const { loss, val_acc, val_loss, peakMemory } = logs;
|
|
36
|
+
if (loss === undefined || isNaN(loss)) {
|
|
37
|
+
throw new Error("Training loss is undefined or nan");
|
|
38
|
+
}
|
|
39
|
+
const structuredLogs = {
|
|
40
|
+
epoch,
|
|
41
|
+
peakMemory,
|
|
42
|
+
training: {
|
|
43
|
+
loss: logs.loss
|
|
44
|
+
}
|
|
45
|
+
};
|
|
46
|
+
if (validationData !== undefined) {
|
|
47
|
+
if (val_loss === undefined || isNaN(val_loss) ||
|
|
48
|
+
val_acc === undefined || isNaN(val_acc)) {
|
|
49
|
+
throw new Error("Invalid validation logs");
|
|
50
|
+
}
|
|
51
|
+
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss };
|
|
52
|
+
}
|
|
53
|
+
yield structuredLogs;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
predict(input) {
|
|
57
|
+
const ret = this.model.predict(input);
|
|
58
|
+
if (Array.isArray(ret)) {
|
|
59
|
+
throw new Error('prediction yield many Tensors but should have only returned one');
|
|
60
|
+
}
|
|
61
|
+
return Promise.resolve(ret);
|
|
62
|
+
}
|
|
63
|
+
async generate(input, tokenizer, newTokens = 10) {
|
|
64
|
+
const { input_ids: tokens } = await tokenizer(input, { return_tensor: false });
|
|
65
|
+
const generationConfig = {
|
|
66
|
+
maxNewTokens: newTokens,
|
|
67
|
+
temperature: 1.0,
|
|
68
|
+
doSample: false
|
|
69
|
+
};
|
|
70
|
+
const predictedTokens = await this.model.generate(tokens, generationConfig);
|
|
71
|
+
const generatedWords = tokenizer.decode(predictedTokens[0]);
|
|
72
|
+
return generatedWords;
|
|
73
|
+
}
|
|
74
|
+
get config() {
|
|
75
|
+
return this.model.getGPTConfig;
|
|
76
|
+
}
|
|
77
|
+
get weights() {
|
|
78
|
+
return new WeightsContainer(this.model.weights.map((w) => w.read()));
|
|
79
|
+
}
|
|
80
|
+
set weights(ws) {
|
|
81
|
+
this.model.setWeights(ws.weights);
|
|
82
|
+
}
|
|
83
|
+
static deserialize(data) {
|
|
84
|
+
const model = new GPT(data.config);
|
|
85
|
+
model.weights = data.weights;
|
|
86
|
+
return model;
|
|
87
|
+
}
|
|
88
|
+
serialize() {
|
|
89
|
+
return {
|
|
90
|
+
weights: this.weights,
|
|
91
|
+
config: this.config
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
[Symbol.dispose]() {
|
|
95
|
+
console.log("Disposing model");
|
|
96
|
+
if (this.model.optimizer !== undefined) {
|
|
97
|
+
this.model.optimizer.dispose();
|
|
98
|
+
}
|
|
99
|
+
// Some tensors are not cleaned up when model.dispose is called
|
|
100
|
+
// So we dispose them manually
|
|
101
|
+
this.model.disposeRefs();
|
|
102
|
+
this.model.dispose();
|
|
103
|
+
}
|
|
104
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { GPTConfig } from './config.js';
|
|
3
|
+
/**
|
|
4
|
+
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
5
|
+
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
6
|
+
* used to create a GPTModel
|
|
7
|
+
*
|
|
8
|
+
* @param conf GPTConfig
|
|
9
|
+
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
|
|
10
|
+
*/
|
|
11
|
+
export declare function GPTArchitecture(config: Required<GPTConfig>, disposalRefs: tf.TensorContainer[], peakMemory: {
|
|
12
|
+
value: number;
|
|
13
|
+
}): tf.LayersModel;
|