@epfml/discojs 2.1.2-p20240515133413.0 → 2.1.2-p20240531085945.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.js +1 -0
- package/dist/aggregator/mean.d.ts +10 -15
- package/dist/aggregator/mean.js +36 -50
- package/dist/aggregator/secure.d.ts +5 -7
- package/dist/aggregator/secure.js +56 -44
- package/dist/client/federated/messages.d.ts +1 -8
- package/dist/client/federated/messages.js +1 -10
- package/dist/client/messages.d.ts +1 -3
- package/dist/client/messages.js +0 -2
- package/dist/dataset/dataset_builder.d.ts +2 -11
- package/dist/dataset/dataset_builder.js +22 -46
- package/dist/default_tasks/cifar10.d.ts +2 -0
- package/dist/default_tasks/{cifar10/index.js → cifar10.js} +2 -2
- package/dist/default_tasks/index.d.ts +3 -2
- package/dist/default_tasks/index.js +3 -2
- package/dist/default_tasks/lus_covid.js +1 -1
- package/dist/default_tasks/simple_face.d.ts +2 -0
- package/dist/default_tasks/{simple_face/index.js → simple_face.js} +3 -3
- package/dist/default_tasks/skin_condition.d.ts +2 -0
- package/dist/default_tasks/skin_condition.js +79 -0
- package/dist/models/gpt/config.d.ts +32 -0
- package/dist/models/gpt/config.js +42 -0
- package/dist/models/gpt/evaluate.d.ts +7 -0
- package/dist/models/gpt/evaluate.js +44 -0
- package/dist/models/gpt/index.d.ts +35 -0
- package/dist/models/gpt/index.js +104 -0
- package/dist/models/gpt/layers.d.ts +13 -0
- package/dist/models/gpt/layers.js +272 -0
- package/dist/models/gpt/model.d.ts +43 -0
- package/dist/models/gpt/model.js +191 -0
- package/dist/models/gpt/optimizers.d.ts +4 -0
- package/dist/models/gpt/optimizers.js +95 -0
- package/dist/models/index.d.ts +5 -0
- package/dist/models/index.js +4 -0
- package/dist/{default_tasks/simple_face/model.js → models/mobileNetV2_35_alpha_2_classes.js} +2 -0
- package/dist/{default_tasks/cifar10/model.js → models/mobileNet_v1_025_224.js} +1 -0
- package/dist/models/model.d.ts +51 -0
- package/dist/models/model.js +8 -0
- package/dist/models/tfjs.d.ts +24 -0
- package/dist/models/tfjs.js +107 -0
- package/dist/models/tokenizer.d.ts +14 -0
- package/dist/models/tokenizer.js +22 -0
- package/dist/validation/validator.js +8 -7
- package/package.json +1 -1
- package/dist/default_tasks/cifar10/index.d.ts +0 -2
- package/dist/default_tasks/simple_face/index.d.ts +0 -2
- /package/dist/{default_tasks/simple_face/model.d.ts → models/mobileNetV2_35_alpha_2_classes.d.ts} +0 -0
- /package/dist/{default_tasks/cifar10/model.d.ts → models/mobileNet_v1_025_224.d.ts} +0 -0
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
const IMAGE_SIZE = 128;
|
|
4
|
+
const LABELS = ['Eczema', 'Allergic Contact Dermatitis', 'Urticaria'];
|
|
5
|
+
export const skinCondition = {
|
|
6
|
+
getTask() {
|
|
7
|
+
return {
|
|
8
|
+
id: 'skin_condition',
|
|
9
|
+
displayInformation: {
|
|
10
|
+
taskTitle: 'Skin Condition Classification',
|
|
11
|
+
summary: {
|
|
12
|
+
preview: "Identify common skin conditions from volunteer image contributions. You can find a sample dataset of 400 images <a class='underline text-primary-dark dark:text-primary-light' href='https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'>here</a> or see the full <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN dataset</a>. You can find how to download and preprocess the dataset <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/blob/develop/docs/examples/scin_dataset.ipynb'>in this notebook</a>.",
|
|
13
|
+
overview: "The <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/google-research-datasets/scin/tree/main'>SCIN (Skin Condition Image Network) open access dataset</a> aims to supplement publicly available dermatology datasets from health system sources with representative images from internet users. To this end, the SCIN dataset was collected from Google Search users in the United States through a voluntary, consented image donation application. The SCIN dataset is intended for health education and research, and to increase the diversity of dermatology images available for public use. The SCIN dataset contains 5,000+ volunteer contributions (10,000+ images) of common dermatology conditions. Contributions include Images, self-reported demographic, history, and symptom information, and self-reported Fitzpatrick skin type (sFST). In addition, dermatologist labels of the skin condition are provided for each contribution. You can find more information on the dataset and classification task <a class='underline text-primary-dark dark:text-primary-light' href='https://arxiv.org/abs/2402.18545'>here</a>."
|
|
14
|
+
},
|
|
15
|
+
dataFormatInformation: "There are hundreds of skin condition labels in the SCIN dataset. For the sake of simplicity, we only include the 3 most common conditions in the sample dataset: 'Eczema', 'Allergic Contact Dermatitis' and 'Urticaria'. Therefore, each image is expected to be labeled with one of these three categories.",
|
|
16
|
+
sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/scin_sample.zip'
|
|
17
|
+
},
|
|
18
|
+
trainingInformation: {
|
|
19
|
+
modelID: 'skin-condition-model',
|
|
20
|
+
epochs: 10,
|
|
21
|
+
roundDuration: 2,
|
|
22
|
+
validationSplit: 0.3,
|
|
23
|
+
batchSize: 8,
|
|
24
|
+
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
25
|
+
dataType: 'image',
|
|
26
|
+
IMAGE_H: IMAGE_SIZE,
|
|
27
|
+
IMAGE_W: IMAGE_SIZE,
|
|
28
|
+
LABEL_LIST: LABELS,
|
|
29
|
+
scheme: 'federated',
|
|
30
|
+
noiseScale: undefined,
|
|
31
|
+
clippingRadius: undefined
|
|
32
|
+
}
|
|
33
|
+
};
|
|
34
|
+
},
|
|
35
|
+
async getModel() {
|
|
36
|
+
const imageChannels = 3;
|
|
37
|
+
const numOutputClasses = LABELS.length;
|
|
38
|
+
const model = tf.sequential();
|
|
39
|
+
model.add(tf.layers.conv2d({
|
|
40
|
+
inputShape: [IMAGE_SIZE, IMAGE_SIZE, imageChannels],
|
|
41
|
+
filters: 8,
|
|
42
|
+
kernelSize: 3,
|
|
43
|
+
strides: 1,
|
|
44
|
+
kernelInitializer: 'varianceScaling',
|
|
45
|
+
activation: 'relu'
|
|
46
|
+
}));
|
|
47
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
|
|
48
|
+
model.add(tf.layers.dropout({ rate: 0.2 }));
|
|
49
|
+
const convFilters = [16, 32, 64, 128];
|
|
50
|
+
for (const filters of convFilters) {
|
|
51
|
+
model.add(tf.layers.conv2d({
|
|
52
|
+
filters: filters,
|
|
53
|
+
kernelSize: 3,
|
|
54
|
+
strides: 1,
|
|
55
|
+
kernelInitializer: 'varianceScaling',
|
|
56
|
+
activation: 'relu'
|
|
57
|
+
}));
|
|
58
|
+
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2] }));
|
|
59
|
+
model.add(tf.layers.dropout({ rate: 0.2 }));
|
|
60
|
+
}
|
|
61
|
+
model.add(tf.layers.flatten());
|
|
62
|
+
model.add(tf.layers.dense({
|
|
63
|
+
units: 64,
|
|
64
|
+
kernelInitializer: 'varianceScaling',
|
|
65
|
+
activation: 'relu',
|
|
66
|
+
}));
|
|
67
|
+
model.add(tf.layers.dense({
|
|
68
|
+
units: numOutputClasses,
|
|
69
|
+
kernelInitializer: 'varianceScaling',
|
|
70
|
+
activation: 'softmax'
|
|
71
|
+
}));
|
|
72
|
+
model.compile({
|
|
73
|
+
optimizer: tf.train.adam(),
|
|
74
|
+
loss: 'categoricalCrossentropy',
|
|
75
|
+
metrics: ['accuracy']
|
|
76
|
+
});
|
|
77
|
+
return Promise.resolve(new models.TFJS(model));
|
|
78
|
+
}
|
|
79
|
+
};
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
type ModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
|
|
2
|
+
export interface GPTConfig {
|
|
3
|
+
lr: number;
|
|
4
|
+
blockSize: number;
|
|
5
|
+
vocabSize: number;
|
|
6
|
+
modelType: ModelType;
|
|
7
|
+
name?: string;
|
|
8
|
+
evaluate?: boolean;
|
|
9
|
+
maxEvalBatches?: number;
|
|
10
|
+
evaluateEvery?: number;
|
|
11
|
+
maxIter?: number;
|
|
12
|
+
weightDecay?: number;
|
|
13
|
+
verbose?: 0 | 1;
|
|
14
|
+
bias?: boolean;
|
|
15
|
+
debug?: boolean;
|
|
16
|
+
dropout?: number;
|
|
17
|
+
residDrop?: number;
|
|
18
|
+
embdDrop?: number;
|
|
19
|
+
tokEmb?: boolean;
|
|
20
|
+
lmHead?: boolean;
|
|
21
|
+
nLayer?: number;
|
|
22
|
+
nHead?: number;
|
|
23
|
+
nEmbd?: number;
|
|
24
|
+
}
|
|
25
|
+
export declare const DEFAULT_CONFIG: Required<GPTConfig>;
|
|
26
|
+
export type ModelSize = {
|
|
27
|
+
nLayer: number;
|
|
28
|
+
nHead: number;
|
|
29
|
+
nEmbd: number;
|
|
30
|
+
};
|
|
31
|
+
export declare function getModelSizes(modelType: ModelType): Required<ModelSize>;
|
|
32
|
+
export {};
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
|
|
2
|
+
export const DEFAULT_CONFIG = {
|
|
3
|
+
name: 'transformer',
|
|
4
|
+
lr: 0.001,
|
|
5
|
+
weightDecay: 0,
|
|
6
|
+
maxIter: 5,
|
|
7
|
+
verbose: 0,
|
|
8
|
+
modelType: 'gpt-nano',
|
|
9
|
+
evaluate: true,
|
|
10
|
+
maxEvalBatches: 12,
|
|
11
|
+
evaluateEvery: 100,
|
|
12
|
+
blockSize: 128,
|
|
13
|
+
vocabSize: 50258,
|
|
14
|
+
bias: true,
|
|
15
|
+
debug: false,
|
|
16
|
+
dropout: 0.2,
|
|
17
|
+
residDrop: 0.2,
|
|
18
|
+
embdDrop: 0.2,
|
|
19
|
+
tokEmb: true,
|
|
20
|
+
lmHead: true,
|
|
21
|
+
nLayer: 3,
|
|
22
|
+
nHead: 3,
|
|
23
|
+
nEmbd: 48,
|
|
24
|
+
};
|
|
25
|
+
export function getModelSizes(modelType) {
|
|
26
|
+
switch (modelType) {
|
|
27
|
+
case 'gpt2':
|
|
28
|
+
return { nLayer: 12, nHead: 12, nEmbd: 768 };
|
|
29
|
+
case 'gpt2-medium':
|
|
30
|
+
return { nLayer: 24, nHead: 16, nEmbd: 1024 };
|
|
31
|
+
case 'gpt2-large':
|
|
32
|
+
return { nLayer: 36, nHead: 20, nEmbd: 1280 };
|
|
33
|
+
case 'gpt2-xl':
|
|
34
|
+
return { nLayer: 48, nHead: 25, nEmbd: 1600 };
|
|
35
|
+
case 'gpt-mini':
|
|
36
|
+
return { nLayer: 6, nHead: 6, nEmbd: 192 };
|
|
37
|
+
case 'gpt-micro':
|
|
38
|
+
return { nLayer: 4, nHead: 4, nEmbd: 128 };
|
|
39
|
+
case 'gpt-nano':
|
|
40
|
+
return { nLayer: 3, nHead: 3, nEmbd: 48 };
|
|
41
|
+
}
|
|
42
|
+
}
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
interface DataPoint extends tf.TensorContainerObject {
|
|
3
|
+
xs: tf.Tensor2D;
|
|
4
|
+
ys: tf.Tensor3D;
|
|
5
|
+
}
|
|
6
|
+
export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>>;
|
|
7
|
+
export {};
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
export default async function evaluate(model, dataset, maxEvalBatches) {
|
|
3
|
+
let datasetSize = 0;
|
|
4
|
+
let totalLoss = 0;
|
|
5
|
+
const acc = [0, 0];
|
|
6
|
+
await dataset.take(maxEvalBatches).map(({ xs, ys }) => {
|
|
7
|
+
const logits = model.apply(xs);
|
|
8
|
+
if (Array.isArray(logits)) {
|
|
9
|
+
throw new Error('model output too many tensor');
|
|
10
|
+
}
|
|
11
|
+
if (logits instanceof tf.SymbolicTensor) {
|
|
12
|
+
throw new Error('model output symbolic tensor');
|
|
13
|
+
}
|
|
14
|
+
xs.dispose();
|
|
15
|
+
return { logits, ys };
|
|
16
|
+
}).mapAsync(async ({ logits, ys }) => {
|
|
17
|
+
const lossTensor = tf.losses.softmaxCrossEntropy(ys, logits);
|
|
18
|
+
const loss = await lossTensor.array();
|
|
19
|
+
if (typeof loss !== 'number') {
|
|
20
|
+
throw new Error('got multiple loss');
|
|
21
|
+
}
|
|
22
|
+
const accTensor = tf.metrics.categoricalAccuracy(ys, logits);
|
|
23
|
+
const accSize = accTensor.shape.reduce((l, r) => l * r, 1);
|
|
24
|
+
const accSum = accTensor.sum();
|
|
25
|
+
const accSummed = await accSum.array();
|
|
26
|
+
if (typeof accSummed !== 'number') {
|
|
27
|
+
throw new Error('got multiple accuracy sum');
|
|
28
|
+
}
|
|
29
|
+
tf.dispose([ys, logits, accTensor, accSum, lossTensor]);
|
|
30
|
+
return { loss, accSummed, accSize };
|
|
31
|
+
}).forEachAsync(({ loss, accSummed, accSize }) => {
|
|
32
|
+
datasetSize += 1;
|
|
33
|
+
totalLoss += loss;
|
|
34
|
+
acc[0] += accSummed;
|
|
35
|
+
acc[1] += accSize;
|
|
36
|
+
});
|
|
37
|
+
const loss = totalLoss / datasetSize;
|
|
38
|
+
return {
|
|
39
|
+
val_loss: loss,
|
|
40
|
+
val_perplexity: Math.exp(loss),
|
|
41
|
+
acc: acc[0] / acc[1],
|
|
42
|
+
val_acc: acc[0] / acc[1]
|
|
43
|
+
};
|
|
44
|
+
}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
5
|
+
import { WeightsContainer } from '../../index.js';
|
|
6
|
+
import type { Dataset } from '../../dataset/index.js';
|
|
7
|
+
import { Model } from '../model.js';
|
|
8
|
+
import type { EpochLogs, Prediction, Sample } from '../model.js';
|
|
9
|
+
import type { GPTConfig } from './config.js';
|
|
10
|
+
export declare class GPT extends Model {
|
|
11
|
+
private readonly model;
|
|
12
|
+
constructor(partialConfig?: GPTConfig);
|
|
13
|
+
/**
|
|
14
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
15
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
16
|
+
*
|
|
17
|
+
* @param trainingData training dataset
|
|
18
|
+
* @param validationData validation dataset
|
|
19
|
+
* @param epochs the number of passes of the training dataset
|
|
20
|
+
* @param tracker
|
|
21
|
+
*/
|
|
22
|
+
train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
|
|
23
|
+
predict(input: Sample): Promise<Prediction>;
|
|
24
|
+
generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
|
|
25
|
+
get config(): Required<GPTConfig>;
|
|
26
|
+
get weights(): WeightsContainer;
|
|
27
|
+
set weights(ws: WeightsContainer);
|
|
28
|
+
static deserialize(data: GPTSerialization): Model;
|
|
29
|
+
serialize(): GPTSerialization;
|
|
30
|
+
[Symbol.dispose](): void;
|
|
31
|
+
}
|
|
32
|
+
export type GPTSerialization = {
|
|
33
|
+
weights: WeightsContainer;
|
|
34
|
+
config?: GPTConfig;
|
|
35
|
+
};
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import { WeightsContainer } from '../../index.js';
|
|
5
|
+
import { Model } from '../model.js';
|
|
6
|
+
import { GPTForCausalLM } from './model.js';
|
|
7
|
+
export class GPT extends Model {
|
|
8
|
+
model;
|
|
9
|
+
constructor(partialConfig) {
|
|
10
|
+
super();
|
|
11
|
+
this.model = new GPTForCausalLM(partialConfig);
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
15
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
16
|
+
*
|
|
17
|
+
* @param trainingData training dataset
|
|
18
|
+
* @param validationData validation dataset
|
|
19
|
+
* @param epochs the number of passes of the training dataset
|
|
20
|
+
* @param tracker
|
|
21
|
+
*/
|
|
22
|
+
async *train(trainingData, validationData, epochs = 1) {
|
|
23
|
+
this.model.compile();
|
|
24
|
+
let logs;
|
|
25
|
+
const trainingArgs = {
|
|
26
|
+
epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
|
|
27
|
+
validationData,
|
|
28
|
+
callbacks: { onEpochEnd: (_, cur) => { logs = cur; } },
|
|
29
|
+
};
|
|
30
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
31
|
+
await this.model.fitDataset(trainingData, trainingArgs);
|
|
32
|
+
if (logs === undefined) {
|
|
33
|
+
throw new Error("Epoch didn't gave any logs");
|
|
34
|
+
}
|
|
35
|
+
const { loss, val_acc, val_loss, peakMemory } = logs;
|
|
36
|
+
if (loss === undefined || isNaN(loss)) {
|
|
37
|
+
throw new Error("Training loss is undefined or nan");
|
|
38
|
+
}
|
|
39
|
+
const structuredLogs = {
|
|
40
|
+
epoch,
|
|
41
|
+
peakMemory,
|
|
42
|
+
training: {
|
|
43
|
+
loss: logs.loss
|
|
44
|
+
}
|
|
45
|
+
};
|
|
46
|
+
if (validationData !== undefined) {
|
|
47
|
+
if (val_loss === undefined || isNaN(val_loss) ||
|
|
48
|
+
val_acc === undefined || isNaN(val_acc)) {
|
|
49
|
+
throw new Error("Invalid validation logs");
|
|
50
|
+
}
|
|
51
|
+
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss };
|
|
52
|
+
}
|
|
53
|
+
yield structuredLogs;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
predict(input) {
|
|
57
|
+
const ret = this.model.predict(input);
|
|
58
|
+
if (Array.isArray(ret)) {
|
|
59
|
+
throw new Error('prediction yield many Tensors but should have only returned one');
|
|
60
|
+
}
|
|
61
|
+
return Promise.resolve(ret);
|
|
62
|
+
}
|
|
63
|
+
async generate(input, tokenizer, newTokens = 10) {
|
|
64
|
+
const { input_ids: tokens } = await tokenizer(input, { return_tensor: false });
|
|
65
|
+
const generationConfig = {
|
|
66
|
+
maxNewTokens: newTokens,
|
|
67
|
+
temperature: 1.0,
|
|
68
|
+
doSample: false
|
|
69
|
+
};
|
|
70
|
+
const predictedTokens = await this.model.generate(tokens, generationConfig);
|
|
71
|
+
const generatedWords = tokenizer.decode(predictedTokens[0]);
|
|
72
|
+
return generatedWords;
|
|
73
|
+
}
|
|
74
|
+
get config() {
|
|
75
|
+
return this.model.getGPTConfig;
|
|
76
|
+
}
|
|
77
|
+
get weights() {
|
|
78
|
+
return new WeightsContainer(this.model.weights.map((w) => w.read()));
|
|
79
|
+
}
|
|
80
|
+
set weights(ws) {
|
|
81
|
+
this.model.setWeights(ws.weights);
|
|
82
|
+
}
|
|
83
|
+
static deserialize(data) {
|
|
84
|
+
const model = new GPT(data.config);
|
|
85
|
+
model.weights = data.weights;
|
|
86
|
+
return model;
|
|
87
|
+
}
|
|
88
|
+
serialize() {
|
|
89
|
+
return {
|
|
90
|
+
weights: this.weights,
|
|
91
|
+
config: this.config
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
[Symbol.dispose]() {
|
|
95
|
+
console.log("Disposing model");
|
|
96
|
+
if (this.model.optimizer !== undefined) {
|
|
97
|
+
this.model.optimizer.dispose();
|
|
98
|
+
}
|
|
99
|
+
// Some tensors are not cleaned up when model.dispose is called
|
|
100
|
+
// So we dispose them manually
|
|
101
|
+
this.model.disposeRefs();
|
|
102
|
+
this.model.dispose();
|
|
103
|
+
}
|
|
104
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { GPTConfig } from './config.js';
|
|
3
|
+
/**
|
|
4
|
+
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
5
|
+
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
6
|
+
* used to create a GPTModel
|
|
7
|
+
*
|
|
8
|
+
* @param conf GPTConfig
|
|
9
|
+
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
|
|
10
|
+
*/
|
|
11
|
+
export declare function GPTArchitecture(config: Required<GPTConfig>, disposalRefs: tf.TensorContainer[], peakMemory: {
|
|
12
|
+
value: number;
|
|
13
|
+
}): tf.LayersModel;
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
/**
|
|
3
|
+
* Defines a range, from 0 to T, that is used to create positional embeddings
|
|
4
|
+
*/
|
|
5
|
+
class Range extends tf.layers.Layer {
|
|
6
|
+
static className = 'Range';
|
|
7
|
+
computeOutputShape(inputShape) {
|
|
8
|
+
return inputShape;
|
|
9
|
+
}
|
|
10
|
+
call(input, kwargs) {
|
|
11
|
+
return tf.tidy(() => {
|
|
12
|
+
if (Array.isArray(input)) {
|
|
13
|
+
// TODO support multitensor
|
|
14
|
+
input = input[0];
|
|
15
|
+
}
|
|
16
|
+
this.invokeCallHook(input, kwargs);
|
|
17
|
+
const T = input.shape[1];
|
|
18
|
+
if (T === undefined)
|
|
19
|
+
throw new Error('unexpected shape');
|
|
20
|
+
return tf.reshape(tf.range(0, T, 1, 'int32'), [1, T]);
|
|
21
|
+
});
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
tf.serialization.registerClass(Range);
|
|
25
|
+
class LogLayer extends tf.layers.Layer {
|
|
26
|
+
static className = 'LogLayer';
|
|
27
|
+
computeOutputShape(inputShape) {
|
|
28
|
+
return inputShape;
|
|
29
|
+
}
|
|
30
|
+
call(input, kwargs) {
|
|
31
|
+
return tf.tidy(() => {
|
|
32
|
+
if (Array.isArray(input)) {
|
|
33
|
+
input = input[0];
|
|
34
|
+
}
|
|
35
|
+
this.invokeCallHook(input, kwargs);
|
|
36
|
+
return input;
|
|
37
|
+
});
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
tf.serialization.registerClass(LogLayer);
|
|
41
|
+
class CausalSelfAttention extends tf.layers.Layer {
|
|
42
|
+
config;
|
|
43
|
+
peakMemory;
|
|
44
|
+
static className = 'CausalSelfAttention';
|
|
45
|
+
nHead;
|
|
46
|
+
nEmbd;
|
|
47
|
+
dropout;
|
|
48
|
+
bias;
|
|
49
|
+
mask;
|
|
50
|
+
cAttnKernel;
|
|
51
|
+
cAttnBias;
|
|
52
|
+
cProjKernel;
|
|
53
|
+
cProjBias;
|
|
54
|
+
constructor(config, disposalRefs, peakMemory) {
|
|
55
|
+
super(config);
|
|
56
|
+
this.config = config;
|
|
57
|
+
this.peakMemory = peakMemory;
|
|
58
|
+
this.nEmbd = config.nEmbd;
|
|
59
|
+
this.nHead = config.nHead;
|
|
60
|
+
this.dropout = config.dropout;
|
|
61
|
+
this.bias = config.bias;
|
|
62
|
+
// mask is a lower triangular matrix filled with 1
|
|
63
|
+
// calling bandPart zero out the upper triangular part of the all-ones matrix
|
|
64
|
+
// from the doc: tf.linalg.band_part(input, -1, 0) ==> Lower triangular part
|
|
65
|
+
this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0);
|
|
66
|
+
disposalRefs.push(this.mask); // Push a reference to dispose this matrix later
|
|
67
|
+
}
|
|
68
|
+
build() {
|
|
69
|
+
this.cAttnKernel = this.addWeight('c_attn/kernel', [this.nEmbd, 3 * this.nEmbd], 'float32', tf.initializers.glorotNormal({}));
|
|
70
|
+
this.cAttnBias = this.addWeight('c_attn/bias', [3 * this.nEmbd], 'float32', tf.initializers.zeros());
|
|
71
|
+
this.cProjKernel = this.addWeight('c_proj/kernel', [this.nEmbd, this.nEmbd], 'float32', tf.initializers.glorotNormal({}));
|
|
72
|
+
this.cProjBias = this.addWeight('c_proj/bias', [this.nEmbd], 'float32', tf.initializers.zeros());
|
|
73
|
+
}
|
|
74
|
+
computeOutputShape(inputShape) {
|
|
75
|
+
return inputShape;
|
|
76
|
+
}
|
|
77
|
+
getConfig() {
|
|
78
|
+
const config = super.getConfig();
|
|
79
|
+
return Object.assign({}, config, this.config);
|
|
80
|
+
}
|
|
81
|
+
call(input, kwargs) {
|
|
82
|
+
return tf.tidy(() => {
|
|
83
|
+
if (this.cAttnKernel === undefined ||
|
|
84
|
+
this.cAttnBias === undefined ||
|
|
85
|
+
this.cProjKernel === undefined ||
|
|
86
|
+
this.cProjBias === undefined) {
|
|
87
|
+
throw new Error('not built');
|
|
88
|
+
}
|
|
89
|
+
if (Array.isArray(input)) {
|
|
90
|
+
input = input[0];
|
|
91
|
+
}
|
|
92
|
+
this.invokeCallHook(input, kwargs);
|
|
93
|
+
const dense = (x, kernel, bias) => {
|
|
94
|
+
const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]);
|
|
95
|
+
const m = x.matMul(k);
|
|
96
|
+
if (this.bias) {
|
|
97
|
+
return tf.add(m, bias.read());
|
|
98
|
+
}
|
|
99
|
+
else {
|
|
100
|
+
return m;
|
|
101
|
+
}
|
|
102
|
+
};
|
|
103
|
+
// Apply attention weights to inputs as one big matrix which is then split into the
|
|
104
|
+
// query, key and value submatrices
|
|
105
|
+
const cAttn = dense(input, this.cAttnKernel, this.cAttnBias);
|
|
106
|
+
let [q, k, v] = tf.split(cAttn, 3, -1);
|
|
107
|
+
const [B, T, C] = k.shape;
|
|
108
|
+
const splitHeads = (x) => tf.transpose(tf.reshape(x, [B, T, this.nHead, C / this.nHead]), [0, 2, 1, 3]);
|
|
109
|
+
q = splitHeads(q);
|
|
110
|
+
k = splitHeads(k);
|
|
111
|
+
v = splitHeads(v);
|
|
112
|
+
// Scaled self attention: query @ key / sqrt(n_heads)
|
|
113
|
+
let att = tf.mul(tf.matMul(q, k, false, true), tf.div(1, tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32'))));
|
|
114
|
+
// The next operations apply attention to the past tokens, which is
|
|
115
|
+
// essentially a weighted average of the past tokens with complicated weights,
|
|
116
|
+
// and makes sure to not pay any attention to future tokens
|
|
117
|
+
// mask is lower triangular matrix filled with 1
|
|
118
|
+
const mask = this.mask.slice([0, 0], [T, T]);
|
|
119
|
+
// 1 - mask => upper triangular matrix filled with 1
|
|
120
|
+
// (1 - mask) * -10^9 => upper triangular matrix filled with -inf
|
|
121
|
+
// att + ((1 - mask) * -10^9) => lower triangular part is the same as the `att` matrix
|
|
122
|
+
// upper triangular part is -inf
|
|
123
|
+
att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9));
|
|
124
|
+
// applying softmax zeros out the upper triangular part
|
|
125
|
+
//(which are the attention weights of future tokens)
|
|
126
|
+
// and creates a probability distribution for the lower triangular
|
|
127
|
+
// (attention weights of past tokens). The probability distribution ensures
|
|
128
|
+
// that the attention weights of past tokens for a particular token sum to one
|
|
129
|
+
att = tf.softmax(att, -1);
|
|
130
|
+
att = kwargs.training === true ? tf.dropout(att, this.dropout) : att;
|
|
131
|
+
// This is where the (attention-)weighted sum of past values is performed
|
|
132
|
+
let y = tf.matMul(att, v);
|
|
133
|
+
y = tf.transpose(y, [0, 2, 1, 3]);
|
|
134
|
+
y = tf.reshape(y, [B, T, C]);
|
|
135
|
+
y = dense(y, this.cProjKernel, this.cProjBias);
|
|
136
|
+
y = kwargs.training === true ? tf.dropout(y, this.dropout) : y;
|
|
137
|
+
const memoryAllocated = tf.memory().numBytes / 1024 / 1024 / 1024; // GB
|
|
138
|
+
if (memoryAllocated > this.peakMemory.value) {
|
|
139
|
+
this.peakMemory.value = memoryAllocated;
|
|
140
|
+
}
|
|
141
|
+
return y;
|
|
142
|
+
});
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
tf.serialization.registerClass(CausalSelfAttention);
|
|
146
|
+
class GELU extends tf.layers.Layer {
|
|
147
|
+
static className = 'GELU';
|
|
148
|
+
constructor() {
|
|
149
|
+
super({});
|
|
150
|
+
}
|
|
151
|
+
computeOutputShape(inputShape) {
|
|
152
|
+
return inputShape;
|
|
153
|
+
}
|
|
154
|
+
call(input, kwargs) {
|
|
155
|
+
return tf.tidy(() => {
|
|
156
|
+
if (Array.isArray(input)) {
|
|
157
|
+
// TODO support multitensor
|
|
158
|
+
input = input[0];
|
|
159
|
+
}
|
|
160
|
+
this.invokeCallHook(input, kwargs);
|
|
161
|
+
const cdf = tf.mul(0.5, tf.add(1, tf.tanh(tf.mul(tf.sqrt(tf.div(2, Math.PI)), tf.add(input, tf.mul(0.044715, tf.pow(input, 3)))))));
|
|
162
|
+
return tf.mul(input, cdf);
|
|
163
|
+
});
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
tf.serialization.registerClass(GELU);
|
|
167
|
+
function MLP(config) {
|
|
168
|
+
return tf.sequential({ layers: [
|
|
169
|
+
tf.layers.dense({
|
|
170
|
+
name: 'mlp/c_fc',
|
|
171
|
+
units: 4 * config.nEmbd,
|
|
172
|
+
inputDim: config.nEmbd,
|
|
173
|
+
inputShape: [config.blockSize, config.nEmbd]
|
|
174
|
+
}),
|
|
175
|
+
new GELU(),
|
|
176
|
+
tf.layers.dense({
|
|
177
|
+
name: 'mlp/c_proj',
|
|
178
|
+
units: config.nEmbd,
|
|
179
|
+
inputDim: 4 * config.nEmbd,
|
|
180
|
+
inputShape: [config.blockSize, 4 * config.nEmbd]
|
|
181
|
+
}),
|
|
182
|
+
tf.layers.dropout({
|
|
183
|
+
name: 'mlp/drop',
|
|
184
|
+
rate: config.residDrop
|
|
185
|
+
}),
|
|
186
|
+
] });
|
|
187
|
+
}
|
|
188
|
+
function TransformerBlock(conf, disposalRefs, peakMemory) {
|
|
189
|
+
const config = Object.assign({ name: 'h' }, conf);
|
|
190
|
+
const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] });
|
|
191
|
+
let x1, x2;
|
|
192
|
+
// input normalization
|
|
193
|
+
x1 = tf.layers.layerNormalization({ name: config.name + '/ln_1', epsilon: 1e-5 })
|
|
194
|
+
.apply(inputs);
|
|
195
|
+
if (config.debug) {
|
|
196
|
+
x1 = new LogLayer({ name: config.name + '/ln_1_log' }).apply(x1);
|
|
197
|
+
}
|
|
198
|
+
// self attention layer
|
|
199
|
+
x1 = new CausalSelfAttention(Object.assign({}, config, { name: config.name + '/attn' }), disposalRefs, peakMemory).apply(x1);
|
|
200
|
+
// Residual connection
|
|
201
|
+
x1 = tf.layers.add().apply([inputs, x1]);
|
|
202
|
+
// normalization
|
|
203
|
+
x2 = tf.layers
|
|
204
|
+
.layerNormalization({ name: config.name + '/ln_2', epsilon: 1e-5 })
|
|
205
|
+
.apply(x1);
|
|
206
|
+
// MLP
|
|
207
|
+
x2 = MLP(Object.assign({}, config, { name: config.name + '/mlp' })).apply(x2);
|
|
208
|
+
// add attention output to mlp output
|
|
209
|
+
x2 = tf.layers.add().apply([x1, x2]);
|
|
210
|
+
return tf.model({ name: config.name, inputs, outputs: x2 });
|
|
211
|
+
}
|
|
212
|
+
/**
|
|
213
|
+
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
214
|
+
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
215
|
+
* used to create a GPTModel
|
|
216
|
+
*
|
|
217
|
+
* @param conf GPTConfig
|
|
218
|
+
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
|
|
219
|
+
*/
|
|
220
|
+
export function GPTArchitecture(config, disposalRefs, peakMemory) {
|
|
221
|
+
const inputs = tf.input({ shape: [null] });
|
|
222
|
+
//Token embedding
|
|
223
|
+
const tokEmb = config.tokEmb
|
|
224
|
+
? tf.layers.embedding({
|
|
225
|
+
name: config.name + '/wte',
|
|
226
|
+
inputDim: config.vocabSize,
|
|
227
|
+
outputDim: config.nEmbd,
|
|
228
|
+
embeddingsInitializer: 'zeros',
|
|
229
|
+
embeddingsRegularizer: undefined,
|
|
230
|
+
activityRegularizer: undefined
|
|
231
|
+
}).apply(inputs)
|
|
232
|
+
: inputs;
|
|
233
|
+
// Positional embedding
|
|
234
|
+
const range = new Range({}).apply(inputs);
|
|
235
|
+
let posEmb = tf.layers.embedding({
|
|
236
|
+
name: config.name + '/wpe',
|
|
237
|
+
inputDim: config.blockSize,
|
|
238
|
+
outputDim: config.nEmbd,
|
|
239
|
+
embeddingsInitializer: 'zeros'
|
|
240
|
+
}).apply(range);
|
|
241
|
+
if (config.debug) {
|
|
242
|
+
posEmb = new LogLayer({ name: 'posEmb' }).apply(posEmb);
|
|
243
|
+
}
|
|
244
|
+
// token and positional embeddings are added together
|
|
245
|
+
let x = tf.layers.add().apply([tokEmb, posEmb]);
|
|
246
|
+
// dropout
|
|
247
|
+
x = tf.layers.dropout({ name: 'drop', rate: config.embdDrop }).apply(x);
|
|
248
|
+
if (config.debug) {
|
|
249
|
+
x = new LogLayer({ name: 'dropadd' }).apply(x);
|
|
250
|
+
}
|
|
251
|
+
//Apply successively transformer blocks, attention and dense layers
|
|
252
|
+
for (let i = 0; i < config.nLayer; i++) {
|
|
253
|
+
x = TransformerBlock(Object.assign({}, config, { name: config.name + '/h/' + i }), disposalRefs, peakMemory).apply(x);
|
|
254
|
+
}
|
|
255
|
+
// Normalization
|
|
256
|
+
x = tf.layers.layerNormalization({ name: config.name + '/ln_f', epsilon: 1e-5 })
|
|
257
|
+
.apply(x);
|
|
258
|
+
if (config.debug) {
|
|
259
|
+
x = new LogLayer({ name: 'fin/ln' }).apply(x);
|
|
260
|
+
}
|
|
261
|
+
// Append a language modeling head if specified
|
|
262
|
+
if (config.lmHead) {
|
|
263
|
+
x = tf.layers.dense({
|
|
264
|
+
name: 'lm_head',
|
|
265
|
+
units: config.vocabSize,
|
|
266
|
+
inputDim: config.nEmbd,
|
|
267
|
+
inputShape: [config.blockSize, config.nEmbd],
|
|
268
|
+
useBias: false
|
|
269
|
+
}).apply(x);
|
|
270
|
+
}
|
|
271
|
+
return tf.model({ inputs, outputs: x });
|
|
272
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { GPTConfig } from './config.js';
|
|
3
|
+
/**
|
|
4
|
+
* tfjs does not export LazyIterator and Dataset...
|
|
5
|
+
*/
|
|
6
|
+
declare abstract class LazyIterator<T> {
|
|
7
|
+
abstract next(): Promise<IteratorResult<T>>;
|
|
8
|
+
}
|
|
9
|
+
export declare abstract class Dataset<T> {
|
|
10
|
+
abstract iterator(): Promise<LazyIterator<T>>;
|
|
11
|
+
size: number;
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
15
|
+
*
|
|
16
|
+
*/
|
|
17
|
+
declare class GPTModel extends tf.LayersModel {
|
|
18
|
+
protected readonly config: Required<GPTConfig>;
|
|
19
|
+
private readonly disposalRefs;
|
|
20
|
+
protected peakMemory: {
|
|
21
|
+
value: number;
|
|
22
|
+
};
|
|
23
|
+
constructor(partialConfig?: GPTConfig);
|
|
24
|
+
disposeRefs(): void;
|
|
25
|
+
get getGPTConfig(): Required<GPTConfig>;
|
|
26
|
+
compile(): void;
|
|
27
|
+
fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History>;
|
|
28
|
+
}
|
|
29
|
+
interface GenerateConfig {
|
|
30
|
+
maxNewTokens: number;
|
|
31
|
+
temperature: number;
|
|
32
|
+
doSample: boolean;
|
|
33
|
+
}
|
|
34
|
+
/**
|
|
35
|
+
* GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
|
|
36
|
+
* This class extends GPTModel and adds supports for text generation
|
|
37
|
+
*
|
|
38
|
+
*/
|
|
39
|
+
export declare class GPTForCausalLM extends GPTModel {
|
|
40
|
+
generate(idxRaw: tf.TensorLike, conf: GenerateConfig): Promise<number[][]>;
|
|
41
|
+
private generateOnce;
|
|
42
|
+
}
|
|
43
|
+
export {};
|