@epfml/discojs 3.0.1-p20241206154707.0 → 3.0.1-p20250204101254.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -5,4 +5,4 @@ export * from "./tabular.js";
|
|
|
5
5
|
export * from "./text.js";
|
|
6
6
|
export declare function preprocess<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.Raw[D]>): Promise<Dataset<DataFormat.ModelEncoded[D]>>;
|
|
7
7
|
export declare function preprocessWithoutLabel<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.RawWithoutLabel[D]>): Promise<Dataset<DataFormat.ModelEncoded[D][0]>>;
|
|
8
|
-
export declare function postprocess<D extends DataType>(task: Task<D>,
|
|
8
|
+
export declare function postprocess<D extends DataType>(task: Task<D>, encoded: DataFormat.ModelEncoded[D][1]): Promise<DataFormat.Inferred[D]>;
|
package/dist/processing/index.js
CHANGED
|
@@ -68,31 +68,29 @@ export async function preprocessWithoutLabel(task, dataset) {
|
|
|
68
68
|
}
|
|
69
69
|
}
|
|
70
70
|
}
|
|
71
|
-
export async function postprocess(task,
|
|
71
|
+
export async function postprocess(task, encoded) {
|
|
72
72
|
switch (task.trainingInformation.dataType) {
|
|
73
73
|
case "image": {
|
|
74
74
|
// cast as typescript doesn't reduce generic type
|
|
75
|
-
const
|
|
75
|
+
const index = encoded;
|
|
76
76
|
const { LABEL_LIST } = task.trainingInformation;
|
|
77
77
|
const labels = List(LABEL_LIST);
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
return v;
|
|
83
|
-
});
|
|
78
|
+
const v = labels.get(index);
|
|
79
|
+
if (v === undefined)
|
|
80
|
+
throw new Error("index not found in labels");
|
|
81
|
+
return v;
|
|
84
82
|
}
|
|
85
83
|
case "tabular": {
|
|
86
84
|
// cast as typescript doesn't reduce generic type
|
|
87
|
-
const
|
|
88
|
-
return
|
|
85
|
+
const v = encoded;
|
|
86
|
+
return v;
|
|
89
87
|
}
|
|
90
88
|
case "text": {
|
|
91
89
|
// cast as typescript doesn't reduce generic type
|
|
92
|
-
const
|
|
90
|
+
const token = encoded;
|
|
93
91
|
const t = task;
|
|
94
92
|
const tokenizer = await models.getTaskTokenizer(t);
|
|
95
|
-
return
|
|
93
|
+
return tokenizer.decode([token]);
|
|
96
94
|
}
|
|
97
95
|
}
|
|
98
96
|
}
|
package/dist/validator.d.ts
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import
|
|
1
|
+
import { Dataset, DataFormat, DataType, Model, Task } from "./index.js";
|
|
2
2
|
export declare class Validator<D extends DataType> {
|
|
3
3
|
#private;
|
|
4
4
|
readonly task: Task<D>;
|
|
5
5
|
constructor(task: Task<D>, model: Model<D>);
|
|
6
6
|
/** infer every line of the dataset and check that it is as labelled */
|
|
7
|
-
test(dataset: Dataset<DataFormat.Raw[D]>):
|
|
7
|
+
test(dataset: Dataset<DataFormat.Raw[D]>): Promise<Dataset<Record<"predicted" | "truth", DataFormat.Inferred[D]>>>;
|
|
8
8
|
/** use the model to predict every line of the dataset */
|
|
9
9
|
infer(dataset: Dataset<DataFormat.RawWithoutLabel[D]>): AsyncGenerator<DataFormat.Inferred[D], void>;
|
|
10
10
|
}
|
package/dist/validator.js
CHANGED
|
@@ -7,15 +7,16 @@ export class Validator {
|
|
|
7
7
|
this.#model = model;
|
|
8
8
|
}
|
|
9
9
|
/** infer every line of the dataset and check that it is as labelled */
|
|
10
|
-
async
|
|
11
|
-
const
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
.zip(batch.map(([_, outputs]) => outputs))
|
|
15
|
-
.map(([inferred, truth]) => inferred === truth))
|
|
10
|
+
async test(dataset) {
|
|
11
|
+
const preprocessed = await processing.preprocess(this.task, dataset);
|
|
12
|
+
const batched = preprocessed.batch(this.task.trainingInformation.batchSize);
|
|
13
|
+
const predictionWithTruth = batched
|
|
14
|
+
.map(async (batch) => (await this.#model.predict(batch.map(([inputs, _]) => inputs))).zip(batch.map(([_, outputs]) => outputs)))
|
|
16
15
|
.flatten();
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
return predictionWithTruth.map(async ([predicted, truth]) => ({
|
|
17
|
+
predicted: await processing.postprocess(this.task, predicted),
|
|
18
|
+
truth: await processing.postprocess(this.task, truth),
|
|
19
|
+
}));
|
|
19
20
|
}
|
|
20
21
|
/** use the model to predict every line of the dataset */
|
|
21
22
|
async *infer(dataset) {
|
|
@@ -23,7 +24,7 @@ export class Validator {
|
|
|
23
24
|
.batch(this.task.trainingInformation.batchSize)
|
|
24
25
|
.map((batch) => this.#model.predict(batch))
|
|
25
26
|
.flatten();
|
|
26
|
-
const predictions =
|
|
27
|
+
const predictions = modelPredictions.map((prediction) => processing.postprocess(this.task, prediction));
|
|
27
28
|
for await (const e of predictions)
|
|
28
29
|
yield e;
|
|
29
30
|
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@epfml/discojs",
|
|
3
|
-
"version": "3.0.1-
|
|
3
|
+
"version": "3.0.1-p20250204101254.0",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"main": "dist/index.js",
|
|
6
6
|
"types": "dist/index.d.ts",
|
|
@@ -19,13 +19,13 @@
|
|
|
19
19
|
},
|
|
20
20
|
"homepage": "https://github.com/epfml/disco#readme",
|
|
21
21
|
"dependencies": {
|
|
22
|
+
"@epfml/isomorphic-wrtc": "1",
|
|
22
23
|
"@jimp/core": "1",
|
|
23
24
|
"@jimp/plugin-resize": "1",
|
|
24
25
|
"@msgpack/msgpack": "^3.0.0-beta2",
|
|
25
26
|
"@tensorflow/tfjs": "4",
|
|
26
27
|
"@xenova/transformers": "2",
|
|
27
28
|
"immutable": "4",
|
|
28
|
-
"isomorphic-wrtc": "1",
|
|
29
29
|
"isomorphic-ws": "5",
|
|
30
30
|
"simple-peer": "9",
|
|
31
31
|
"tslib": "2",
|