@epfml/discojs 3.0.1-p20241025115642.0 → 3.0.1-p20241107104659.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/get.d.ts +3 -3
- package/dist/client/client.d.ts +5 -5
- package/dist/client/decentralized/decentralized_client.d.ts +2 -2
- package/dist/client/federated/federated_client.d.ts +2 -2
- package/dist/client/utils.d.ts +2 -2
- package/dist/dataset/dataset.d.ts +9 -2
- package/dist/dataset/dataset.js +83 -36
- package/dist/dataset/image.d.ts +5 -0
- package/dist/dataset/image.js +6 -1
- package/dist/dataset/index.d.ts +0 -1
- package/dist/dataset/index.js +0 -1
- package/dist/dataset/types.d.ts +2 -0
- package/dist/default_tasks/cifar10.d.ts +1 -1
- package/dist/default_tasks/cifar10.js +2 -3
- package/dist/default_tasks/lus_covid.d.ts +1 -1
- package/dist/default_tasks/lus_covid.js +2 -3
- package/dist/default_tasks/mnist.d.ts +1 -1
- package/dist/default_tasks/mnist.js +3 -5
- package/dist/default_tasks/simple_face.d.ts +1 -1
- package/dist/default_tasks/simple_face.js +2 -3
- package/dist/default_tasks/titanic.d.ts +1 -1
- package/dist/default_tasks/titanic.js +3 -6
- package/dist/default_tasks/wikitext.d.ts +1 -1
- package/dist/default_tasks/wikitext.js +1 -2
- package/dist/index.d.ts +4 -5
- package/dist/index.js +4 -5
- package/dist/models/gpt/index.d.ts +13 -16
- package/dist/models/gpt/index.js +62 -43
- package/dist/models/gpt/model.d.ts +1 -15
- package/dist/models/gpt/model.js +1 -75
- package/dist/models/model.d.ts +7 -12
- package/dist/models/tfjs.d.ts +10 -8
- package/dist/models/tfjs.js +106 -44
- package/dist/models/tokenizer.d.ts +1 -1
- package/dist/privacy.js +1 -1
- package/dist/processing/image.d.ts +18 -0
- package/dist/processing/image.js +75 -0
- package/dist/processing/index.d.ts +8 -0
- package/dist/processing/index.js +106 -0
- package/dist/processing/tabular.d.ts +19 -0
- package/dist/processing/tabular.js +33 -0
- package/dist/processing/text.d.ts +11 -0
- package/dist/processing/text.js +33 -0
- package/dist/serialization/model.d.ts +3 -3
- package/dist/serialization/model.js +19 -6
- package/dist/task/task.d.ts +4 -3
- package/dist/task/task.js +5 -3
- package/dist/task/task_handler.d.ts +3 -3
- package/dist/task/task_provider.d.ts +4 -4
- package/dist/task/training_information.d.ts +25 -16
- package/dist/task/training_information.js +76 -72
- package/dist/training/disco.d.ts +20 -12
- package/dist/training/disco.js +32 -13
- package/dist/training/trainer.d.ts +6 -7
- package/dist/training/trainer.js +6 -6
- package/dist/types/data_format.d.ts +40 -0
- package/dist/types/index.d.ts +2 -0
- package/dist/types/index.js +1 -0
- package/dist/validator.d.ts +10 -0
- package/dist/validator.js +30 -0
- package/package.json +4 -2
- package/dist/dataset/data/data.d.ts +0 -47
- package/dist/dataset/data/data.js +0 -88
- package/dist/dataset/data/data_split.d.ts +0 -8
- package/dist/dataset/data/helpers.d.ts +0 -10
- package/dist/dataset/data/helpers.js +0 -97
- package/dist/dataset/data/image_data.d.ts +0 -11
- package/dist/dataset/data/image_data.js +0 -43
- package/dist/dataset/data/index.d.ts +0 -5
- package/dist/dataset/data/index.js +0 -5
- package/dist/dataset/data/preprocessing/base.d.ts +0 -16
- package/dist/dataset/data/preprocessing/base.js +0 -1
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
- package/dist/dataset/data/preprocessing/index.d.ts +0 -4
- package/dist/dataset/data/preprocessing/index.js +0 -3
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
- package/dist/dataset/data/tabular_data.d.ts +0 -11
- package/dist/dataset/data/tabular_data.js +0 -24
- package/dist/dataset/data/text_data.d.ts +0 -11
- package/dist/dataset/data/text_data.js +0 -14
- package/dist/processing.d.ts +0 -35
- package/dist/processing.js +0 -89
- package/dist/types.d.ts +0 -3
- package/dist/types.js +0 -1
- package/dist/validation/index.d.ts +0 -1
- package/dist/validation/index.js +0 -1
- package/dist/validation/validator.d.ts +0 -10
- package/dist/validation/validator.js +0 -113
- /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
package/dist/models/gpt/index.js
CHANGED
|
@@ -2,21 +2,27 @@
|
|
|
2
2
|
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
3
|
**/
|
|
4
4
|
import createDebug from "debug";
|
|
5
|
-
import { List } from
|
|
6
|
-
import * as tf from
|
|
7
|
-
import { WeightsContainer } from
|
|
5
|
+
import { List, Range } from "immutable";
|
|
6
|
+
import * as tf from "@tensorflow/tfjs";
|
|
7
|
+
import { WeightsContainer } from "../../index.js";
|
|
8
8
|
import { Model, EpochLogs } from "../index.js";
|
|
9
|
-
import {
|
|
10
|
-
import { DEFAULT_CONFIG } from
|
|
11
|
-
import evaluate from
|
|
9
|
+
import { GPTModel } from "./model.js";
|
|
10
|
+
import { DEFAULT_CONFIG } from "./config.js";
|
|
11
|
+
import evaluate from "./evaluate.js";
|
|
12
12
|
const debug = createDebug("discojs:models:gpt");
|
|
13
13
|
export class GPT extends Model {
|
|
14
14
|
model;
|
|
15
|
+
#blockSize;
|
|
15
16
|
#maxBatchCount;
|
|
17
|
+
#vocabSize;
|
|
16
18
|
constructor(partialConfig, layersModel) {
|
|
17
19
|
super();
|
|
18
|
-
|
|
20
|
+
const model = new GPTModel(partialConfig, layersModel);
|
|
21
|
+
model.compile();
|
|
22
|
+
this.model = model;
|
|
23
|
+
this.#blockSize = partialConfig?.blockSize ?? DEFAULT_CONFIG.blockSize;
|
|
19
24
|
this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter;
|
|
25
|
+
this.#vocabSize = partialConfig?.vocabSize ?? DEFAULT_CONFIG.vocabSize;
|
|
20
26
|
}
|
|
21
27
|
/**
|
|
22
28
|
* The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
|
|
@@ -27,26 +33,20 @@ export class GPT extends Model {
|
|
|
27
33
|
* @param epochs the number of passes of the training dataset
|
|
28
34
|
* @param tracker
|
|
29
35
|
*/
|
|
30
|
-
async *train(
|
|
31
|
-
this.model.compile();
|
|
32
|
-
const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
|
|
36
|
+
async *train(trainingDataset, validationDataset) {
|
|
33
37
|
let batchesLogs = List();
|
|
34
|
-
for (
|
|
35
|
-
const iteration = await batches.next();
|
|
36
|
-
if (iteration.done)
|
|
37
|
-
break;
|
|
38
|
-
const batch = iteration.value;
|
|
38
|
+
for await (const [batch, _] of trainingDataset.zip(Range(0, this.#maxBatchCount))) {
|
|
39
39
|
const batchLogs = await this.#runBatch(batch);
|
|
40
|
-
tf.dispose(batch);
|
|
41
40
|
yield batchLogs;
|
|
42
41
|
batchesLogs = batchesLogs.push(batchLogs);
|
|
43
42
|
}
|
|
44
|
-
const validation =
|
|
43
|
+
const validation = validationDataset && (await this.#evaluate(validationDataset));
|
|
45
44
|
return new EpochLogs(batchesLogs, validation);
|
|
46
45
|
}
|
|
47
46
|
async #runBatch(batch) {
|
|
47
|
+
const tfBatch = this.#batchToTF(batch);
|
|
48
48
|
let logs;
|
|
49
|
-
await this.model.fitDataset(tf.data.array([
|
|
49
|
+
await this.model.fitDataset(tf.data.array([tfBatch]), {
|
|
50
50
|
epochs: 1,
|
|
51
51
|
verbose: 0, // don't pollute
|
|
52
52
|
callbacks: {
|
|
@@ -55,6 +55,7 @@ export class GPT extends Model {
|
|
|
55
55
|
},
|
|
56
56
|
},
|
|
57
57
|
});
|
|
58
|
+
tf.dispose(tfBatch);
|
|
58
59
|
if (logs === undefined)
|
|
59
60
|
throw new Error("batch didn't gave any logs");
|
|
60
61
|
const { loss, acc: accuracy } = logs;
|
|
@@ -67,38 +68,56 @@ export class GPT extends Model {
|
|
|
67
68
|
};
|
|
68
69
|
}
|
|
69
70
|
async #evaluate(dataset) {
|
|
70
|
-
const evaluation = await evaluate(this.model,
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
case undefined:
|
|
74
|
-
throw new Error("nullish value in dataset");
|
|
75
|
-
default:
|
|
76
|
-
// TODO unsafe cast
|
|
77
|
-
return t;
|
|
78
|
-
}
|
|
79
|
-
}), this.config.maxEvalBatches);
|
|
71
|
+
const evaluation = await evaluate(this.model, tf.data.generator(async function* () {
|
|
72
|
+
yield* dataset.map((batch) => this.#batchToTF(batch));
|
|
73
|
+
}.bind(this)), this.config.maxEvalBatches);
|
|
80
74
|
return {
|
|
81
75
|
accuracy: evaluation.val_acc,
|
|
82
76
|
loss: evaluation.val_loss,
|
|
83
77
|
};
|
|
84
78
|
}
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
79
|
+
#batchToTF(batch) {
|
|
80
|
+
return tf.tidy(() => ({
|
|
81
|
+
xs: tf.stack(batch.map(([line]) => tf.tensor1d(line.toArray(), "int32")).toArray()), // cast as stack doesn't type
|
|
82
|
+
ys: tf.stack(batch
|
|
83
|
+
.map(([line, next]) => tf.oneHot(line.shift().push(next).toArray(), this.#vocabSize))
|
|
84
|
+
.toArray()), // cast as oneHot/stack doesn't type
|
|
85
|
+
}));
|
|
91
86
|
}
|
|
92
|
-
async
|
|
93
|
-
const
|
|
94
|
-
const generationConfig = {
|
|
95
|
-
maxNewTokens: newTokens,
|
|
87
|
+
async predict(batch, options) {
|
|
88
|
+
const config = {
|
|
96
89
|
temperature: 1.0,
|
|
97
|
-
doSample: false
|
|
90
|
+
doSample: false,
|
|
91
|
+
...options,
|
|
98
92
|
};
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
93
|
+
return List(await Promise.all(batch.map((tokens) => this.#predictSingle(tokens, config))));
|
|
94
|
+
}
|
|
95
|
+
async #predictSingle(tokens, config) {
|
|
96
|
+
// slice input tokens if longer than context length
|
|
97
|
+
tokens = tokens.slice(-this.#blockSize);
|
|
98
|
+
const input = tf.tidy(() => tf.tensor1d(tokens.toArray(), "int32").expandDims(0));
|
|
99
|
+
const logits = tf.tidy(() => {
|
|
100
|
+
const output = this.model.predict(input);
|
|
101
|
+
if (Array.isArray(output))
|
|
102
|
+
throw new Error("The model outputs too multiple values");
|
|
103
|
+
if (output.rank !== 3)
|
|
104
|
+
throw new Error("The model outputs wrong shape");
|
|
105
|
+
return output.squeeze([0]);
|
|
106
|
+
});
|
|
107
|
+
input.dispose();
|
|
108
|
+
const probs = tf.tidy(() => logits
|
|
109
|
+
.slice([logits.shape[0] - 1])
|
|
110
|
+
.squeeze([0])
|
|
111
|
+
.div(config.temperature)
|
|
112
|
+
.softmax());
|
|
113
|
+
logits.dispose();
|
|
114
|
+
const next = tf.tidy(() => config.doSample
|
|
115
|
+
? tf.multinomial(probs, 1).squeeze([0])
|
|
116
|
+
: probs.argMax());
|
|
117
|
+
probs.dispose();
|
|
118
|
+
const ret = await next.array();
|
|
119
|
+
next.dispose();
|
|
120
|
+
return ret;
|
|
102
121
|
}
|
|
103
122
|
get config() {
|
|
104
123
|
return this.model.getGPTConfig;
|
|
@@ -117,7 +136,7 @@ export class GPT extends Model {
|
|
|
117
136
|
serialize() {
|
|
118
137
|
return {
|
|
119
138
|
weights: this.weights,
|
|
120
|
-
config: this.config
|
|
139
|
+
config: this.config,
|
|
121
140
|
};
|
|
122
141
|
}
|
|
123
142
|
extract() {
|
|
@@ -14,25 +14,11 @@ export declare abstract class Dataset<T> {
|
|
|
14
14
|
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
15
15
|
*
|
|
16
16
|
*/
|
|
17
|
-
declare class GPTModel extends tf.LayersModel {
|
|
17
|
+
export declare class GPTModel extends tf.LayersModel {
|
|
18
18
|
protected readonly config: Required<GPTConfig>;
|
|
19
19
|
constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
|
|
20
20
|
get getGPTConfig(): Required<GPTConfig>;
|
|
21
21
|
compile(): void;
|
|
22
22
|
fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History>;
|
|
23
23
|
}
|
|
24
|
-
interface GenerateConfig {
|
|
25
|
-
maxNewTokens: number;
|
|
26
|
-
temperature: number;
|
|
27
|
-
doSample: boolean;
|
|
28
|
-
}
|
|
29
|
-
/**
|
|
30
|
-
* GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
|
|
31
|
-
* This class extends GPTModel and adds supports for text generation
|
|
32
|
-
*
|
|
33
|
-
*/
|
|
34
|
-
export declare class GPTForCausalLM extends GPTModel {
|
|
35
|
-
generate(idxRaw: tf.TensorLike, conf: GenerateConfig): Promise<number[][]>;
|
|
36
|
-
private generateOnce;
|
|
37
|
-
}
|
|
38
24
|
export {};
|
package/dist/models/gpt/model.js
CHANGED
|
@@ -9,7 +9,7 @@ const debug = createDebug("discojs:models:gpt");
|
|
|
9
9
|
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
10
10
|
*
|
|
11
11
|
*/
|
|
12
|
-
class GPTModel extends tf.LayersModel {
|
|
12
|
+
export class GPTModel extends tf.LayersModel {
|
|
13
13
|
config;
|
|
14
14
|
constructor(partialConfig, layersModel) {
|
|
15
15
|
// Fill missing config parameters with default values
|
|
@@ -124,77 +124,3 @@ class GPTModel extends tf.LayersModel {
|
|
|
124
124
|
return new tf.History();
|
|
125
125
|
}
|
|
126
126
|
}
|
|
127
|
-
const defaultGenerateConfig = {
|
|
128
|
-
maxNewTokens: 20,
|
|
129
|
-
temperature: 1.0,
|
|
130
|
-
doSample: false
|
|
131
|
-
};
|
|
132
|
-
function prepareIdx(idx) {
|
|
133
|
-
return tf.tidy(() => {
|
|
134
|
-
let ret;
|
|
135
|
-
if (idx instanceof tf.Tensor) {
|
|
136
|
-
ret = idx.clone();
|
|
137
|
-
}
|
|
138
|
-
else {
|
|
139
|
-
ret = tf.tensor(idx);
|
|
140
|
-
}
|
|
141
|
-
if (ret.dtype !== 'int32') {
|
|
142
|
-
ret = ret.toInt();
|
|
143
|
-
}
|
|
144
|
-
switch (ret.shape.length) {
|
|
145
|
-
case 1:
|
|
146
|
-
return ret.expandDims(0);
|
|
147
|
-
case 2:
|
|
148
|
-
return ret;
|
|
149
|
-
default:
|
|
150
|
-
throw new Error('unexpected shape');
|
|
151
|
-
}
|
|
152
|
-
});
|
|
153
|
-
}
|
|
154
|
-
/**
|
|
155
|
-
* GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
|
|
156
|
-
* This class extends GPTModel and adds supports for text generation
|
|
157
|
-
*
|
|
158
|
-
*/
|
|
159
|
-
export class GPTForCausalLM extends GPTModel {
|
|
160
|
-
async generate(idxRaw, conf) {
|
|
161
|
-
const config = Object.assign({}, defaultGenerateConfig, conf);
|
|
162
|
-
let idx = prepareIdx(idxRaw);
|
|
163
|
-
for (let step = 0; step < config.maxNewTokens; step++) {
|
|
164
|
-
const idxNext = this.generateOnce(this, idx, config);
|
|
165
|
-
const idxNew = idx.concat(idxNext, 1);
|
|
166
|
-
tf.dispose(idx);
|
|
167
|
-
idx = idxNew;
|
|
168
|
-
tf.dispose(idxNext);
|
|
169
|
-
}
|
|
170
|
-
const idxArr = await idx.array();
|
|
171
|
-
tf.dispose(idx);
|
|
172
|
-
return idxArr;
|
|
173
|
-
}
|
|
174
|
-
generateOnce(model, idx, config) {
|
|
175
|
-
const idxNext = tf.tidy(() => {
|
|
176
|
-
// slice input tokens if longer than context length
|
|
177
|
-
const blockSize = this.config.blockSize;
|
|
178
|
-
idx = idx.shape[1] <= blockSize
|
|
179
|
-
? idx : idx.slice([0, idx.shape[1] - blockSize]);
|
|
180
|
-
const output = model.predict(idx);
|
|
181
|
-
if (Array.isArray(output))
|
|
182
|
-
throw new Error('The model outputs too multiple values');
|
|
183
|
-
if (output.shape.length !== 3)
|
|
184
|
-
throw new Error('The model outputs wrong shape');
|
|
185
|
-
const logits = output;
|
|
186
|
-
const logitsScaled = logits
|
|
187
|
-
.slice([0, idx.shape[1] - 1, 0])
|
|
188
|
-
.reshape([logits.shape[0], logits.shape[2]])
|
|
189
|
-
.div(tf.scalar(config.temperature));
|
|
190
|
-
const probs = logitsScaled.softmax(-1);
|
|
191
|
-
if (config.doSample) {
|
|
192
|
-
return tf.multinomial(probs, 1);
|
|
193
|
-
}
|
|
194
|
-
else {
|
|
195
|
-
return probs.argMax(-1).expandDims(1);
|
|
196
|
-
}
|
|
197
|
-
});
|
|
198
|
-
return idxNext;
|
|
199
|
-
}
|
|
200
|
-
}
|
package/dist/models/model.d.ts
CHANGED
|
@@ -1,14 +1,11 @@
|
|
|
1
|
-
import type
|
|
2
|
-
import type { WeightsContainer } from "../index.js";
|
|
1
|
+
import type { Batched, Dataset, DataFormat, DataType, WeightsContainer } from "../index.js";
|
|
3
2
|
import type { BatchLogs, EpochLogs } from "./logs.js";
|
|
4
|
-
export type Prediction = tf.Tensor;
|
|
5
|
-
export type Sample = tf.Tensor;
|
|
6
3
|
/**
|
|
7
4
|
* Trainable predictor
|
|
8
5
|
*
|
|
9
6
|
* Allow for various implementation of models (various train function, tensor-library, ...)
|
|
10
7
|
**/
|
|
11
|
-
export declare abstract class Model implements Disposable {
|
|
8
|
+
export declare abstract class Model<D extends DataType> implements Disposable {
|
|
12
9
|
/** Return training state */
|
|
13
10
|
abstract get weights(): WeightsContainer;
|
|
14
11
|
/** Set training state */
|
|
@@ -16,15 +13,13 @@ export declare abstract class Model implements Disposable {
|
|
|
16
13
|
/**
|
|
17
14
|
* Improve predictor
|
|
18
15
|
*
|
|
19
|
-
* @param
|
|
20
|
-
* @param
|
|
21
|
-
* @
|
|
22
|
-
* @param tracker watch the various steps
|
|
23
|
-
* @yields on every epoch, training can be stop by `return`ing it
|
|
16
|
+
* @param trainingDataset dataset to optimize for
|
|
17
|
+
* @param validationDataset dataset to measure how well it is training
|
|
18
|
+
* @yields on every epoch, training can be stop by `return`ing or `throw`ing it
|
|
24
19
|
*/
|
|
25
|
-
abstract train(
|
|
20
|
+
abstract train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
26
21
|
/** Predict likely values */
|
|
27
|
-
abstract predict(
|
|
22
|
+
abstract predict(batch: Batched<DataFormat.ModelEncoded[D][0]>): Promise<Batched<DataFormat.ModelEncoded[D][1]>>;
|
|
28
23
|
/**
|
|
29
24
|
* This method is automatically called to cleanup the memory occupied by the model
|
|
30
25
|
* when leaving the definition scope if the instance has been defined with the `using` keyword.
|
package/dist/models/tfjs.d.ts
CHANGED
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { WeightsContainer } from
|
|
2
|
+
import { Batched, Dataset, DataFormat, DataType, WeightsContainer } from "../index.js";
|
|
3
3
|
import { BatchLogs } from './index.js';
|
|
4
4
|
import { Model } from './index.js';
|
|
5
|
-
import { Prediction, Sample } from './model.js';
|
|
6
5
|
import { EpochLogs } from './logs.js';
|
|
6
|
+
type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts];
|
|
7
7
|
/** TensorFlow JavaScript model with standard training */
|
|
8
|
-
export declare class TFJS extends Model {
|
|
8
|
+
export declare class TFJS<D extends "image" | "tabular"> extends Model<D> {
|
|
9
9
|
#private;
|
|
10
|
+
readonly datatype: D;
|
|
10
11
|
private readonly model;
|
|
11
12
|
/** Wrap the given trainable model */
|
|
12
|
-
constructor(model: tf.LayersModel);
|
|
13
|
+
constructor(datatype: D, model: tf.LayersModel);
|
|
13
14
|
get weights(): WeightsContainer;
|
|
14
15
|
set weights(ws: WeightsContainer);
|
|
15
|
-
train(
|
|
16
|
-
predict(
|
|
17
|
-
static deserialize(
|
|
18
|
-
serialize(): Promise<
|
|
16
|
+
train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
17
|
+
predict(batch: Batched<DataFormat.ModelEncoded[D][0]>): Promise<Batched<DataFormat.ModelEncoded[D][1]>>;
|
|
18
|
+
static deserialize<D extends "image" | "tabular">([datatype, artifacts,]: Serialized<D>): Promise<TFJS<D>>;
|
|
19
|
+
serialize(): Promise<Serialized<D>>;
|
|
19
20
|
[Symbol.dispose](): void;
|
|
20
21
|
/**
|
|
21
22
|
* extract wrapped model
|
|
@@ -24,3 +25,4 @@ export declare class TFJS extends Model {
|
|
|
24
25
|
*/
|
|
25
26
|
extract(): tf.LayersModel;
|
|
26
27
|
}
|
|
28
|
+
export {};
|
package/dist/models/tfjs.js
CHANGED
|
@@ -1,18 +1,22 @@
|
|
|
1
|
-
import { List, Map } from
|
|
1
|
+
import { List, Map, Range } from "immutable";
|
|
2
2
|
import * as tf from '@tensorflow/tfjs';
|
|
3
|
-
import { WeightsContainer } from
|
|
3
|
+
import { WeightsContainer, } from "../index.js";
|
|
4
4
|
import { Model } from './index.js';
|
|
5
5
|
import { EpochLogs } from './logs.js';
|
|
6
6
|
/** TensorFlow JavaScript model with standard training */
|
|
7
7
|
export class TFJS extends Model {
|
|
8
|
+
datatype;
|
|
8
9
|
model;
|
|
9
10
|
/** Wrap the given trainable model */
|
|
10
|
-
constructor(model) {
|
|
11
|
+
constructor(datatype, model) {
|
|
11
12
|
super();
|
|
13
|
+
this.datatype = datatype;
|
|
12
14
|
this.model = model;
|
|
13
15
|
if (model.loss === undefined) {
|
|
14
16
|
throw new Error('TFJS models need to be compiled to be used');
|
|
15
17
|
}
|
|
18
|
+
if (model.outputs.length !== 1)
|
|
19
|
+
throw new Error("only support single output model");
|
|
16
20
|
}
|
|
17
21
|
get weights() {
|
|
18
22
|
return new WeightsContainer(this.model.weights.map((w) => w.read()));
|
|
@@ -20,57 +24,43 @@ export class TFJS extends Model {
|
|
|
20
24
|
set weights(ws) {
|
|
21
25
|
this.model.setWeights(ws.weights);
|
|
22
26
|
}
|
|
23
|
-
async *train(
|
|
24
|
-
const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
|
|
27
|
+
async *train(trainingDataset, validationDataset) {
|
|
25
28
|
let batchesLogs = List();
|
|
26
|
-
for (
|
|
27
|
-
const iteration = await batches.next();
|
|
28
|
-
if (iteration.done)
|
|
29
|
-
break;
|
|
30
|
-
const batch = iteration.value;
|
|
29
|
+
for await (const [batch, batchNumber] of trainingDataset.zip(Range())) {
|
|
31
30
|
const batchLogs = {
|
|
32
31
|
batch: batchNumber,
|
|
33
32
|
...(await this.#runBatch(batch)),
|
|
34
33
|
};
|
|
35
|
-
tf.dispose(batch);
|
|
36
34
|
yield batchLogs;
|
|
37
35
|
batchesLogs = batchesLogs.push(batchLogs);
|
|
38
36
|
}
|
|
39
|
-
const validation =
|
|
37
|
+
const validation = validationDataset && (await this.#evaluate(validationDataset));
|
|
40
38
|
return new EpochLogs(batchesLogs, validation);
|
|
41
39
|
}
|
|
42
40
|
async #runBatch(batch) {
|
|
43
|
-
|
|
44
|
-
await this.model.
|
|
41
|
+
const { xs, ys } = this.#batchToTF(batch);
|
|
42
|
+
const { history } = await this.model.fit(xs, ys, {
|
|
45
43
|
epochs: 1,
|
|
46
44
|
verbose: 0, // don't pollute
|
|
47
|
-
callbacks: {
|
|
48
|
-
onEpochEnd: (_, cur) => {
|
|
49
|
-
logs = cur;
|
|
50
|
-
},
|
|
51
|
-
},
|
|
52
45
|
});
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
46
|
+
const { loss: losses, acc: accuracies } = history;
|
|
47
|
+
if (losses === undefined ||
|
|
48
|
+
accuracies === undefined ||
|
|
49
|
+
typeof losses[0] !== "number" ||
|
|
50
|
+
typeof accuracies[0] !== "number" ||
|
|
51
|
+
isNaN(losses[0]) ||
|
|
52
|
+
isNaN(accuracies[0]))
|
|
53
|
+
throw new Error("training loss or accuracy is undefined or NaN");
|
|
58
54
|
return {
|
|
59
|
-
accuracy,
|
|
60
|
-
loss,
|
|
55
|
+
accuracy: accuracies[0],
|
|
56
|
+
loss: losses[0],
|
|
61
57
|
memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
|
|
62
58
|
};
|
|
63
59
|
}
|
|
64
60
|
async #evaluate(dataset) {
|
|
65
|
-
const evaluation = await this.model.evaluateDataset(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
case undefined:
|
|
69
|
-
throw new Error("nullish value in dataset");
|
|
70
|
-
default:
|
|
71
|
-
return t;
|
|
72
|
-
}
|
|
73
|
-
}));
|
|
61
|
+
const evaluation = await this.model.evaluateDataset(tf.data.generator(async function* () {
|
|
62
|
+
yield* dataset.map((batch) => this.#batchToTF(batch));
|
|
63
|
+
}.bind(this)));
|
|
74
64
|
const metricToValue = Map(List(this.model.metricsNames).zip(Array.isArray(evaluation)
|
|
75
65
|
? List(await Promise.all(evaluation.map((t) => t.data())))
|
|
76
66
|
: List.of(await evaluation.data()))).map((values) => {
|
|
@@ -87,16 +77,39 @@ export class TFJS extends Model {
|
|
|
87
77
|
throw new Error("some needed metrics are missing");
|
|
88
78
|
return { accuracy, loss };
|
|
89
79
|
}
|
|
90
|
-
predict(
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
80
|
+
async predict(batch) {
|
|
81
|
+
async function cleanupPredicted(y) {
|
|
82
|
+
if (y.shape[0] === 1) {
|
|
83
|
+
// Binary classification
|
|
84
|
+
const threshold = tf.scalar(0.5);
|
|
85
|
+
const binaryTensor = y.greaterEqual(threshold);
|
|
86
|
+
const binaryArray = await binaryTensor.data();
|
|
87
|
+
tf.dispose([y, binaryTensor, threshold]);
|
|
88
|
+
return binaryArray[0];
|
|
89
|
+
}
|
|
90
|
+
// Multi-class classification
|
|
91
|
+
const indexTensor = y.argMax();
|
|
92
|
+
const indexArray = await indexTensor.data();
|
|
93
|
+
tf.dispose([y, indexTensor]);
|
|
94
|
+
return indexArray[0];
|
|
95
|
+
// Multi-label classification is not supported
|
|
94
96
|
}
|
|
95
|
-
|
|
97
|
+
const xs = this.#batchWithoutLabelToTF(batch);
|
|
98
|
+
const prediction = this.model.predict(xs);
|
|
99
|
+
if (Array.isArray(prediction))
|
|
100
|
+
throw new Error("prediction yield many Tensors but should have only returned one");
|
|
101
|
+
tf.dispose(xs);
|
|
102
|
+
if (prediction.rank !== 2)
|
|
103
|
+
throw new Error("unexpected batched prediction shape");
|
|
104
|
+
const ret = List(await Promise.all(tf.unstack(prediction).map((y) => cleanupPredicted(
|
|
105
|
+
// cast as unstack reduce by one the rank
|
|
106
|
+
y))));
|
|
107
|
+
prediction.dispose();
|
|
108
|
+
return ret;
|
|
96
109
|
}
|
|
97
|
-
static async deserialize(
|
|
98
|
-
return new this(await tf.loadLayersModel({
|
|
99
|
-
load: () => Promise.resolve(
|
|
110
|
+
static async deserialize([datatype, artifacts,]) {
|
|
111
|
+
return new this(datatype, await tf.loadLayersModel({
|
|
112
|
+
load: () => Promise.resolve(artifacts),
|
|
100
113
|
}));
|
|
101
114
|
}
|
|
102
115
|
async serialize() {
|
|
@@ -115,7 +128,7 @@ export class TFJS extends Model {
|
|
|
115
128
|
}, {
|
|
116
129
|
includeOptimizer: true // keep model compiled
|
|
117
130
|
});
|
|
118
|
-
return await ret;
|
|
131
|
+
return [this.datatype, await ret];
|
|
119
132
|
}
|
|
120
133
|
[Symbol.dispose]() {
|
|
121
134
|
this.model.dispose();
|
|
@@ -128,4 +141,53 @@ export class TFJS extends Model {
|
|
|
128
141
|
extract() {
|
|
129
142
|
return this.model;
|
|
130
143
|
}
|
|
144
|
+
#batchToTF(batch) {
|
|
145
|
+
const outputSize = tf.util.sizeFromShape(this.model.outputShape.map((dim) => {
|
|
146
|
+
if (Array.isArray(dim))
|
|
147
|
+
throw new Error("TODO support multiple outputs");
|
|
148
|
+
return dim ?? 1;
|
|
149
|
+
}));
|
|
150
|
+
switch (this.datatype) {
|
|
151
|
+
case "image": {
|
|
152
|
+
// cast as typescript doesn't reduce generic type
|
|
153
|
+
const b = batch;
|
|
154
|
+
return tf.tidy(() => ({
|
|
155
|
+
xs: tf.stack(b
|
|
156
|
+
.map(([image]) => tf.tensor3d(image.data, [image.width, image.height, 3], "float32"))
|
|
157
|
+
.toArray()),
|
|
158
|
+
ys: tf.stack(b
|
|
159
|
+
.map(([_, label]) => tf.oneHot(label, outputSize, 1, 0, "int32"))
|
|
160
|
+
.toArray()),
|
|
161
|
+
}));
|
|
162
|
+
}
|
|
163
|
+
case "tabular": {
|
|
164
|
+
// cast as typescript doesn't reduce generic type
|
|
165
|
+
const b = batch;
|
|
166
|
+
return tf.tidy(() => ({
|
|
167
|
+
xs: tf.stack(b.map(([inputs, _]) => tf.tensor1d(inputs.toArray())).toArray()),
|
|
168
|
+
ys: tf.stack(b.map(([_, output]) => tf.tensor1d([output])).toArray()),
|
|
169
|
+
}));
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
const _ = this.datatype;
|
|
173
|
+
throw new Error("should never happen");
|
|
174
|
+
}
|
|
175
|
+
#batchWithoutLabelToTF(batch) {
|
|
176
|
+
switch (this.datatype) {
|
|
177
|
+
case "image": {
|
|
178
|
+
// cast as typescript doesn't reduce generic type
|
|
179
|
+
const b = batch;
|
|
180
|
+
return tf.tidy(() => tf.stack(b
|
|
181
|
+
.map((image) => tf.tensor3d(image.data, [image.width, image.height, 3], "float32"))
|
|
182
|
+
.toArray()));
|
|
183
|
+
}
|
|
184
|
+
case "tabular": {
|
|
185
|
+
// cast as typescript doesn't reduce generic type
|
|
186
|
+
const b = batch;
|
|
187
|
+
return tf.tidy(() => tf.stack(b.map((inputs) => tf.tensor1d(inputs.toArray())).toArray()));
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
const _ = this.datatype;
|
|
191
|
+
throw new Error("should never happen");
|
|
192
|
+
}
|
|
131
193
|
}
|
|
@@ -11,4 +11,4 @@ import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
|
11
11
|
* @param task the task object specifying which tokenizer to use
|
|
12
12
|
* @returns an initialized tokenizer object
|
|
13
13
|
*/
|
|
14
|
-
export declare function getTaskTokenizer(task: Task): Promise<PreTrainedTokenizer>;
|
|
14
|
+
export declare function getTaskTokenizer(task: Task<'text'>): Promise<PreTrainedTokenizer>;
|
package/dist/privacy.js
CHANGED
|
@@ -5,7 +5,7 @@ async function frobeniusNorm(weights) {
|
|
|
5
5
|
.reduce((a, b) => a.add(b))
|
|
6
6
|
.data();
|
|
7
7
|
if (squared.length !== 1)
|
|
8
|
-
throw new Error("
|
|
8
|
+
throw new Error("unexpected weights shape");
|
|
9
9
|
return Math.sqrt(squared[0]);
|
|
10
10
|
}
|
|
11
11
|
/** Scramble weights */
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import { Image } from "../index.js";
|
|
2
|
+
/** Image where intensity is represented in the range 0..1 */
|
|
3
|
+
export declare class NormalizedImage<D extends 1 | 3 | 4 = 1 | 3 | 4, W extends number = number, H extends number = number> {
|
|
4
|
+
readonly data: Readonly<Float32Array>;
|
|
5
|
+
readonly width: W;
|
|
6
|
+
readonly height: H;
|
|
7
|
+
readonly depth: D;
|
|
8
|
+
private constructor();
|
|
9
|
+
static from<D extends 1 | 3 | 4 = 1 | 3 | 4, W extends number = number, H extends number = number>(image: Image<D, W, H>): NormalizedImage<D, W, H>;
|
|
10
|
+
}
|
|
11
|
+
/** Remove the alpha channel of an image */
|
|
12
|
+
export declare function removeAlpha<W extends number, H extends number>(image: Image<4, W, H>): Image<3, W, H>;
|
|
13
|
+
export declare function removeAlpha<D extends 1 | 3, W extends number, H extends number>(image: Image<D | 4, W, H>): Image<D, W, H>;
|
|
14
|
+
/** Convert monochrome images to multicolor */
|
|
15
|
+
export declare function expandToMulticolor<W extends number, H extends number>(image: Image<1, W, H>): Image<3, W, H>;
|
|
16
|
+
export declare function expandToMulticolor<D extends 3 | 4, W extends number, H extends number>(image: Image<1 | D, W, H>): Image<D, W, H>;
|
|
17
|
+
export declare function resize<D extends 1 | 3 | 4, W extends number, H extends number>(width: W, height: H, image: Image<D, number, number>): Image<4, W, H>;
|
|
18
|
+
export declare function normalize<D extends 1 | 3 | 4, W extends number, H extends number>(image: Image<D, W, H>): NormalizedImage<D, W, H>;
|