@epfml/discojs 3.0.1-p20241206154707.0 → 3.0.1-p20250204101254.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,5 @@
1
1
  import { List, Map, Range, Seq } from 'immutable';
2
- import wrtc from 'isomorphic-wrtc';
2
+ import wrtc from "@epfml/isomorphic-wrtc";
3
3
  import SimplePeer from 'simple-peer';
4
4
  // message id + (chunk counter == 0) + chunk count
5
5
  const FIRST_HEADER_SIZE = 2 + 1 + 1;
@@ -5,4 +5,4 @@ export * from "./tabular.js";
5
5
  export * from "./text.js";
6
6
  export declare function preprocess<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.Raw[D]>): Promise<Dataset<DataFormat.ModelEncoded[D]>>;
7
7
  export declare function preprocessWithoutLabel<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.RawWithoutLabel[D]>): Promise<Dataset<DataFormat.ModelEncoded[D][0]>>;
8
- export declare function postprocess<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.ModelEncoded[D][1]>): Promise<Dataset<DataFormat.Inferred[D]>>;
8
+ export declare function postprocess<D extends DataType>(task: Task<D>, encoded: DataFormat.ModelEncoded[D][1]): Promise<DataFormat.Inferred[D]>;
@@ -68,31 +68,29 @@ export async function preprocessWithoutLabel(task, dataset) {
68
68
  }
69
69
  }
70
70
  }
71
- export async function postprocess(task, dataset) {
71
+ export async function postprocess(task, encoded) {
72
72
  switch (task.trainingInformation.dataType) {
73
73
  case "image": {
74
74
  // cast as typescript doesn't reduce generic type
75
- const d = dataset;
75
+ const index = encoded;
76
76
  const { LABEL_LIST } = task.trainingInformation;
77
77
  const labels = List(LABEL_LIST);
78
- return d.map((index) => {
79
- const v = labels.get(index);
80
- if (v === undefined)
81
- throw new Error("index not found in labels");
82
- return v;
83
- });
78
+ const v = labels.get(index);
79
+ if (v === undefined)
80
+ throw new Error("index not found in labels");
81
+ return v;
84
82
  }
85
83
  case "tabular": {
86
84
  // cast as typescript doesn't reduce generic type
87
- const d = dataset;
88
- return d;
85
+ const v = encoded;
86
+ return v;
89
87
  }
90
88
  case "text": {
91
89
  // cast as typescript doesn't reduce generic type
92
- const d = dataset;
90
+ const token = encoded;
93
91
  const t = task;
94
92
  const tokenizer = await models.getTaskTokenizer(t);
95
- return d.map((token) => tokenizer.decode([token]));
93
+ return tokenizer.decode([token]);
96
94
  }
97
95
  }
98
96
  }
@@ -1,10 +1,10 @@
1
- import type { Dataset, DataFormat, DataType, Model, Task } from "./index.js";
1
+ import { Dataset, DataFormat, DataType, Model, Task } from "./index.js";
2
2
  export declare class Validator<D extends DataType> {
3
3
  #private;
4
4
  readonly task: Task<D>;
5
5
  constructor(task: Task<D>, model: Model<D>);
6
6
  /** infer every line of the dataset and check that it is as labelled */
7
- test(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<boolean, void>;
7
+ test(dataset: Dataset<DataFormat.Raw[D]>): Promise<Dataset<Record<"predicted" | "truth", DataFormat.Inferred[D]>>>;
8
8
  /** use the model to predict every line of the dataset */
9
9
  infer(dataset: Dataset<DataFormat.RawWithoutLabel[D]>): AsyncGenerator<DataFormat.Inferred[D], void>;
10
10
  }
package/dist/validator.js CHANGED
@@ -7,15 +7,16 @@ export class Validator {
7
7
  this.#model = model;
8
8
  }
9
9
  /** infer every line of the dataset and check that it is as labelled */
10
- async *test(dataset) {
11
- const results = (await processing.preprocess(this.task, dataset))
12
- .batch(this.task.trainingInformation.batchSize)
13
- .map(async (batch) => (await this.#model.predict(batch.map(([inputs, _]) => inputs)))
14
- .zip(batch.map(([_, outputs]) => outputs))
15
- .map(([inferred, truth]) => inferred === truth))
10
+ async test(dataset) {
11
+ const preprocessed = await processing.preprocess(this.task, dataset);
12
+ const batched = preprocessed.batch(this.task.trainingInformation.batchSize);
13
+ const predictionWithTruth = batched
14
+ .map(async (batch) => (await this.#model.predict(batch.map(([inputs, _]) => inputs))).zip(batch.map(([_, outputs]) => outputs)))
16
15
  .flatten();
17
- for await (const e of results)
18
- yield e;
16
+ return predictionWithTruth.map(async ([predicted, truth]) => ({
17
+ predicted: await processing.postprocess(this.task, predicted),
18
+ truth: await processing.postprocess(this.task, truth),
19
+ }));
19
20
  }
20
21
  /** use the model to predict every line of the dataset */
21
22
  async *infer(dataset) {
@@ -23,7 +24,7 @@ export class Validator {
23
24
  .batch(this.task.trainingInformation.batchSize)
24
25
  .map((batch) => this.#model.predict(batch))
25
26
  .flatten();
26
- const predictions = await processing.postprocess(this.task, modelPredictions);
27
+ const predictions = modelPredictions.map((prediction) => processing.postprocess(this.task, prediction));
27
28
  for await (const e of predictions)
28
29
  yield e;
29
30
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20241206154707.0",
3
+ "version": "3.0.1-p20250204101254.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -19,13 +19,13 @@
19
19
  },
20
20
  "homepage": "https://github.com/epfml/disco#readme",
21
21
  "dependencies": {
22
+ "@epfml/isomorphic-wrtc": "1",
22
23
  "@jimp/core": "1",
23
24
  "@jimp/plugin-resize": "1",
24
25
  "@msgpack/msgpack": "^3.0.0-beta2",
25
26
  "@tensorflow/tfjs": "4",
26
27
  "@xenova/transformers": "2",
27
28
  "immutable": "4",
28
- "isomorphic-wrtc": "1",
29
29
  "isomorphic-ws": "5",
30
30
  "simple-peer": "9",
31
31
  "tslib": "2",