@epfml/discojs 3.0.1-p20240821133014.0 → 3.0.1-p20240826092658.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/dataset/data/data.d.ts +6 -7
- package/dist/dataset/data/data.js +12 -7
- package/dist/dataset/data/helpers.d.ts +10 -0
- package/dist/dataset/data/helpers.js +97 -0
- package/dist/dataset/data/image_data.d.ts +3 -3
- package/dist/dataset/data/image_data.js +7 -2
- package/dist/dataset/data/index.d.ts +0 -1
- package/dist/dataset/data/preprocessing/text_preprocessing.js +23 -9
- package/dist/dataset/data/tabular_data.d.ts +3 -3
- package/dist/dataset/data/text_data.d.ts +3 -3
- package/dist/dataset/dataset.d.ts +48 -5
- package/dist/dataset/dataset.js +155 -1
- package/dist/dataset/image.d.ts +14 -0
- package/dist/dataset/image.js +21 -0
- package/dist/dataset/index.d.ts +3 -5
- package/dist/dataset/index.js +3 -3
- package/dist/dataset/types.d.ts +4 -0
- package/dist/dataset/types.js +2 -0
- package/dist/index.d.ts +4 -0
- package/dist/index.js +4 -0
- package/dist/models/gpt/model.js +2 -0
- package/dist/models/model.d.ts +1 -2
- package/dist/models/tfjs.d.ts +4 -4
- package/dist/models/tfjs.js +2 -1
- package/dist/processing.d.ts +35 -0
- package/dist/processing.js +89 -0
- package/dist/training/disco.d.ts +7 -7
- package/dist/training/disco.js +21 -19
- package/dist/types.d.ts +3 -0
- package/dist/types.js +1 -0
- package/dist/validation/validator.d.ts +7 -23
- package/dist/validation/validator.js +99 -105
- package/package.json +1 -1
- package/dist/dataset/data_loader/data_loader.d.ts +0 -13
- package/dist/dataset/data_loader/data_loader.js +0 -2
- package/dist/dataset/data_loader/image_loader.d.ts +0 -21
- package/dist/dataset/data_loader/image_loader.js +0 -101
- package/dist/dataset/data_loader/index.d.ts +0 -5
- package/dist/dataset/data_loader/index.js +0 -4
- package/dist/dataset/data_loader/tabular_loader.d.ts +0 -35
- package/dist/dataset/data_loader/tabular_loader.js +0 -76
- package/dist/dataset/data_loader/text_loader.d.ts +0 -14
- package/dist/dataset/data_loader/text_loader.js +0 -25
- package/dist/dataset/dataset_builder.d.ts +0 -51
- package/dist/dataset/dataset_builder.js +0 -118
package/dist/models/tfjs.d.ts
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
2
|
import { WeightsContainer } from '../index.js';
|
|
3
|
-
import
|
|
4
|
-
import { BatchLogs, EpochLogs } from './index.js';
|
|
3
|
+
import { BatchLogs } from './index.js';
|
|
5
4
|
import { Model } from './index.js';
|
|
6
|
-
import
|
|
5
|
+
import { Prediction, Sample } from './model.js';
|
|
6
|
+
import { EpochLogs } from './logs.js';
|
|
7
7
|
/** TensorFlow JavaScript model with standard training */
|
|
8
8
|
export declare class TFJS extends Model {
|
|
9
9
|
#private;
|
|
@@ -12,7 +12,7 @@ export declare class TFJS extends Model {
|
|
|
12
12
|
constructor(model: tf.LayersModel);
|
|
13
13
|
get weights(): WeightsContainer;
|
|
14
14
|
set weights(ws: WeightsContainer);
|
|
15
|
-
train(trainingData: Dataset
|
|
15
|
+
train(trainingData: tf.data.Dataset<tf.TensorContainer>, validationData?: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
16
16
|
predict(input: Sample): Promise<Prediction>;
|
|
17
17
|
static deserialize(raw: tf.io.ModelArtifacts): Promise<Model>;
|
|
18
18
|
serialize(): Promise<tf.io.ModelArtifacts>;
|
package/dist/models/tfjs.js
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import { List, Map } from 'immutable';
|
|
2
2
|
import * as tf from '@tensorflow/tfjs';
|
|
3
3
|
import { WeightsContainer } from '../index.js';
|
|
4
|
-
import { EpochLogs } from './index.js';
|
|
5
4
|
import { Model } from './index.js';
|
|
5
|
+
import { EpochLogs } from './logs.js';
|
|
6
6
|
/** TensorFlow JavaScript model with standard training */
|
|
7
7
|
export class TFJS extends Model {
|
|
8
8
|
model;
|
|
@@ -78,6 +78,7 @@ export class TFJS extends Model {
|
|
|
78
78
|
throw new Error("more than one metric value");
|
|
79
79
|
return values[0];
|
|
80
80
|
});
|
|
81
|
+
tf.dispose(evaluation);
|
|
81
82
|
const [accuracy, loss] = [
|
|
82
83
|
metricToValue.get("acc"),
|
|
83
84
|
metricToValue.get("loss"),
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
/** Dataset shapers, convenient to map with */
|
|
2
|
+
import { PreTrainedTokenizer } from "@xenova/transformers";
|
|
3
|
+
import { List } from "immutable";
|
|
4
|
+
import { Image } from "./dataset/image.js";
|
|
5
|
+
/**
|
|
6
|
+
* Convert a string to a number
|
|
7
|
+
*
|
|
8
|
+
* @throws if it isn't written as a number
|
|
9
|
+
*/
|
|
10
|
+
export declare function convertToNumber(raw: string): number;
|
|
11
|
+
/**
|
|
12
|
+
* Return the named field of an object with string values
|
|
13
|
+
*
|
|
14
|
+
* @throws if the named field isn't there
|
|
15
|
+
*/
|
|
16
|
+
export declare function extractColumn(row: Partial<Record<string, string>>, column: string): string;
|
|
17
|
+
/**
|
|
18
|
+
* Return the index of the element in the given list
|
|
19
|
+
*
|
|
20
|
+
* @throws if not found
|
|
21
|
+
*/
|
|
22
|
+
export declare function indexInList(element: string, elements: List<string>): number;
|
|
23
|
+
/**
|
|
24
|
+
* Tokenize and truncates input strings
|
|
25
|
+
*
|
|
26
|
+
* @param length number of tokens
|
|
27
|
+
* @returns encoded string in an array of token, size of max_length
|
|
28
|
+
*/
|
|
29
|
+
export declare function tokenizeAndLeftPad(line: string, tokenizer: PreTrainedTokenizer, length: number): number[];
|
|
30
|
+
/** Remove the alpha channel of an image */
|
|
31
|
+
export declare function removeAlpha<W extends number, H extends number>(image: Image<4, W, H>): Image<3, W, H>;
|
|
32
|
+
export declare function removeAlpha<D extends 1 | 3, W extends number, H extends number>(image: Image<D | 4, W, H>): Image<D, W, H>;
|
|
33
|
+
/** Convert monochrome images to multicolor */
|
|
34
|
+
export declare function expandToMulticolor<W extends number, H extends number>(image: Image<1, W, H>): Image<3, W, H>;
|
|
35
|
+
export declare function expandToMulticolor<D extends 3 | 4, W extends number, H extends number>(image: Image<1 | D, W, H>): Image<D, W, H>;
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
/** Dataset shapers, convenient to map with */
|
|
2
|
+
import { Repeat, Seq } from "immutable";
|
|
3
|
+
import { Image } from "./dataset/image.js";
|
|
4
|
+
/**
|
|
5
|
+
* Convert a string to a number
|
|
6
|
+
*
|
|
7
|
+
* @throws if it isn't written as a number
|
|
8
|
+
*/
|
|
9
|
+
export function convertToNumber(raw) {
|
|
10
|
+
const num = Number.parseFloat(raw);
|
|
11
|
+
if (Number.isNaN(num))
|
|
12
|
+
throw new Error(`unable to parse "${raw}" as number`);
|
|
13
|
+
return num;
|
|
14
|
+
}
|
|
15
|
+
/**
|
|
16
|
+
* Return the named field of an object with string values
|
|
17
|
+
*
|
|
18
|
+
* @throws if the named field isn't there
|
|
19
|
+
*/
|
|
20
|
+
export function extractColumn(row, column) {
|
|
21
|
+
const raw = row[column];
|
|
22
|
+
if (raw === undefined)
|
|
23
|
+
throw new Error(`${column} not found in row`);
|
|
24
|
+
return raw;
|
|
25
|
+
}
|
|
26
|
+
/**
|
|
27
|
+
* Return the index of the element in the given list
|
|
28
|
+
*
|
|
29
|
+
* @throws if not found
|
|
30
|
+
*/
|
|
31
|
+
export function indexInList(element, elements) {
|
|
32
|
+
const ret = elements.indexOf(element);
|
|
33
|
+
if (ret === -1)
|
|
34
|
+
throw new Error(`${element} not found in list`);
|
|
35
|
+
return ret;
|
|
36
|
+
}
|
|
37
|
+
function isArrayOfNumber(raw) {
|
|
38
|
+
return Array.isArray(raw) && raw.every((e) => typeof e === "number");
|
|
39
|
+
}
|
|
40
|
+
/**
|
|
41
|
+
* Tokenize and truncates input strings
|
|
42
|
+
*
|
|
43
|
+
* @param length number of tokens
|
|
44
|
+
* @returns encoded string in an array of token, size of max_length
|
|
45
|
+
*/
|
|
46
|
+
export function tokenizeAndLeftPad(line, tokenizer, length) {
|
|
47
|
+
if (!Number.isInteger(length))
|
|
48
|
+
throw new Error("length should be an integer");
|
|
49
|
+
// Transformers.js currently only supports right padding while we need left for text generation
|
|
50
|
+
// Right padding should be supported in the future, once it is, we can directly pad while tokenizing
|
|
51
|
+
// https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
|
|
52
|
+
const tokenized = tokenizer(line, {
|
|
53
|
+
padding: false,
|
|
54
|
+
truncation: true,
|
|
55
|
+
return_tensor: false,
|
|
56
|
+
max_length: length,
|
|
57
|
+
});
|
|
58
|
+
if (typeof tokenized !== "object" ||
|
|
59
|
+
tokenized === null ||
|
|
60
|
+
!("input_ids" in tokenized) ||
|
|
61
|
+
!isArrayOfNumber(tokenized.input_ids))
|
|
62
|
+
throw new Error("tokenizer returns unexcepted type");
|
|
63
|
+
const tokens = tokenized.input_ids;
|
|
64
|
+
const paddingSize = length - tokens.length;
|
|
65
|
+
if (paddingSize < 0)
|
|
66
|
+
throw new Error("tokenized returned more token than excepted");
|
|
67
|
+
const padding = new Array(paddingSize);
|
|
68
|
+
padding.fill(tokenizer.pad_token_id);
|
|
69
|
+
const padded = padding.concat(tokens);
|
|
70
|
+
return padded;
|
|
71
|
+
}
|
|
72
|
+
export function removeAlpha(image) {
|
|
73
|
+
switch (image.depth) {
|
|
74
|
+
case 1:
|
|
75
|
+
case 3:
|
|
76
|
+
return new Image(image.data, image.width, image.height, image.depth);
|
|
77
|
+
case 4:
|
|
78
|
+
return new Image(image.data.filter((_, i) => i % 4 !== 3), image.width, image.height, 3);
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
export function expandToMulticolor(image) {
|
|
82
|
+
switch (image.depth) {
|
|
83
|
+
case 1:
|
|
84
|
+
return new Image(Uint8Array.from(Seq(image.data).flatMap((v) => Repeat(v, 3))), image.width, image.height, 3);
|
|
85
|
+
case 3:
|
|
86
|
+
case 4:
|
|
87
|
+
return new Image(image.data, image.width, image.height, image.depth);
|
|
88
|
+
}
|
|
89
|
+
}
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
1
|
+
import { client as clients, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
|
|
2
|
+
import type { TypedLabeledDataset } from "../index.js";
|
|
3
3
|
import type { Aggregator } from "../aggregator/index.js";
|
|
4
4
|
import { RoundLogs, Trainer } from "./trainer.js";
|
|
5
5
|
interface Config {
|
|
@@ -28,20 +28,20 @@ export declare class Disco {
|
|
|
28
28
|
url: URL;
|
|
29
29
|
}, config: Partial<Config>): Promise<Disco>;
|
|
30
30
|
/** Train on dataset, yielding logs of every round. */
|
|
31
|
-
trainByRound(
|
|
31
|
+
trainByRound(dataset: TypedLabeledDataset): AsyncGenerator<RoundLogs>;
|
|
32
32
|
/** Train on dataset, yielding logs of every epoch. */
|
|
33
|
-
trainByEpoch(
|
|
33
|
+
trainByEpoch(dataset: TypedLabeledDataset): AsyncGenerator<EpochLogs>;
|
|
34
34
|
/** Train on dataset, yielding logs of every batch. */
|
|
35
|
-
trainByBatch(dataTuple:
|
|
35
|
+
trainByBatch(dataTuple: TypedLabeledDataset): AsyncGenerator<BatchLogs>;
|
|
36
36
|
/** Run whole train on dataset. */
|
|
37
|
-
trainFully(dataTuple:
|
|
37
|
+
trainFully(dataTuple: TypedLabeledDataset): Promise<void>;
|
|
38
38
|
/**
|
|
39
39
|
* Train on dataset, yield the nested steps.
|
|
40
40
|
*
|
|
41
41
|
* Don't forget to await the yielded generator otherwise nothing will progress.
|
|
42
42
|
* If you don't care about the whole process, use one of the other train methods.
|
|
43
43
|
**/
|
|
44
|
-
train(
|
|
44
|
+
train(dataset: TypedLabeledDataset): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
|
|
45
45
|
/**
|
|
46
46
|
* Stops the ongoing training instance without disconnecting the client.
|
|
47
47
|
*/
|
package/dist/training/disco.js
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
import { async_iterator, } from "../index.js";
|
|
2
|
-
import { client as clients, ConsoleLogger, EmptyMemory } from "../index.js";
|
|
1
|
+
import { async_iterator, client as clients, ConsoleLogger, EmptyMemory, } from "../index.js";
|
|
3
2
|
import { getAggregator } from "../aggregator/index.js";
|
|
4
3
|
import { enumerate, split } from "../utils/async_iterator.js";
|
|
5
4
|
import { Trainer } from "./trainer.js";
|
|
5
|
+
import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
|
|
6
6
|
/**
|
|
7
7
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
8
8
|
* a convenient object providing a reduced yet complete API that wraps model training,
|
|
@@ -12,18 +12,14 @@ export class Disco {
|
|
|
12
12
|
trainer;
|
|
13
13
|
#client;
|
|
14
14
|
#logger;
|
|
15
|
-
|
|
16
|
-
#
|
|
15
|
+
#memory;
|
|
16
|
+
#task;
|
|
17
17
|
constructor(trainer, task, client, memory, logger) {
|
|
18
18
|
this.trainer = trainer;
|
|
19
19
|
this.#client = client;
|
|
20
20
|
this.#logger = logger;
|
|
21
|
-
this.#
|
|
22
|
-
|
|
23
|
-
taskID: task.id,
|
|
24
|
-
name: task.trainingInformation.modelID,
|
|
25
|
-
tensorBackend: task.trainingInformation.tensorBackend,
|
|
26
|
-
}, this.trainer.model);
|
|
21
|
+
this.#memory = memory;
|
|
22
|
+
this.#task = task;
|
|
27
23
|
}
|
|
28
24
|
/**
|
|
29
25
|
* Connect to the given task and get ready to train.
|
|
@@ -70,8 +66,8 @@ export class Disco {
|
|
|
70
66
|
return new Disco(new Trainer(task, model, client), task, client, memory, logger);
|
|
71
67
|
}
|
|
72
68
|
/** Train on dataset, yielding logs of every round. */
|
|
73
|
-
async *trainByRound(
|
|
74
|
-
for await (const round of this.train(
|
|
69
|
+
async *trainByRound(dataset) {
|
|
70
|
+
for await (const round of this.train(dataset)) {
|
|
75
71
|
const [roundGen, roundLogs] = async_iterator.split(round);
|
|
76
72
|
for await (const epoch of roundGen)
|
|
77
73
|
for await (const _ of epoch)
|
|
@@ -80,8 +76,8 @@ export class Disco {
|
|
|
80
76
|
}
|
|
81
77
|
}
|
|
82
78
|
/** Train on dataset, yielding logs of every epoch. */
|
|
83
|
-
async *trainByEpoch(
|
|
84
|
-
for await (const round of this.train(
|
|
79
|
+
async *trainByEpoch(dataset) {
|
|
80
|
+
for await (const round of this.train(dataset)) {
|
|
85
81
|
for await (const epoch of round) {
|
|
86
82
|
const [epochGen, epochLogs] = async_iterator.split(epoch);
|
|
87
83
|
for await (const _ of epochGen)
|
|
@@ -109,12 +105,13 @@ export class Disco {
|
|
|
109
105
|
* Don't forget to await the yielded generator otherwise nothing will progress.
|
|
110
106
|
* If you don't care about the whole process, use one of the other train methods.
|
|
111
107
|
**/
|
|
112
|
-
async *train(
|
|
108
|
+
async *train(dataset) {
|
|
113
109
|
this.#logger.success("Training started.");
|
|
114
|
-
const
|
|
115
|
-
const
|
|
110
|
+
const data = await labeledDatasetToDataSplit(this.#task, dataset);
|
|
111
|
+
const trainData = data.train.preprocess().batch().dataset;
|
|
112
|
+
const validationData = data.validation?.preprocess().batch().dataset ?? trainData;
|
|
116
113
|
await this.#client.connect();
|
|
117
|
-
for await (const [round, epochs] of enumerate(this.trainer.train(trainData
|
|
114
|
+
for await (const [round, epochs] of enumerate(this.trainer.train(trainData, validationData))) {
|
|
118
115
|
yield async function* () {
|
|
119
116
|
const [gen, returnedRoundLogs] = split(epochs);
|
|
120
117
|
for await (const [epoch, batches] of enumerate(gen)) {
|
|
@@ -136,7 +133,12 @@ export class Disco {
|
|
|
136
133
|
}
|
|
137
134
|
return await returnedRoundLogs;
|
|
138
135
|
}.bind(this)();
|
|
139
|
-
await this.#updateWorkingModel(
|
|
136
|
+
await this.#memory.updateWorkingModel({
|
|
137
|
+
type: "working",
|
|
138
|
+
taskID: this.#task.id,
|
|
139
|
+
name: this.#task.trainingInformation.modelID,
|
|
140
|
+
tensorBackend: this.#task.trainingInformation.tensorBackend,
|
|
141
|
+
}, this.trainer.model);
|
|
140
142
|
}
|
|
141
143
|
this.#logger.success("Training finished.");
|
|
142
144
|
}
|
package/dist/types.d.ts
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
import { Dataset, Image, Tabular, Text } from "./dataset/index.js";
|
|
2
|
+
export type TypedDataset = ["image", Dataset<Image>] | ["tabular", Dataset<Tabular>] | ["text", Dataset<Text>];
|
|
3
|
+
export type TypedLabeledDataset = ["image", Dataset<[Image, label: string]>] | ["tabular", Dataset<Tabular>] | ["text", Dataset<Text>];
|
package/dist/types.js
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -1,26 +1,10 @@
|
|
|
1
|
-
import type {
|
|
1
|
+
import type { Model, Task, TypedDataset, TypedLabeledDataset } from "../index.js";
|
|
2
2
|
export declare class Validator {
|
|
3
|
+
#private;
|
|
3
4
|
readonly task: Task;
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
private _confusionMatrix;
|
|
10
|
-
private rollingAccuracy;
|
|
11
|
-
constructor(task: Task, logger: Logger, memory: Memory, source?: ModelSource | undefined, client?: clients.Client | undefined);
|
|
12
|
-
private getLabel;
|
|
13
|
-
test(data: data.Data): AsyncGenerator<Array<{
|
|
14
|
-
groundTruth: number;
|
|
15
|
-
pred: number;
|
|
16
|
-
features: number[];
|
|
17
|
-
}>, void>;
|
|
18
|
-
inference(data: data.Data): AsyncGenerator<Array<{
|
|
19
|
-
features: number[];
|
|
20
|
-
pred: number;
|
|
21
|
-
}>, void>;
|
|
22
|
-
getModel(): Promise<Model>;
|
|
23
|
-
get accuracy(): number;
|
|
24
|
-
get visitedSamples(): number;
|
|
25
|
-
get confusionMatrix(): number[][] | undefined;
|
|
5
|
+
constructor(task: Task, model: Model);
|
|
6
|
+
/** infer every line of the dataset and check that it is as labeled */
|
|
7
|
+
test(dataset: TypedLabeledDataset): AsyncGenerator<boolean>;
|
|
8
|
+
/** use the model to predict every line of the dataset */
|
|
9
|
+
infer(dataset: TypedDataset): AsyncGenerator<number, void>;
|
|
26
10
|
}
|
|
@@ -1,119 +1,113 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { datasetToData, labeledDatasetToData, } from "../dataset/data/helpers.js";
|
|
3
|
+
function intoTFDataset(iter) {
|
|
4
|
+
// @ts-expect-error generator
|
|
5
|
+
return tf.data.generator(async function* () {
|
|
6
|
+
yield* iter;
|
|
7
|
+
});
|
|
8
|
+
}
|
|
3
9
|
export class Validator {
|
|
4
10
|
task;
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
source;
|
|
8
|
-
client;
|
|
9
|
-
size = 0;
|
|
10
|
-
_confusionMatrix;
|
|
11
|
-
rollingAccuracy = 0;
|
|
12
|
-
constructor(task, logger, memory, source, client) {
|
|
11
|
+
#model;
|
|
12
|
+
constructor(task, model) {
|
|
13
13
|
this.task = task;
|
|
14
|
-
this
|
|
15
|
-
this.memory = memory;
|
|
16
|
-
this.source = source;
|
|
17
|
-
this.client = client;
|
|
18
|
-
if (source === undefined && client === undefined) {
|
|
19
|
-
throw new Error('To initialize a Validator, either or both a source and client need to be specified');
|
|
20
|
-
}
|
|
14
|
+
this.#model = model;
|
|
21
15
|
}
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
16
|
+
/** infer every line of the dataset and check that it is as labeled */
|
|
17
|
+
async *test(dataset) {
|
|
18
|
+
const preprocessed = (await labeledDatasetToData(this.task, dataset)).preprocess();
|
|
19
|
+
const batched = preprocessed.batch().dataset;
|
|
20
|
+
const iterator = await tf.data
|
|
21
|
+
.zip([
|
|
22
|
+
preprocessed.dataset.map((t) => {
|
|
23
|
+
if (typeof t !== "object" ||
|
|
24
|
+
!("ys" in t) ||
|
|
25
|
+
!(t.ys instanceof tf.Tensor) ||
|
|
26
|
+
!(t.ys.rank === 1 || t.ys.rank === 2))
|
|
27
|
+
throw new Error("unexpected preprocessed dataset");
|
|
28
|
+
if ("xs" in t)
|
|
29
|
+
tf.dispose(t.xs);
|
|
30
|
+
return t.ys;
|
|
31
|
+
}),
|
|
32
|
+
intoTFDataset(this.#inferOnBatchedData(batched)),
|
|
33
|
+
])
|
|
34
|
+
.iterator();
|
|
35
|
+
for (let iter = await iterator.next(); iter.done !== true; iter = await iterator.next()) {
|
|
36
|
+
const zipped = iter.value;
|
|
37
|
+
const label = await getLabel(zipped[0]);
|
|
38
|
+
tf.dispose(zipped[0]);
|
|
39
|
+
const infered = zipped[1];
|
|
40
|
+
yield label === infered;
|
|
37
41
|
}
|
|
38
|
-
// Multi-label classification is not supported
|
|
39
42
|
}
|
|
40
|
-
|
|
41
|
-
async *
|
|
42
|
-
const
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
}
|
|
46
|
-
const model = await this.getModel();
|
|
47
|
-
let hits = 0;
|
|
48
|
-
const iterator = await data.preprocess().dataset.batch(batchSize).iterator();
|
|
49
|
-
let next = await iterator.next();
|
|
50
|
-
while (next.done !== true) {
|
|
51
|
-
const { xs, ys } = next.value;
|
|
52
|
-
const ysLabel = await this.getLabel(ys);
|
|
53
|
-
const yPredTensor = await model.predict(xs);
|
|
54
|
-
const pred = await this.getLabel(yPredTensor);
|
|
55
|
-
const currentFeatures = await xs.array();
|
|
56
|
-
this.size += ysLabel.length;
|
|
57
|
-
hits += List(pred).zip(List(ysLabel)).filter(([p, y]) => p === y).size;
|
|
58
|
-
this.rollingAccuracy = hits / this.size;
|
|
59
|
-
tf.dispose([xs, ys, yPredTensor]);
|
|
60
|
-
yield (List(ysLabel).zip(List(pred), List(currentFeatures)))
|
|
61
|
-
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
|
|
62
|
-
.toArray();
|
|
63
|
-
next = await iterator.next();
|
|
64
|
-
}
|
|
65
|
-
this.logger.success(`Obtained validation accuracy of ${this.accuracy}`);
|
|
66
|
-
this.logger.success(`Visited ${this.visitedSamples} samples`);
|
|
43
|
+
/** use the model to predict every line of the dataset */
|
|
44
|
+
async *infer(dataset) {
|
|
45
|
+
const data = await datasetToData(this.task, dataset);
|
|
46
|
+
const batched = data.preprocess().batch().dataset;
|
|
47
|
+
yield* this.#inferOnBatchedData(batched);
|
|
67
48
|
}
|
|
68
|
-
async
|
|
69
|
-
const
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
49
|
+
async *#inferOnBatchedData(batched) {
|
|
50
|
+
const iterator = await batched.iterator();
|
|
51
|
+
for (let iter = await iterator.next(); iter.done !== true; iter = await iterator.next()) {
|
|
52
|
+
const row = iter.value;
|
|
53
|
+
if (typeof row !== "object" ||
|
|
54
|
+
!("xs" in row) ||
|
|
55
|
+
!(row.xs instanceof tf.Tensor))
|
|
56
|
+
throw new Error("unexpected shape of dataset");
|
|
57
|
+
const prediction = await this.#model.predict(row.xs);
|
|
58
|
+
tf.dispose(row);
|
|
59
|
+
let predictions;
|
|
60
|
+
switch (prediction.rank) {
|
|
61
|
+
case 2:
|
|
62
|
+
case 3:
|
|
63
|
+
predictions = await getLabels(
|
|
64
|
+
// cast as rank was just checked
|
|
65
|
+
prediction);
|
|
66
|
+
prediction.dispose();
|
|
67
|
+
break;
|
|
68
|
+
default:
|
|
69
|
+
throw new Error("unexpected batched prediction shape");
|
|
85
70
|
}
|
|
86
|
-
|
|
87
|
-
const
|
|
88
|
-
|
|
89
|
-
this.size += pred.length;
|
|
90
|
-
if (!Array.isArray(currentFeatures)) {
|
|
91
|
-
throw new TypeError('Data format is incorrect');
|
|
92
|
-
}
|
|
93
|
-
tf.dispose([xs, yPredTensor]);
|
|
94
|
-
yield List(currentFeatures).zip(List(pred))
|
|
95
|
-
.map(([f, p]) => ({ features: f, pred: p }))
|
|
96
|
-
.toArray();
|
|
97
|
-
next = await iterator.next();
|
|
71
|
+
prediction.dispose();
|
|
72
|
+
for (const prediction of predictions)
|
|
73
|
+
yield prediction;
|
|
98
74
|
}
|
|
99
|
-
this.logger.success(`Visited ${this.visitedSamples} samples`);
|
|
100
75
|
}
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
76
|
+
}
|
|
77
|
+
async function getLabels(ys) {
|
|
78
|
+
// cast as unstack drop a dimension and tfjs doesn't type correctly
|
|
79
|
+
return Promise.all(tf.unstack(ys).map((y) => {
|
|
80
|
+
const ret = getLabel(y);
|
|
81
|
+
y.dispose();
|
|
82
|
+
return ret;
|
|
83
|
+
}));
|
|
84
|
+
}
|
|
85
|
+
async function getLabel(ys) {
|
|
86
|
+
switch (ys.rank) {
|
|
87
|
+
case 1: {
|
|
88
|
+
if (ys.shape[0] == 1) {
|
|
89
|
+
// Binary classification
|
|
90
|
+
const threshold = tf.scalar(0.5);
|
|
91
|
+
const binaryTensor = ys.greaterEqual(threshold);
|
|
92
|
+
const binaryArray = await binaryTensor.data();
|
|
93
|
+
tf.dispose([binaryTensor, threshold]);
|
|
94
|
+
return binaryArray[0];
|
|
95
|
+
}
|
|
96
|
+
// Multi-class classification
|
|
97
|
+
const indexTensor = ys.argMax();
|
|
98
|
+
const indexArray = await indexTensor.data();
|
|
99
|
+
tf.dispose([indexTensor]);
|
|
100
|
+
return indexArray[0];
|
|
101
|
+
// Multi-label classification is not supported
|
|
104
102
|
}
|
|
105
|
-
|
|
106
|
-
|
|
103
|
+
case 2: {
|
|
104
|
+
// it's LLM, we only extract the next token
|
|
105
|
+
const firstToken = tf.tidy(() => ys.gather([0]).squeeze().argMax());
|
|
106
|
+
const raw = await firstToken.data();
|
|
107
|
+
firstToken.dispose();
|
|
108
|
+
return raw[0];
|
|
107
109
|
}
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
get accuracy() {
|
|
111
|
-
return this.rollingAccuracy;
|
|
112
|
-
}
|
|
113
|
-
get visitedSamples() {
|
|
114
|
-
return this.size;
|
|
115
|
-
}
|
|
116
|
-
get confusionMatrix() {
|
|
117
|
-
return this._confusionMatrix;
|
|
110
|
+
default:
|
|
111
|
+
throw new Error("unexpected tensor rank");
|
|
118
112
|
}
|
|
119
113
|
}
|
package/package.json
CHANGED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
import type { DataSplit, Dataset } from '../index.js';
|
|
2
|
-
export interface DataConfig {
|
|
3
|
-
features?: string[];
|
|
4
|
-
labels?: string[];
|
|
5
|
-
shuffle?: boolean;
|
|
6
|
-
validationSplit?: number;
|
|
7
|
-
inference?: boolean;
|
|
8
|
-
channels?: number;
|
|
9
|
-
}
|
|
10
|
-
export declare abstract class DataLoader<Source> {
|
|
11
|
-
abstract load(source: Source, config: DataConfig): Promise<Dataset>;
|
|
12
|
-
abstract loadAll(sources: Source[], config: DataConfig): Promise<DataSplit>;
|
|
13
|
-
}
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import type { Task } from '../../index.js';
|
|
3
|
-
import type { Dataset, DataSplit } from '../index.js';
|
|
4
|
-
import type { DataConfig } from '../data_loader/index.js';
|
|
5
|
-
import { DataLoader } from '../data_loader/index.js';
|
|
6
|
-
/**
|
|
7
|
-
* Image data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
|
|
8
|
-
* @epfml/discojs-web and @epfml/discojs-node.
|
|
9
|
-
* Load labels and correctly match them with their respective images, with the following constraints:
|
|
10
|
-
* 1. Images are given as 1 image/1 file;
|
|
11
|
-
* 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels.
|
|
12
|
-
*/
|
|
13
|
-
export declare abstract class ImageLoader<Source> extends DataLoader<Source> {
|
|
14
|
-
private readonly task;
|
|
15
|
-
abstract readImageFrom(source: Source, channels?: number): Promise<tf.Tensor3D>;
|
|
16
|
-
constructor(task: Task);
|
|
17
|
-
load(image: Source, config?: DataConfig): Promise<Dataset>;
|
|
18
|
-
private buildDataset;
|
|
19
|
-
loadAll(images: Source[], config?: DataConfig): Promise<DataSplit>;
|
|
20
|
-
shuffle(array: number[]): void;
|
|
21
|
-
}
|