@epfml/discojs 2.1.2-p20240513140724.0 → 2.1.2-p20240515132210.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.d.ts +1 -1
- package/dist/index.js +1 -1
- package/dist/memory/base.d.ts +6 -19
- package/dist/memory/empty.d.ts +2 -2
- package/dist/memory/empty.js +2 -2
- package/dist/memory/index.d.ts +1 -1
- package/dist/memory/index.js +1 -1
- package/dist/memory/model_type.d.ts +1 -1
- package/dist/memory/model_type.js +5 -5
- package/dist/models/gpt/config.d.ts +32 -0
- package/dist/models/gpt/config.js +42 -0
- package/dist/models/gpt/evaluate.d.ts +7 -0
- package/dist/models/gpt/evaluate.js +44 -0
- package/dist/models/gpt/index.d.ts +37 -0
- package/dist/models/gpt/index.js +107 -0
- package/dist/models/gpt/layers.d.ts +13 -0
- package/dist/models/gpt/layers.js +272 -0
- package/dist/models/gpt/model.d.ts +43 -0
- package/dist/models/gpt/model.js +191 -0
- package/dist/models/gpt/optimizers.d.ts +4 -0
- package/dist/models/gpt/optimizers.js +95 -0
- package/dist/models/index.d.ts +5 -0
- package/dist/models/index.js +4 -0
- package/dist/models/model.d.ts +51 -0
- package/dist/models/model.js +8 -0
- package/dist/models/tfjs.d.ts +24 -0
- package/dist/models/tfjs.js +107 -0
- package/dist/models/tokenizer.d.ts +14 -0
- package/dist/models/tokenizer.js +23 -0
- package/dist/training/trainer/trainer_builder.js +2 -2
- package/package.json +1 -1
package/dist/index.d.ts
CHANGED
|
@@ -8,7 +8,7 @@ export * as aggregator from './aggregator/index.js';
|
|
|
8
8
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
9
9
|
export { AsyncInformant } from './async_informant.js';
|
|
10
10
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
11
|
-
export { Memory,
|
|
11
|
+
export { Memory, StoredModelType, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
|
|
12
12
|
export { Disco, RoundLogs } from './training/index.js';
|
|
13
13
|
export { Validator } from './validation/index.js';
|
|
14
14
|
export { Model, EpochLogs } from './models/index.js';
|
package/dist/index.js
CHANGED
|
@@ -8,7 +8,7 @@ export * as aggregator from './aggregator/index.js';
|
|
|
8
8
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
9
9
|
export { AsyncInformant } from './async_informant.js';
|
|
10
10
|
export { ConsoleLogger } from './logging/index.js';
|
|
11
|
-
export { Memory,
|
|
11
|
+
export { Memory, StoredModelType, Empty as EmptyMemory } from './memory/index.js';
|
|
12
12
|
export { Disco } from './training/index.js';
|
|
13
13
|
export { Validator } from './validation/index.js';
|
|
14
14
|
export { Model } from './models/index.js';
|
package/dist/memory/base.d.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { Model, TaskID } from '../index.js';
|
|
2
|
-
import type {
|
|
2
|
+
import type { StoredModelType } from './model_type.js';
|
|
3
3
|
/**
|
|
4
4
|
* Model path which uniquely identifies a model in memory.
|
|
5
5
|
*/
|
|
@@ -8,22 +8,9 @@ export type Path = string;
|
|
|
8
8
|
* Model information which uniquely identifies a model in memory.
|
|
9
9
|
*/
|
|
10
10
|
export interface ModelInfo {
|
|
11
|
-
|
|
12
|
-
* The model's type: "working" or "saved" model.
|
|
13
|
-
*/
|
|
14
|
-
type?: ModelType;
|
|
15
|
-
/**
|
|
16
|
-
* The model's version, to allow for multiple saved models of a same task without
|
|
17
|
-
* causing id conflicts
|
|
18
|
-
*/
|
|
11
|
+
type?: StoredModelType;
|
|
19
12
|
version?: number;
|
|
20
|
-
/**
|
|
21
|
-
* The model's corresponding task
|
|
22
|
-
*/
|
|
23
13
|
taskID: TaskID;
|
|
24
|
-
/**
|
|
25
|
-
* The model's name
|
|
26
|
-
*/
|
|
27
14
|
name: string;
|
|
28
15
|
}
|
|
29
16
|
/**
|
|
@@ -95,21 +82,21 @@ export declare abstract class Memory {
|
|
|
95
82
|
/**
|
|
96
83
|
* Computes the path in memory corresponding to the given model source, be it a path or model information.
|
|
97
84
|
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
98
|
-
* with a one-to-one
|
|
85
|
+
* with a one-to-one equivalence. Returns undefined instead if no path could be inferred from the given
|
|
99
86
|
* model source.
|
|
100
87
|
* @param source The model source
|
|
101
88
|
* @returns The model path
|
|
102
89
|
*/
|
|
103
|
-
abstract
|
|
90
|
+
abstract getModelMemoryPath(source: ModelSource): Path | undefined;
|
|
104
91
|
/**
|
|
105
92
|
* Computes the model information corresponding to the given model source, be it a path or model information.
|
|
106
93
|
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
107
|
-
* with a one-to-one
|
|
94
|
+
* with a one-to-one equivalence. Returns undefined instead if no unique model information could be inferred
|
|
108
95
|
* from the given model source.
|
|
109
96
|
* @param source The model source
|
|
110
97
|
* @returns The model information
|
|
111
98
|
*/
|
|
112
|
-
abstract
|
|
99
|
+
abstract getModelInfo(source: ModelSource): ModelInfo | undefined;
|
|
113
100
|
/**
|
|
114
101
|
* Computes the lowest version a model source can have without conflicting with model versions currently in memory.
|
|
115
102
|
* @param source The model source
|
package/dist/memory/empty.d.ts
CHANGED
|
@@ -14,7 +14,7 @@ export declare class Empty extends Memory {
|
|
|
14
14
|
saveModel(): Promise<undefined>;
|
|
15
15
|
deleteModel(): Promise<void>;
|
|
16
16
|
downloadModel(): Promise<void>;
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
getModelMemoryPath(): Path;
|
|
18
|
+
getModelInfo(): ModelInfo;
|
|
19
19
|
duplicateSource(): Promise<undefined>;
|
|
20
20
|
}
|
package/dist/memory/empty.js
CHANGED
|
@@ -31,10 +31,10 @@ export class Empty extends Memory {
|
|
|
31
31
|
downloadModel() {
|
|
32
32
|
return Promise.reject(new Error('empty'));
|
|
33
33
|
}
|
|
34
|
-
|
|
34
|
+
getModelMemoryPath() {
|
|
35
35
|
throw new Error('empty');
|
|
36
36
|
}
|
|
37
|
-
|
|
37
|
+
getModelInfo() {
|
|
38
38
|
throw new Error('empty');
|
|
39
39
|
}
|
|
40
40
|
duplicateSource() {
|
package/dist/memory/index.d.ts
CHANGED
package/dist/memory/index.js
CHANGED
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
* being trained ("working model") or a regular model saved in memory ("saved model").
|
|
4
4
|
* There can only be a single working model for a given task.
|
|
5
5
|
*/
|
|
6
|
-
export declare enum
|
|
6
|
+
export declare enum StoredModelType {
|
|
7
7
|
WORKING = "working",
|
|
8
8
|
SAVED = "saved"
|
|
9
9
|
}
|
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
* being trained ("working model") or a regular model saved in memory ("saved model").
|
|
4
4
|
* There can only be a single working model for a given task.
|
|
5
5
|
*/
|
|
6
|
-
export var
|
|
7
|
-
(function (
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
})(
|
|
6
|
+
export var StoredModelType;
|
|
7
|
+
(function (StoredModelType) {
|
|
8
|
+
StoredModelType["WORKING"] = "working";
|
|
9
|
+
StoredModelType["SAVED"] = "saved";
|
|
10
|
+
})(StoredModelType || (StoredModelType = {}));
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
type GPTModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
|
|
2
|
+
export interface GPTConfig {
|
|
3
|
+
lr: number;
|
|
4
|
+
blockSize: number;
|
|
5
|
+
vocabSize: number;
|
|
6
|
+
modelType: GPTModelType;
|
|
7
|
+
name?: string;
|
|
8
|
+
evaluate?: boolean;
|
|
9
|
+
maxEvalBatches?: number;
|
|
10
|
+
evaluateEvery?: number;
|
|
11
|
+
maxIter?: number;
|
|
12
|
+
weightDecay?: number;
|
|
13
|
+
verbose?: 0 | 1;
|
|
14
|
+
bias?: boolean;
|
|
15
|
+
debug?: boolean;
|
|
16
|
+
dropout?: number;
|
|
17
|
+
residDrop?: number;
|
|
18
|
+
embdDrop?: number;
|
|
19
|
+
tokEmb?: boolean;
|
|
20
|
+
lmHead?: boolean;
|
|
21
|
+
nLayer?: number;
|
|
22
|
+
nHead?: number;
|
|
23
|
+
nEmbd?: number;
|
|
24
|
+
}
|
|
25
|
+
export declare const DEFAULT_CONFIG: Required<GPTConfig>;
|
|
26
|
+
export type ModelSize = {
|
|
27
|
+
nLayer: number;
|
|
28
|
+
nHead: number;
|
|
29
|
+
nEmbd: number;
|
|
30
|
+
};
|
|
31
|
+
export declare function getModelSizes(modelType: GPTModelType): Required<ModelSize>;
|
|
32
|
+
export {};
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
|
|
2
|
+
export const DEFAULT_CONFIG = {
|
|
3
|
+
name: 'transformer',
|
|
4
|
+
lr: 0.001,
|
|
5
|
+
weightDecay: 0,
|
|
6
|
+
maxIter: 5,
|
|
7
|
+
verbose: 0,
|
|
8
|
+
modelType: 'gpt-nano',
|
|
9
|
+
evaluate: true,
|
|
10
|
+
maxEvalBatches: 12,
|
|
11
|
+
evaluateEvery: 100,
|
|
12
|
+
blockSize: 128,
|
|
13
|
+
vocabSize: 50258,
|
|
14
|
+
bias: true,
|
|
15
|
+
debug: false,
|
|
16
|
+
dropout: 0.2,
|
|
17
|
+
residDrop: 0.2,
|
|
18
|
+
embdDrop: 0.2,
|
|
19
|
+
tokEmb: true,
|
|
20
|
+
lmHead: true,
|
|
21
|
+
nLayer: 3,
|
|
22
|
+
nHead: 3,
|
|
23
|
+
nEmbd: 48,
|
|
24
|
+
};
|
|
25
|
+
export function getModelSizes(modelType) {
|
|
26
|
+
switch (modelType) {
|
|
27
|
+
case 'gpt2':
|
|
28
|
+
return { nLayer: 12, nHead: 12, nEmbd: 768 };
|
|
29
|
+
case 'gpt2-medium':
|
|
30
|
+
return { nLayer: 24, nHead: 16, nEmbd: 1024 };
|
|
31
|
+
case 'gpt2-large':
|
|
32
|
+
return { nLayer: 36, nHead: 20, nEmbd: 1280 };
|
|
33
|
+
case 'gpt2-xl':
|
|
34
|
+
return { nLayer: 48, nHead: 25, nEmbd: 1600 };
|
|
35
|
+
case 'gpt-mini':
|
|
36
|
+
return { nLayer: 6, nHead: 6, nEmbd: 192 };
|
|
37
|
+
case 'gpt-micro':
|
|
38
|
+
return { nLayer: 4, nHead: 4, nEmbd: 128 };
|
|
39
|
+
case 'gpt-nano':
|
|
40
|
+
return { nLayer: 3, nHead: 3, nEmbd: 48 };
|
|
41
|
+
}
|
|
42
|
+
}
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
interface DataPoint extends tf.TensorContainerObject {
|
|
3
|
+
xs: tf.Tensor2D;
|
|
4
|
+
ys: tf.Tensor3D;
|
|
5
|
+
}
|
|
6
|
+
export default function evaluate(model: tf.LayersModel, dataset: tf.data.Dataset<DataPoint>, maxEvalBatches: number): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>>;
|
|
7
|
+
export {};
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
export default async function evaluate(model, dataset, maxEvalBatches) {
|
|
3
|
+
let datasetSize = 0;
|
|
4
|
+
let totalLoss = 0;
|
|
5
|
+
const acc = [0, 0];
|
|
6
|
+
await dataset.take(maxEvalBatches).map(({ xs, ys }) => {
|
|
7
|
+
const logits = model.apply(xs);
|
|
8
|
+
if (Array.isArray(logits)) {
|
|
9
|
+
throw new Error('model output too many tensor');
|
|
10
|
+
}
|
|
11
|
+
if (logits instanceof tf.SymbolicTensor) {
|
|
12
|
+
throw new Error('model output symbolic tensor');
|
|
13
|
+
}
|
|
14
|
+
xs.dispose();
|
|
15
|
+
return { logits, ys };
|
|
16
|
+
}).mapAsync(async ({ logits, ys }) => {
|
|
17
|
+
const lossTensor = tf.losses.softmaxCrossEntropy(ys, logits);
|
|
18
|
+
const loss = await lossTensor.array();
|
|
19
|
+
if (typeof loss !== 'number') {
|
|
20
|
+
throw new Error('got multiple loss');
|
|
21
|
+
}
|
|
22
|
+
const accTensor = tf.metrics.categoricalAccuracy(ys, logits);
|
|
23
|
+
const accSize = accTensor.shape.reduce((l, r) => l * r, 1);
|
|
24
|
+
const accSum = accTensor.sum();
|
|
25
|
+
const accSummed = await accSum.array();
|
|
26
|
+
if (typeof accSummed !== 'number') {
|
|
27
|
+
throw new Error('got multiple accuracy sum');
|
|
28
|
+
}
|
|
29
|
+
tf.dispose([ys, logits, accTensor, accSum, lossTensor]);
|
|
30
|
+
return { loss, accSummed, accSize };
|
|
31
|
+
}).forEachAsync(({ loss, accSummed, accSize }) => {
|
|
32
|
+
datasetSize += 1;
|
|
33
|
+
totalLoss += loss;
|
|
34
|
+
acc[0] += accSummed;
|
|
35
|
+
acc[1] += accSize;
|
|
36
|
+
});
|
|
37
|
+
const loss = totalLoss / datasetSize;
|
|
38
|
+
return {
|
|
39
|
+
val_loss: loss,
|
|
40
|
+
val_perplexity: Math.exp(loss),
|
|
41
|
+
acc: acc[0] / acc[1],
|
|
42
|
+
val_acc: acc[0] / acc[1]
|
|
43
|
+
};
|
|
44
|
+
}
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import * as tf from '@tensorflow/tfjs';
|
|
5
|
+
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
6
|
+
import { WeightsContainer } from '../../index.js';
|
|
7
|
+
import type { Dataset } from '../../dataset/index.js';
|
|
8
|
+
import { Model } from '../model.js';
|
|
9
|
+
import type { EpochLogs, Prediction, Sample } from '../model.js';
|
|
10
|
+
import type { GPTConfig } from './config.js';
|
|
11
|
+
export declare class GPT extends Model {
|
|
12
|
+
private readonly model;
|
|
13
|
+
constructor(partialConfig?: GPTConfig);
|
|
14
|
+
/**
|
|
15
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
16
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
17
|
+
*
|
|
18
|
+
* @param trainingData training dataset
|
|
19
|
+
* @param validationData validation dataset
|
|
20
|
+
* @param epochs the number of passes of the training dataset
|
|
21
|
+
* @param tracker
|
|
22
|
+
*/
|
|
23
|
+
train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
|
|
24
|
+
predict(input: Sample): Promise<Prediction>;
|
|
25
|
+
generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
|
|
26
|
+
get config(): Required<GPTConfig>;
|
|
27
|
+
get weights(): WeightsContainer;
|
|
28
|
+
set weights(ws: WeightsContainer);
|
|
29
|
+
static deserialize(data: GPTSerialization): Model;
|
|
30
|
+
serialize(): GPTSerialization;
|
|
31
|
+
extract(): tf.LayersModel;
|
|
32
|
+
[Symbol.dispose](): void;
|
|
33
|
+
}
|
|
34
|
+
export type GPTSerialization = {
|
|
35
|
+
weights: WeightsContainer;
|
|
36
|
+
config?: GPTConfig;
|
|
37
|
+
};
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
|
+
**/
|
|
4
|
+
import { WeightsContainer } from '../../index.js';
|
|
5
|
+
import { Model } from '../model.js';
|
|
6
|
+
import { GPTForCausalLM } from './model.js';
|
|
7
|
+
export class GPT extends Model {
|
|
8
|
+
model;
|
|
9
|
+
constructor(partialConfig) {
|
|
10
|
+
super();
|
|
11
|
+
this.model = new GPTForCausalLM(partialConfig);
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
15
|
+
* This allows for getting logs and stopping training without callbacks.
|
|
16
|
+
*
|
|
17
|
+
* @param trainingData training dataset
|
|
18
|
+
* @param validationData validation dataset
|
|
19
|
+
* @param epochs the number of passes of the training dataset
|
|
20
|
+
* @param tracker
|
|
21
|
+
*/
|
|
22
|
+
async *train(trainingData, validationData, epochs = 1) {
|
|
23
|
+
this.model.compile();
|
|
24
|
+
let logs;
|
|
25
|
+
const trainingArgs = {
|
|
26
|
+
epochs: 1, // force fitDataset to do only one epoch because it is wrapped in a for loop
|
|
27
|
+
validationData,
|
|
28
|
+
callbacks: { onEpochEnd: (_, cur) => { logs = cur; } },
|
|
29
|
+
};
|
|
30
|
+
for (let epoch = 0; epoch < epochs; epoch++) {
|
|
31
|
+
await this.model.fitDataset(trainingData, trainingArgs);
|
|
32
|
+
if (logs === undefined) {
|
|
33
|
+
throw new Error("Epoch didn't gave any logs");
|
|
34
|
+
}
|
|
35
|
+
const { loss, val_acc, val_loss, peakMemory } = logs;
|
|
36
|
+
if (loss === undefined || isNaN(loss)) {
|
|
37
|
+
throw new Error("Training loss is undefined or nan");
|
|
38
|
+
}
|
|
39
|
+
const structuredLogs = {
|
|
40
|
+
epoch,
|
|
41
|
+
peakMemory,
|
|
42
|
+
training: {
|
|
43
|
+
loss: logs.loss
|
|
44
|
+
}
|
|
45
|
+
};
|
|
46
|
+
if (validationData !== undefined) {
|
|
47
|
+
if (val_loss === undefined || isNaN(val_loss) ||
|
|
48
|
+
val_acc === undefined || isNaN(val_acc)) {
|
|
49
|
+
throw new Error("Invalid validation logs");
|
|
50
|
+
}
|
|
51
|
+
structuredLogs.validation = { accuracy: logs.val_acc, loss: logs.val_loss };
|
|
52
|
+
}
|
|
53
|
+
yield structuredLogs;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
predict(input) {
|
|
57
|
+
const ret = this.model.predict(input);
|
|
58
|
+
if (Array.isArray(ret)) {
|
|
59
|
+
throw new Error('prediction yield many Tensors but should have only returned one');
|
|
60
|
+
}
|
|
61
|
+
return Promise.resolve(ret);
|
|
62
|
+
}
|
|
63
|
+
async generate(input, tokenizer, newTokens = 10) {
|
|
64
|
+
const { input_ids: tokens } = await tokenizer(input, { return_tensor: false });
|
|
65
|
+
const generationConfig = {
|
|
66
|
+
maxNewTokens: newTokens,
|
|
67
|
+
temperature: 1.0,
|
|
68
|
+
doSample: false
|
|
69
|
+
};
|
|
70
|
+
const predictedTokens = await this.model.generate(tokens, generationConfig);
|
|
71
|
+
const generatedWords = tokenizer.decode(predictedTokens[0]);
|
|
72
|
+
return generatedWords;
|
|
73
|
+
}
|
|
74
|
+
get config() {
|
|
75
|
+
return this.model.getGPTConfig;
|
|
76
|
+
}
|
|
77
|
+
get weights() {
|
|
78
|
+
return new WeightsContainer(this.model.weights.map((w) => w.read()));
|
|
79
|
+
}
|
|
80
|
+
set weights(ws) {
|
|
81
|
+
this.model.setWeights(ws.weights);
|
|
82
|
+
}
|
|
83
|
+
static deserialize(data) {
|
|
84
|
+
const model = new GPT(data.config);
|
|
85
|
+
model.weights = data.weights;
|
|
86
|
+
return model;
|
|
87
|
+
}
|
|
88
|
+
serialize() {
|
|
89
|
+
return {
|
|
90
|
+
weights: this.weights,
|
|
91
|
+
config: this.config
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
extract() {
|
|
95
|
+
return this.model;
|
|
96
|
+
}
|
|
97
|
+
[Symbol.dispose]() {
|
|
98
|
+
console.log("Disposing model");
|
|
99
|
+
if (this.model.optimizer !== undefined) {
|
|
100
|
+
this.model.optimizer.dispose();
|
|
101
|
+
}
|
|
102
|
+
// Some tensors are not cleaned up when model.dispose is called
|
|
103
|
+
// So we dispose them manually
|
|
104
|
+
this.model.disposeRefs();
|
|
105
|
+
this.model.dispose();
|
|
106
|
+
}
|
|
107
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { GPTConfig } from './config.js';
|
|
3
|
+
/**
|
|
4
|
+
* The GPTArchitecture specifically defines a GPT forward pass, i.e.,
|
|
5
|
+
* what are the inputs, the successive transformer blocks and the outputs. It is then
|
|
6
|
+
* used to create a GPTModel
|
|
7
|
+
*
|
|
8
|
+
* @param conf GPTConfig
|
|
9
|
+
* @returns model, tf.LayersModel, which supports model(inputs), model.predict and model.apply
|
|
10
|
+
*/
|
|
11
|
+
export declare function GPTArchitecture(config: Required<GPTConfig>, disposalRefs: tf.TensorContainer[], peakMemory: {
|
|
12
|
+
value: number;
|
|
13
|
+
}): tf.LayersModel;
|