@epfml/discojs 3.0.1-p20241119093954.0 → 3.0.1-p20241206133538.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/client/client.js +2 -0
- package/dist/client/federated/federated_client.js +2 -2
- package/dist/dataset/dataset.d.ts +18 -5
- package/dist/dataset/dataset.js +58 -23
- package/dist/dataset/types.d.ts +1 -0
- package/dist/default_tasks/index.d.ts +1 -0
- package/dist/default_tasks/index.js +1 -0
- package/dist/default_tasks/tinder_dog.d.ts +2 -0
- package/dist/default_tasks/tinder_dog.js +72 -0
- package/dist/default_tasks/wikitext.js +5 -3
- package/dist/models/gpt/config.d.ts +11 -6
- package/dist/models/gpt/config.js +11 -7
- package/dist/models/gpt/index.d.ts +5 -9
- package/dist/models/gpt/index.js +36 -15
- package/dist/models/gpt/layers.js +260 -82
- package/dist/models/gpt/model.d.ts +1 -1
- package/dist/models/gpt/model.js +4 -4
- package/dist/processing/index.js +8 -9
- package/dist/processing/text.d.ts +16 -6
- package/dist/processing/text.js +29 -26
- package/dist/task/task_handler.js +5 -1
- package/dist/task/training_information.d.ts +1 -1
- package/dist/task/training_information.js +3 -4
- package/dist/training/disco.js +6 -3
- package/dist/types/data_format.d.ts +2 -2
- package/dist/validator.js +2 -2
- package/package.json +1 -1
package/dist/client/client.js
CHANGED
|
@@ -149,6 +149,8 @@ export class Client extends EventEmitter {
|
|
|
149
149
|
}
|
|
150
150
|
url.pathname += `tasks/${this.task.id}/model.json`;
|
|
151
151
|
const response = await fetch(url);
|
|
152
|
+
if (!response.ok)
|
|
153
|
+
throw new Error(`fetch: HTTP status ${response.status}`);
|
|
152
154
|
const encoded = new Uint8Array(await response.arrayBuffer());
|
|
153
155
|
return await serialization.model.decode(encoded);
|
|
154
156
|
}
|
|
@@ -2,7 +2,7 @@ import createDebug from "debug";
|
|
|
2
2
|
import { serialization } from "../../index.js";
|
|
3
3
|
import { Client, shortenId } from "../client.js";
|
|
4
4
|
import { type } from "../messages.js";
|
|
5
|
-
import { waitMessage,
|
|
5
|
+
import { waitMessage, WebSocketServer, } from "../event_connection.js";
|
|
6
6
|
import * as messages from "./messages.js";
|
|
7
7
|
const debug = createDebug("discojs:client:federated");
|
|
8
8
|
/**
|
|
@@ -53,7 +53,7 @@ export class FederatedClient extends Client {
|
|
|
53
53
|
type: type.ClientConnected,
|
|
54
54
|
};
|
|
55
55
|
this.server.send(msg);
|
|
56
|
-
const { id, waitForMoreParticipants, payload, round, nbOfParticipants } = await
|
|
56
|
+
const { id, waitForMoreParticipants, payload, round, nbOfParticipants } = await waitMessage(this.server, type.NewFederatedNodeInfo);
|
|
57
57
|
// This should come right after receiving the message to make sure
|
|
58
58
|
// we don't miss a subsequent message from the server
|
|
59
59
|
// We check if the server is telling us to wait for more participants
|
|
@@ -25,15 +25,22 @@ export declare class Dataset<T> implements AsyncIterable<T> {
|
|
|
25
25
|
* @param ratio between 0 (all on left) and 1 (all on right)
|
|
26
26
|
*/
|
|
27
27
|
split(ratio: number): [Dataset<T>, Dataset<T>];
|
|
28
|
-
/**
|
|
28
|
+
/** Create batches of `size` elements with potential overlap.
|
|
29
|
+
* Last batch is smaller if dataset isn't perfectly divisible
|
|
29
30
|
*
|
|
30
|
-
*
|
|
31
|
+
* If overlap is set to a positive integer, the last `overlap` elements of a batch
|
|
32
|
+
* are the first `overlap` elements of the next batch.
|
|
33
|
+
*
|
|
34
|
+
* This method is tailored to create text sequences where each token's label is the following token.
|
|
35
|
+
* In order to have a label for the last token of the input sequence, we include the first token
|
|
36
|
+
* of the next sequence (i.e. with an overlap of 1).
|
|
31
37
|
*
|
|
32
38
|
* @param size count of element per chunk
|
|
39
|
+
* @param overlap number of elements overlapping between two consecutive batches
|
|
33
40
|
*/
|
|
34
|
-
batch(size: number): Dataset<Batched<T>>;
|
|
35
|
-
/** Flatten
|
|
36
|
-
|
|
41
|
+
batch(size: number, overlap?: number): Dataset<Batched<T>>;
|
|
42
|
+
/** Flatten batches/arrays of elements */
|
|
43
|
+
flatten<U>(this: Dataset<DatasetLike<U>>): Dataset<U>;
|
|
37
44
|
/** Join side-by-side
|
|
38
45
|
*
|
|
39
46
|
* Stops as soon as one runs out
|
|
@@ -41,6 +48,12 @@ export declare class Dataset<T> implements AsyncIterable<T> {
|
|
|
41
48
|
* @param other right side
|
|
42
49
|
**/
|
|
43
50
|
zip<U>(other: Dataset<U> | DatasetLike<U>): Dataset<[T, U]>;
|
|
51
|
+
/**
|
|
52
|
+
* Repeat the dataset `times` times
|
|
53
|
+
* @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
|
|
54
|
+
* @returns a dataset repeated `times` times
|
|
55
|
+
*/
|
|
56
|
+
repeat(times?: number): Dataset<T>;
|
|
44
57
|
/** Compute size
|
|
45
58
|
*
|
|
46
59
|
* This is a costly operation as we need to go through the whole Dataset.
|
package/dist/dataset/dataset.js
CHANGED
|
@@ -1,6 +1,22 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import { List, Range } from "immutable";
|
|
3
3
|
const debug = createDebug("discojs:dataset");
|
|
4
|
+
/** Convert a DatasetLike object to an async generator */
|
|
5
|
+
async function* datasetLikeToGenerator(content) {
|
|
6
|
+
let iter;
|
|
7
|
+
if (typeof content === "function")
|
|
8
|
+
iter = content();
|
|
9
|
+
else if (Symbol.asyncIterator in content)
|
|
10
|
+
iter = content[Symbol.asyncIterator]();
|
|
11
|
+
else
|
|
12
|
+
iter = content[Symbol.iterator]();
|
|
13
|
+
while (true) {
|
|
14
|
+
const result = await iter.next();
|
|
15
|
+
if (result.done === true)
|
|
16
|
+
break;
|
|
17
|
+
yield result.value;
|
|
18
|
+
}
|
|
19
|
+
}
|
|
4
20
|
/** Immutable series of data */
|
|
5
21
|
export class Dataset {
|
|
6
22
|
#content;
|
|
@@ -11,19 +27,7 @@ export class Dataset {
|
|
|
11
27
|
*/
|
|
12
28
|
constructor(content) {
|
|
13
29
|
this.#content = async function* () {
|
|
14
|
-
|
|
15
|
-
if (typeof content === "function")
|
|
16
|
-
iter = content();
|
|
17
|
-
else if (Symbol.asyncIterator in content)
|
|
18
|
-
iter = content[Symbol.asyncIterator]();
|
|
19
|
-
else
|
|
20
|
-
iter = content[Symbol.iterator]();
|
|
21
|
-
while (true) {
|
|
22
|
-
const result = await iter.next();
|
|
23
|
-
if (result.done === true)
|
|
24
|
-
break;
|
|
25
|
-
yield result.value;
|
|
26
|
-
}
|
|
30
|
+
yield* datasetLikeToGenerator(content);
|
|
27
31
|
};
|
|
28
32
|
}
|
|
29
33
|
[Symbol.asyncIterator]() {
|
|
@@ -87,19 +91,31 @@ export class Dataset {
|
|
|
87
91
|
}.bind(this)),
|
|
88
92
|
];
|
|
89
93
|
}
|
|
90
|
-
/**
|
|
94
|
+
/** Create batches of `size` elements with potential overlap.
|
|
95
|
+
* Last batch is smaller if dataset isn't perfectly divisible
|
|
96
|
+
*
|
|
97
|
+
* If overlap is set to a positive integer, the last `overlap` elements of a batch
|
|
98
|
+
* are the first `overlap` elements of the next batch.
|
|
91
99
|
*
|
|
92
|
-
*
|
|
100
|
+
* This method is tailored to create text sequences where each token's label is the following token.
|
|
101
|
+
* In order to have a label for the last token of the input sequence, we include the first token
|
|
102
|
+
* of the next sequence (i.e. with an overlap of 1).
|
|
93
103
|
*
|
|
94
104
|
* @param size count of element per chunk
|
|
105
|
+
* @param overlap number of elements overlapping between two consecutive batches
|
|
95
106
|
*/
|
|
96
|
-
batch(size) {
|
|
107
|
+
batch(size, overlap = 0) {
|
|
97
108
|
if (size <= 0 || !Number.isInteger(size))
|
|
98
109
|
throw new Error("invalid size");
|
|
110
|
+
if (overlap >= size || !Number.isInteger(overlap))
|
|
111
|
+
throw new Error("invalid overlap");
|
|
99
112
|
return new Dataset(async function* () {
|
|
100
113
|
const iter = this[Symbol.asyncIterator]();
|
|
114
|
+
let overlapped = List();
|
|
101
115
|
for (;;) {
|
|
102
|
-
const batch = List(
|
|
116
|
+
const batch = List(
|
|
117
|
+
// get the first elements of the next batch
|
|
118
|
+
await Promise.all(Range(overlapped.size, size).map(() => iter.next()))).flatMap((res) => {
|
|
103
119
|
if (res.done)
|
|
104
120
|
return [];
|
|
105
121
|
else
|
|
@@ -107,18 +123,21 @@ export class Dataset {
|
|
|
107
123
|
});
|
|
108
124
|
if (batch.isEmpty())
|
|
109
125
|
break;
|
|
110
|
-
yield batch
|
|
126
|
+
// yield the current batch with the first elements of the next batch
|
|
127
|
+
yield overlapped.concat(batch);
|
|
128
|
+
overlapped = batch.takeLast(overlap);
|
|
111
129
|
// iterator couldn't generate more
|
|
112
|
-
if (batch.size < size)
|
|
130
|
+
if (batch.size < size - overlap)
|
|
113
131
|
break;
|
|
114
132
|
}
|
|
115
133
|
}.bind(this));
|
|
116
134
|
}
|
|
117
|
-
/** Flatten
|
|
118
|
-
|
|
135
|
+
/** Flatten batches/arrays of elements */
|
|
136
|
+
flatten() {
|
|
119
137
|
return new Dataset(async function* () {
|
|
120
|
-
for await (const batch of this)
|
|
121
|
-
yield* batch;
|
|
138
|
+
for await (const batch of this) {
|
|
139
|
+
yield* datasetLikeToGenerator(batch);
|
|
140
|
+
}
|
|
122
141
|
}.bind(this));
|
|
123
142
|
}
|
|
124
143
|
/** Join side-by-side
|
|
@@ -141,6 +160,22 @@ export class Dataset {
|
|
|
141
160
|
}
|
|
142
161
|
}.bind(this));
|
|
143
162
|
}
|
|
163
|
+
/**
|
|
164
|
+
* Repeat the dataset `times` times
|
|
165
|
+
* @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
|
|
166
|
+
* @returns a dataset repeated `times` times
|
|
167
|
+
*/
|
|
168
|
+
repeat(times) {
|
|
169
|
+
if (times !== undefined && (!Number.isInteger(times) || times < 1))
|
|
170
|
+
throw new Error("times needs to be a positive integer or undefined");
|
|
171
|
+
return new Dataset(async function* () {
|
|
172
|
+
let loop = 0;
|
|
173
|
+
do {
|
|
174
|
+
yield* this;
|
|
175
|
+
loop++;
|
|
176
|
+
} while (times === undefined || loop < times);
|
|
177
|
+
}.bind(this));
|
|
178
|
+
}
|
|
144
179
|
/** Compute size
|
|
145
180
|
*
|
|
146
181
|
* This is a costly operation as we need to go through the whole Dataset.
|
package/dist/dataset/types.d.ts
CHANGED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { models } from '../index.js';
|
|
3
|
+
export const tinderDog = {
|
|
4
|
+
getTask() {
|
|
5
|
+
return {
|
|
6
|
+
id: 'tinder_dog',
|
|
7
|
+
displayInformation: {
|
|
8
|
+
taskTitle: 'GDHF 2024 | TinderDog',
|
|
9
|
+
summary: {
|
|
10
|
+
preview: 'Which dog is the cutest....or not?',
|
|
11
|
+
overview: "Binary classification model for dog cuteness."
|
|
12
|
+
},
|
|
13
|
+
model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1',
|
|
14
|
+
dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.',
|
|
15
|
+
dataExampleText: '',
|
|
16
|
+
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png',
|
|
17
|
+
sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip',
|
|
18
|
+
sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, pick one of the data splits (the folder 0 for example) and use the CSV option below to select the file named "labels.csv". You can now connect the images located in the same folder.'
|
|
19
|
+
},
|
|
20
|
+
trainingInformation: {
|
|
21
|
+
epochs: 10,
|
|
22
|
+
roundDuration: 2,
|
|
23
|
+
validationSplit: 0, // nicer plot for GDHF demo
|
|
24
|
+
batchSize: 10,
|
|
25
|
+
dataType: 'image',
|
|
26
|
+
IMAGE_H: 64,
|
|
27
|
+
IMAGE_W: 64,
|
|
28
|
+
LABEL_LIST: ['Cute dogs', 'Less cute dogs'],
|
|
29
|
+
scheme: 'federated',
|
|
30
|
+
aggregationStrategy: 'mean',
|
|
31
|
+
minNbOfParticipants: 3,
|
|
32
|
+
tensorBackend: 'tfjs'
|
|
33
|
+
}
|
|
34
|
+
};
|
|
35
|
+
},
|
|
36
|
+
async getModel() {
|
|
37
|
+
const seed = 42; // set a seed to ensure reproducibility during GDHF demo
|
|
38
|
+
const imageHeight = this.getTask().trainingInformation.IMAGE_H;
|
|
39
|
+
const imageWidth = this.getTask().trainingInformation.IMAGE_W;
|
|
40
|
+
const imageChannels = 3;
|
|
41
|
+
const model = tf.sequential();
|
|
42
|
+
model.add(tf.layers.conv2d({
|
|
43
|
+
inputShape: [imageHeight, imageWidth, imageChannels],
|
|
44
|
+
kernelSize: 5,
|
|
45
|
+
filters: 8,
|
|
46
|
+
activation: 'relu',
|
|
47
|
+
kernelInitializer: tf.initializers.heNormal({ seed })
|
|
48
|
+
}));
|
|
49
|
+
model.add(tf.layers.conv2d({
|
|
50
|
+
kernelSize: 5, filters: 16, activation: 'relu',
|
|
51
|
+
kernelInitializer: tf.initializers.heNormal({ seed })
|
|
52
|
+
}));
|
|
53
|
+
model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
|
|
54
|
+
model.add(tf.layers.dropout({ rate: 0.25, seed }));
|
|
55
|
+
model.add(tf.layers.flatten());
|
|
56
|
+
model.add(tf.layers.dense({
|
|
57
|
+
units: 32, activation: 'relu',
|
|
58
|
+
kernelInitializer: tf.initializers.heNormal({ seed })
|
|
59
|
+
}));
|
|
60
|
+
model.add(tf.layers.dropout({ rate: 0.25, seed }));
|
|
61
|
+
model.add(tf.layers.dense({
|
|
62
|
+
units: 2, activation: 'softmax',
|
|
63
|
+
kernelInitializer: tf.initializers.heNormal({ seed })
|
|
64
|
+
}));
|
|
65
|
+
model.compile({
|
|
66
|
+
optimizer: tf.train.adam(0.0005),
|
|
67
|
+
loss: 'categoricalCrossentropy',
|
|
68
|
+
metrics: ['accuracy']
|
|
69
|
+
});
|
|
70
|
+
return Promise.resolve(new models.TFJS('image', model));
|
|
71
|
+
}
|
|
72
|
+
};
|
|
@@ -31,14 +31,16 @@ export const wikitext = {
|
|
|
31
31
|
// But if set to 0 then the webapp doesn't display the validation metrics
|
|
32
32
|
validationSplit: 0.1,
|
|
33
33
|
roundDuration: 2,
|
|
34
|
-
batchSize:
|
|
34
|
+
batchSize: 8, // If set too high firefox raises a WebGL error
|
|
35
35
|
tokenizer: 'Xenova/gpt2',
|
|
36
|
-
|
|
36
|
+
contextLength: 64,
|
|
37
37
|
tensorBackend: 'gpt'
|
|
38
38
|
}
|
|
39
39
|
};
|
|
40
40
|
},
|
|
41
41
|
getModel() {
|
|
42
|
-
return Promise.resolve(new models.GPT(
|
|
42
|
+
return Promise.resolve(new models.GPT({
|
|
43
|
+
contextLength: this.getTask().trainingInformation.contextLength,
|
|
44
|
+
}));
|
|
43
45
|
}
|
|
44
46
|
};
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
type GPTModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
|
|
2
2
|
export interface GPTConfig {
|
|
3
3
|
lr: number;
|
|
4
|
-
|
|
5
|
-
vocabSize
|
|
4
|
+
contextLength: number;
|
|
5
|
+
vocabSize?: number;
|
|
6
6
|
modelType: GPTModelType;
|
|
7
7
|
name?: string;
|
|
8
8
|
evaluate?: boolean;
|
|
@@ -11,22 +11,27 @@ export interface GPTConfig {
|
|
|
11
11
|
maxIter?: number;
|
|
12
12
|
weightDecay?: number;
|
|
13
13
|
verbose?: 0 | 1;
|
|
14
|
-
bias?: boolean;
|
|
15
14
|
debug?: boolean;
|
|
16
15
|
dropout?: number;
|
|
17
16
|
residDrop?: number;
|
|
18
17
|
embdDrop?: number;
|
|
19
|
-
tokEmb?: boolean;
|
|
20
|
-
lmHead?: boolean;
|
|
21
18
|
nLayer?: number;
|
|
22
19
|
nHead?: number;
|
|
23
20
|
nEmbd?: number;
|
|
21
|
+
seed?: number;
|
|
24
22
|
}
|
|
25
|
-
export declare const
|
|
23
|
+
export declare const DefaultGPTConfig: Required<GPTConfig>;
|
|
26
24
|
export type ModelSize = {
|
|
27
25
|
nLayer: number;
|
|
28
26
|
nHead: number;
|
|
29
27
|
nEmbd: number;
|
|
30
28
|
};
|
|
31
29
|
export declare function getModelSizes(modelType: GPTModelType): Required<ModelSize>;
|
|
30
|
+
export interface GenerationConfig {
|
|
31
|
+
doSample: boolean;
|
|
32
|
+
temperature: number;
|
|
33
|
+
topk: number;
|
|
34
|
+
seed: number;
|
|
35
|
+
}
|
|
36
|
+
export declare const DefaultGenerationConfig: Required<GenerationConfig>;
|
|
32
37
|
export {};
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
|
|
2
|
-
export const
|
|
3
|
-
name: 'transformer',
|
|
2
|
+
export const DefaultGPTConfig = {
|
|
3
|
+
name: 'transformer', // prefix for the model layer names
|
|
4
4
|
lr: 0.001,
|
|
5
5
|
weightDecay: 0,
|
|
6
6
|
maxIter: 10,
|
|
@@ -9,18 +9,16 @@ export const DEFAULT_CONFIG = {
|
|
|
9
9
|
evaluate: true,
|
|
10
10
|
maxEvalBatches: 12,
|
|
11
11
|
evaluateEvery: 100,
|
|
12
|
-
|
|
13
|
-
vocabSize:
|
|
14
|
-
bias: true,
|
|
12
|
+
contextLength: 128,
|
|
13
|
+
vocabSize: 50257,
|
|
15
14
|
debug: false,
|
|
16
15
|
dropout: 0.2,
|
|
17
16
|
residDrop: 0.2,
|
|
18
17
|
embdDrop: 0.2,
|
|
19
|
-
tokEmb: true,
|
|
20
|
-
lmHead: true,
|
|
21
18
|
nLayer: 3,
|
|
22
19
|
nHead: 3,
|
|
23
20
|
nEmbd: 48,
|
|
21
|
+
seed: Math.random(),
|
|
24
22
|
};
|
|
25
23
|
export function getModelSizes(modelType) {
|
|
26
24
|
switch (modelType) {
|
|
@@ -40,3 +38,9 @@ export function getModelSizes(modelType) {
|
|
|
40
38
|
return { nLayer: 3, nHead: 3, nEmbd: 48 };
|
|
41
39
|
}
|
|
42
40
|
}
|
|
41
|
+
export const DefaultGenerationConfig = {
|
|
42
|
+
temperature: 1.0,
|
|
43
|
+
doSample: false,
|
|
44
|
+
seed: Math.random(),
|
|
45
|
+
topk: 50
|
|
46
|
+
};
|
|
@@ -1,23 +1,20 @@
|
|
|
1
1
|
/**
|
|
2
|
-
*
|
|
2
|
+
* Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
|
|
3
|
+
* With modifications from @peacefulotter, @lukemovement and the Disco team
|
|
3
4
|
**/
|
|
4
5
|
import * as tf from "@tensorflow/tfjs";
|
|
5
6
|
import type { Batched, Dataset, DataFormat } from "../../index.js";
|
|
6
7
|
import { WeightsContainer } from "../../index.js";
|
|
7
8
|
import { BatchLogs, Model, EpochLogs } from "../index.js";
|
|
8
|
-
import {
|
|
9
|
+
import type { GPTConfig, GenerationConfig } from './config.js';
|
|
9
10
|
export type GPTSerialization = {
|
|
10
11
|
weights: WeightsContainer;
|
|
11
12
|
config?: GPTConfig;
|
|
12
13
|
};
|
|
13
|
-
interface PredictConfig {
|
|
14
|
-
temperature: number;
|
|
15
|
-
doSample: boolean;
|
|
16
|
-
}
|
|
17
14
|
export declare class GPT extends Model<"text"> {
|
|
18
15
|
#private;
|
|
19
16
|
private readonly model;
|
|
20
|
-
constructor(partialConfig?: GPTConfig
|
|
17
|
+
constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel);
|
|
21
18
|
/**
|
|
22
19
|
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
23
20
|
* This allows for getting logs and stopping training without callbacks.
|
|
@@ -28,7 +25,7 @@ export declare class GPT extends Model<"text"> {
|
|
|
28
25
|
* @param tracker
|
|
29
26
|
*/
|
|
30
27
|
train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded["text"]>>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
31
|
-
predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<
|
|
28
|
+
predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<GenerationConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
|
|
32
29
|
get config(): Required<GPTConfig>;
|
|
33
30
|
get weights(): WeightsContainer;
|
|
34
31
|
set weights(ws: WeightsContainer);
|
|
@@ -37,4 +34,3 @@ export declare class GPT extends Model<"text"> {
|
|
|
37
34
|
extract(): tf.LayersModel;
|
|
38
35
|
[Symbol.dispose](): void;
|
|
39
36
|
}
|
|
40
|
-
export {};
|
package/dist/models/gpt/index.js
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
/**
|
|
2
|
-
*
|
|
2
|
+
* Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
|
|
3
|
+
* With modifications from @peacefulotter, @lukemovement and the Disco team
|
|
3
4
|
**/
|
|
4
5
|
import createDebug from "debug";
|
|
5
6
|
import { List, Range } from "immutable";
|
|
@@ -7,12 +8,12 @@ import * as tf from "@tensorflow/tfjs";
|
|
|
7
8
|
import { WeightsContainer } from "../../index.js";
|
|
8
9
|
import { Model, EpochLogs } from "../index.js";
|
|
9
10
|
import { GPTModel } from "./model.js";
|
|
10
|
-
import { DEFAULT_CONFIG } from "./config.js";
|
|
11
11
|
import evaluate from "./evaluate.js";
|
|
12
|
+
import { DefaultGPTConfig, DefaultGenerationConfig } from './config.js';
|
|
12
13
|
const debug = createDebug("discojs:models:gpt");
|
|
13
14
|
export class GPT extends Model {
|
|
14
15
|
model;
|
|
15
|
-
#
|
|
16
|
+
#contextLength;
|
|
16
17
|
#maxBatchCount;
|
|
17
18
|
#vocabSize;
|
|
18
19
|
constructor(partialConfig, layersModel) {
|
|
@@ -20,9 +21,9 @@ export class GPT extends Model {
|
|
|
20
21
|
const model = new GPTModel(partialConfig, layersModel);
|
|
21
22
|
model.compile();
|
|
22
23
|
this.model = model;
|
|
23
|
-
this.#
|
|
24
|
-
this.#maxBatchCount = partialConfig?.maxIter ??
|
|
25
|
-
this.#vocabSize = partialConfig?.vocabSize ??
|
|
24
|
+
this.#contextLength = partialConfig?.contextLength ?? DefaultGPTConfig.contextLength;
|
|
25
|
+
this.#maxBatchCount = partialConfig?.maxIter ?? DefaultGPTConfig.maxIter;
|
|
26
|
+
this.#vocabSize = partialConfig?.vocabSize ?? DefaultGPTConfig.vocabSize;
|
|
26
27
|
}
|
|
27
28
|
/**
|
|
28
29
|
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
@@ -85,16 +86,21 @@ export class GPT extends Model {
|
|
|
85
86
|
}));
|
|
86
87
|
}
|
|
87
88
|
async predict(batch, options) {
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
doSample: false,
|
|
91
|
-
...options,
|
|
92
|
-
};
|
|
89
|
+
// overwrite default with user config
|
|
90
|
+
const config = Object.assign({}, DefaultGenerationConfig, options);
|
|
93
91
|
return List(await Promise.all(batch.map((tokens) => this.#predictSingle(tokens, config))));
|
|
94
92
|
}
|
|
93
|
+
/**
|
|
94
|
+
* Generate the next token after the input sequence.
|
|
95
|
+
* In other words, takes an input tensor of shape (prompt length T) and returns a tensor of shape (T+1)
|
|
96
|
+
*
|
|
97
|
+
* @param token input tokens of shape (T,). T is truncated to the model's context length
|
|
98
|
+
* @param config generation config: temperature, doSample, topk
|
|
99
|
+
* @returns the next token predicted by the model
|
|
100
|
+
*/
|
|
95
101
|
async #predictSingle(tokens, config) {
|
|
96
102
|
// slice input tokens if longer than context length
|
|
97
|
-
tokens = tokens.slice(-this.#
|
|
103
|
+
tokens = tokens.slice(-this.#contextLength);
|
|
98
104
|
const input = tf.tidy(() => tf.tensor1d(tokens.toArray(), "int32").expandDims(0));
|
|
99
105
|
const logits = tf.tidy(() => {
|
|
100
106
|
const output = this.model.predict(input);
|
|
@@ -111,9 +117,24 @@ export class GPT extends Model {
|
|
|
111
117
|
.div(config.temperature)
|
|
112
118
|
.softmax());
|
|
113
119
|
logits.dispose();
|
|
114
|
-
const next = tf.tidy(() =>
|
|
115
|
-
|
|
116
|
-
|
|
120
|
+
const next = tf.tidy(() => {
|
|
121
|
+
if (config.doSample) {
|
|
122
|
+
// returns topk biggest values among the `vocab_size` probabilities and the corresponding tokens indices
|
|
123
|
+
// both shapes are (config.topk,)
|
|
124
|
+
const { values: topkProbs, indices: topkTokens } = tf.topk(probs, config.topk);
|
|
125
|
+
// sample an index from the top-k probabilities
|
|
126
|
+
// e.g. [[0.1, 0.4, 0.3], [0.1, 0.2, 0.5]] -> [[1], [2]]
|
|
127
|
+
// note: multinomial does not need the input to sum to 1
|
|
128
|
+
const selectedIndices = tf.multinomial(topkProbs, 1, config.seed, false); // (B, )
|
|
129
|
+
// return the corresponding token from the sampled indices (one per sequence in the batch).
|
|
130
|
+
// if for some reason the probabilities are NaN, selectedIndices will be out of bounds
|
|
131
|
+
return topkTokens.gather(selectedIndices).squeeze([0]); // (1)
|
|
132
|
+
}
|
|
133
|
+
else {
|
|
134
|
+
// greedy decoding: return the token with the highest probability
|
|
135
|
+
return probs.argMax();
|
|
136
|
+
}
|
|
137
|
+
});
|
|
117
138
|
probs.dispose();
|
|
118
139
|
const ret = await next.array();
|
|
119
140
|
next.dispose();
|