@epfml/discojs 3.0.1-p20240821133014.0 → 3.0.1-p20240826092658.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. package/dist/dataset/data/data.d.ts +6 -7
  2. package/dist/dataset/data/data.js +12 -7
  3. package/dist/dataset/data/helpers.d.ts +10 -0
  4. package/dist/dataset/data/helpers.js +97 -0
  5. package/dist/dataset/data/image_data.d.ts +3 -3
  6. package/dist/dataset/data/image_data.js +7 -2
  7. package/dist/dataset/data/index.d.ts +0 -1
  8. package/dist/dataset/data/preprocessing/text_preprocessing.js +23 -9
  9. package/dist/dataset/data/tabular_data.d.ts +3 -3
  10. package/dist/dataset/data/text_data.d.ts +3 -3
  11. package/dist/dataset/dataset.d.ts +48 -5
  12. package/dist/dataset/dataset.js +155 -1
  13. package/dist/dataset/image.d.ts +14 -0
  14. package/dist/dataset/image.js +21 -0
  15. package/dist/dataset/index.d.ts +3 -5
  16. package/dist/dataset/index.js +3 -3
  17. package/dist/dataset/types.d.ts +4 -0
  18. package/dist/dataset/types.js +2 -0
  19. package/dist/index.d.ts +4 -0
  20. package/dist/index.js +4 -0
  21. package/dist/models/gpt/model.js +2 -0
  22. package/dist/models/model.d.ts +1 -2
  23. package/dist/models/tfjs.d.ts +4 -4
  24. package/dist/models/tfjs.js +2 -1
  25. package/dist/processing.d.ts +35 -0
  26. package/dist/processing.js +89 -0
  27. package/dist/training/disco.d.ts +7 -7
  28. package/dist/training/disco.js +21 -19
  29. package/dist/types.d.ts +3 -0
  30. package/dist/types.js +1 -0
  31. package/dist/validation/validator.d.ts +7 -23
  32. package/dist/validation/validator.js +99 -105
  33. package/package.json +1 -1
  34. package/dist/dataset/data_loader/data_loader.d.ts +0 -13
  35. package/dist/dataset/data_loader/data_loader.js +0 -2
  36. package/dist/dataset/data_loader/image_loader.d.ts +0 -21
  37. package/dist/dataset/data_loader/image_loader.js +0 -101
  38. package/dist/dataset/data_loader/index.d.ts +0 -5
  39. package/dist/dataset/data_loader/index.js +0 -4
  40. package/dist/dataset/data_loader/tabular_loader.d.ts +0 -35
  41. package/dist/dataset/data_loader/tabular_loader.js +0 -76
  42. package/dist/dataset/data_loader/text_loader.d.ts +0 -14
  43. package/dist/dataset/data_loader/text_loader.js +0 -25
  44. package/dist/dataset/dataset_builder.d.ts +0 -51
  45. package/dist/dataset/dataset_builder.js +0 -118
@@ -1,23 +1,22 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
2
  import type { List } from 'immutable';
3
3
  import type { Task } from '../../index.js';
4
- import type { Dataset } from '../index.js';
5
4
  import type { PreprocessingFunction } from './preprocessing/base.js';
6
5
  /**
7
6
  * Abstract class representing an immutable Disco dataset, including a TF.js dataset,
8
7
  * Disco task and set of preprocessing functions.
9
8
  */
10
9
  export declare abstract class Data {
11
- readonly dataset: Dataset;
10
+ readonly dataset: tf.data.Dataset<tf.TensorContainer>;
12
11
  readonly task: Task;
13
12
  readonly size?: number | undefined;
14
13
  abstract readonly availablePreprocessing: List<PreprocessingFunction>;
15
- protected constructor(dataset: Dataset, task: Task, size?: number | undefined);
16
- static init(_dataset: Dataset, _task: Task, _size?: number): Promise<Data>;
14
+ protected constructor(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number | undefined);
15
+ static init(_dataset: tf.data.Dataset<tf.TensorContainer>, _task: Task, _size?: number): Promise<Data>;
17
16
  /**
18
17
  * Callable abstract method instead of constructor.
19
18
  */
20
- protected abstract create(dataset: Dataset, task: Task, size?: number): Data;
19
+ protected abstract create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Data;
21
20
  /**
22
21
  * Creates a new Disco data object containing the batched TF.js dataset, according to the
23
22
  * task's parameters.
@@ -27,7 +26,7 @@ export declare abstract class Data {
27
26
  /**
28
27
  * The TF.js dataset batched according to the task's parameters.
29
28
  */
30
- get batchedDataset(): Dataset;
29
+ get batchedDataset(): tf.data.Dataset<tf.TensorContainer>;
31
30
  /**
32
31
  * Creates a new Disco data object containing the preprocessed TF.js dataset,
33
32
  * according to the defined set of preprocessing functions and the task's parameters.
@@ -44,5 +43,5 @@ export declare abstract class Data {
44
43
  * The TF.js dataset preprocessing according to the set of preprocessing functions and the task's
45
44
  * parameters.
46
45
  */
47
- get preprocessedDataset(): Dataset;
46
+ get preprocessedDataset(): tf.data.Dataset<tf.TensorContainer>;
48
47
  }
@@ -57,17 +57,22 @@ export class Data {
57
57
  const applyPreprocessing = this.availablePreprocessing
58
58
  .filter((e) => e.type in taskPreprocessing)
59
59
  .map((e) => e.apply);
60
- if (applyPreprocessing.size === 0) {
61
- return x => Promise.resolve(x);
62
- }
63
60
  const preprocessingChain = async (input) => {
64
61
  let currentContainer = await input; // Start with the initial tensor container
65
62
  for (const fn of applyPreprocessing) {
66
- const newContainer = await fn(Promise.resolve(currentContainer), this.task);
67
- if (currentContainer !== newContainer) {
68
- tf.dispose(currentContainer); // Dispose of the old container
63
+ const next = await fn(Promise.resolve(currentContainer), this.task);
64
+ // dirty but kinda working way to dispose of converted tensors
65
+ if (typeof currentContainer === "object" && typeof next === "object") {
66
+ if ("xs" in currentContainer &&
67
+ "xs" in next &&
68
+ currentContainer.xs !== next.xs)
69
+ tf.dispose(currentContainer.xs);
70
+ if ("ys" in currentContainer &&
71
+ "ys" in next &&
72
+ currentContainer.ys !== next.ys)
73
+ tf.dispose(currentContainer.ys);
69
74
  }
70
- currentContainer = newContainer;
75
+ currentContainer = next;
71
76
  }
72
77
  return currentContainer; // Return the final tensor container
73
78
  };
@@ -0,0 +1,10 @@
1
+ /** Internal functions to help with Dataset to Data/DataSplit conversion
2
+ *
3
+ * @todo rm when fully using Dataset
4
+ */
5
+ import type { Task, TypedDataset, TypedLabeledDataset } from "../../index.js";
6
+ import { Data } from "./index.js";
7
+ import { DataSplit } from "./data_split.js";
8
+ export declare function datasetToData(task: Task, [t, dataset]: TypedDataset): Promise<Data>;
9
+ export declare function labeledDatasetToData(task: Task, [t, dataset]: TypedLabeledDataset): Promise<Data>;
10
+ export declare function labeledDatasetToDataSplit(task: Task, [t, dataset]: TypedLabeledDataset): Promise<DataSplit>;
@@ -0,0 +1,97 @@
1
+ /** Internal functions to help with Dataset to Data/DataSplit conversion
2
+ *
3
+ * @todo rm when fully using Dataset
4
+ */
5
+ import { List } from "immutable";
6
+ import * as tf from "@tensorflow/tfjs";
7
+ import { processing } from "../../index.js";
8
+ import { ImageData, TabularData, TextData } from "./index.js";
9
+ function intoTFDataset(iter) {
10
+ // @ts-expect-error generator
11
+ return tf.data.generator(async function* () {
12
+ yield* iter;
13
+ });
14
+ }
15
+ function imageToTensor(image) {
16
+ return tf.tensor3d(image.data, [image.width, image.height, 3], "int32");
17
+ }
18
+ function tabularToNumbers(columns, row) {
19
+ return List(columns)
20
+ .map((column) => processing.extractColumn(row, column))
21
+ .map((v) => (v !== "" ? v : "0")) // TODO how to specify defaults?
22
+ .map(processing.convertToNumber)
23
+ .toArray();
24
+ }
25
+ export async function datasetToData(task, [t, dataset]) {
26
+ switch (t) {
27
+ case "image": {
28
+ const converted = dataset
29
+ .map(processing.removeAlpha)
30
+ .map((image) => processing.expandToMulticolor(image))
31
+ .map((image) => ({
32
+ xs: imageToTensor(image),
33
+ }));
34
+ return await ImageData.init(intoTFDataset(converted), task);
35
+ }
36
+ case "tabular": {
37
+ const inputColumns = task.trainingInformation.inputColumns;
38
+ if (inputColumns === undefined)
39
+ throw new Error("tabular task without input columns");
40
+ const converted = dataset.map((row) => ({
41
+ xs: tabularToNumbers(inputColumns, row),
42
+ }));
43
+ return await TabularData.init(intoTFDataset(converted), task);
44
+ }
45
+ case "text":
46
+ return await TextData.init(intoTFDataset(dataset), task);
47
+ }
48
+ }
49
+ export async function labeledDatasetToData(task, [t, dataset]) {
50
+ switch (t) {
51
+ case "image": {
52
+ const labels = List(task.trainingInformation.LABEL_LIST);
53
+ const converted = dataset
54
+ .map(([image, label]) => [
55
+ processing.expandToMulticolor(processing.removeAlpha(image)),
56
+ processing.indexInList(label, labels),
57
+ ])
58
+ .map(([image, label]) => ({
59
+ xs: imageToTensor(image),
60
+ ys: tf.oneHot(label, labels.size, 1, 0, "int32"),
61
+ }));
62
+ return await ImageData.init(intoTFDataset(converted), task);
63
+ }
64
+ case "tabular": {
65
+ const { inputColumns, outputColumns } = task.trainingInformation;
66
+ if (inputColumns === undefined || outputColumns === undefined)
67
+ throw new Error("tabular task without input and output columns");
68
+ const converted = dataset.map((row) => ({
69
+ xs: tabularToNumbers(inputColumns, row),
70
+ ys: tf.tensor1d(tabularToNumbers(outputColumns, row)),
71
+ }));
72
+ return await TabularData.init(intoTFDataset(converted), task);
73
+ }
74
+ case "text":
75
+ return await TextData.init(intoTFDataset(dataset), task);
76
+ }
77
+ }
78
+ export async function labeledDatasetToDataSplit(task, [t, dataset]) {
79
+ const split = task.trainingInformation.validationSplit;
80
+ let train;
81
+ let validation;
82
+ switch (t) {
83
+ case "image": {
84
+ [train, validation] = await Promise.all(dataset.split(split).map((d) => labeledDatasetToData(task, [t, d])));
85
+ break;
86
+ }
87
+ case "tabular": {
88
+ [train, validation] = await Promise.all(dataset.split(split).map((d) => labeledDatasetToData(task, [t, d])));
89
+ break;
90
+ }
91
+ case "text": {
92
+ [train, validation] = await Promise.all(dataset.split(split).map((d) => labeledDatasetToData(task, [t, d])));
93
+ break;
94
+ }
95
+ }
96
+ return { train, validation };
97
+ }
@@ -1,11 +1,11 @@
1
+ import * as tf from '@tensorflow/tfjs';
1
2
  import type { Task } from '../../index.js';
2
- import type { Dataset } from '../dataset.js';
3
3
  import { Data } from './data.js';
4
4
  /**
5
5
  * Disco data made of image samples (.jpg, .png, etc.).
6
6
  */
7
7
  export declare class ImageData extends Data {
8
8
  readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
9
- static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
10
- protected create(dataset: Dataset, task: Task, size: number): ImageData;
9
+ static init(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Promise<Data>;
10
+ protected create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size: number): ImageData;
11
11
  }
@@ -1,3 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
1
2
  import { Data } from './data.js';
2
3
  import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing/index.js';
3
4
  /**
@@ -11,14 +12,17 @@ export class ImageData extends Data {
11
12
  // cause an error during training, because of the lazy aspect of the dataset; we only
12
13
  // verify the first sample.
13
14
  if (task.trainingInformation.preprocessingFunctions?.includes(ImagePreprocessing.Resize) !== true) {
14
- const sample = (await dataset.take(1).toArray())[0];
15
+ const iteration = await dataset.iterator().then((iter) => iter.next());
16
+ if (iteration.done === true)
17
+ throw new Error("empty dataset");
18
+ const sample = iteration.value;
15
19
  // TODO: We suppose the presence of labels
16
20
  // TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
17
21
  if (typeof sample !== 'object' || sample === null || sample === undefined) {
18
22
  throw new Error("Image is undefined or is not an object");
19
23
  }
20
24
  let shape;
21
- if ('xs' in sample && 'ys' in sample) {
25
+ if ('xs' in sample) {
22
26
  shape = sample.xs.shape;
23
27
  }
24
28
  else {
@@ -29,6 +33,7 @@ export class ImageData extends Data {
29
33
  (shape[0] !== IMAGE_W || shape[1] !== IMAGE_H)) {
30
34
  throw new Error(`Image doesn't have the dimensions specified in the task's training information. Expected ${IMAGE_H}x${IMAGE_W} but got ${shape[0]}x${shape[1]}.`);
31
35
  }
36
+ tf.dispose(sample);
32
37
  }
33
38
  return new ImageData(dataset, task, size);
34
39
  }
@@ -1,4 +1,3 @@
1
- export type { DataSplit } from './data_split.js';
2
1
  export { Data } from './data.js';
3
2
  export { ImageData } from './image_data.js';
4
3
  export { TabularData } from './tabular_data.js';
@@ -9,6 +9,21 @@ export var TextPreprocessing;
9
9
  TextPreprocessing[TextPreprocessing["Tokenize"] = 0] = "Tokenize";
10
10
  TextPreprocessing[TextPreprocessing["LeftPadding"] = 1] = "LeftPadding";
11
11
  })(TextPreprocessing || (TextPreprocessing = {}));
12
+ function isNumberArray(raw) {
13
+ if (!Array.isArray(raw))
14
+ return false;
15
+ const arr = raw; // isArray is unsafely guarding with any[]
16
+ return arr.every((e) => typeof e === "number");
17
+ }
18
+ function isTokenizedEntry(raw) {
19
+ if (typeof raw !== "object" || raw === null)
20
+ return false;
21
+ const { tokens } = raw;
22
+ if (!isNumberArray(tokens))
23
+ return false;
24
+ const _ = { tokens };
25
+ return true;
26
+ }
12
27
  /**
13
28
  * LeftPadding pads all incoming inputs to be a fixed length, which should be specified
14
29
  * in `task.trainingInformation.maxSequenceLength`.
@@ -23,11 +38,11 @@ export var TextPreprocessing;
23
38
  */
24
39
  const leftPadding = {
25
40
  type: TextPreprocessing.LeftPadding,
26
- apply: async (x, task) => {
27
- if (x === undefined || !Array.isArray(x) || x.length == 0 || typeof (x[0] !== 'number')) {
28
- new Error("The leftPadding preprocessing expects a non empty 1D array of number");
29
- }
30
- const { tokens } = await x;
41
+ apply: async (input, task) => {
42
+ const x = await input;
43
+ if (!isTokenizedEntry(x))
44
+ throw new Error("The leftPadding preprocessing expects a non empty 1D array of number");
45
+ const { tokens } = x;
31
46
  const tokenizer = await models.getTaskTokenizer(task);
32
47
  return tf.tidy(() => {
33
48
  // maxLength is the final length of xs
@@ -59,10 +74,9 @@ const leftPadding = {
59
74
  const tokenize = {
60
75
  type: TextPreprocessing.Tokenize,
61
76
  apply: async (x, task) => {
62
- if (typeof x !== 'string') {
63
- new Error("The tokenize preprocessing expects a string as input");
64
- }
65
- const xs = await x; // tf.TextLineDataset yields strings
77
+ const xs = await x;
78
+ if (typeof xs !== 'string')
79
+ throw new Error("The tokenize preprocessing expects a string as input");
66
80
  const tokenizer = await models.getTaskTokenizer(task);
67
81
  // Add plus one to include the next token label of the last token in the input sequence
68
82
  // The inputs are truncated down to exactly maxSequenceLength in leftPadding
@@ -1,11 +1,11 @@
1
+ import * as tf from '@tensorflow/tfjs';
1
2
  import type { Task } from '../../index.js';
2
- import type { Dataset } from '../dataset.js';
3
3
  import { Data } from './data.js';
4
4
  /**
5
5
  * Disco data made of tabular (.csv, .tsv, etc.) files.
6
6
  */
7
7
  export declare class TabularData extends Data {
8
8
  readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
9
- static init(dataset: Dataset, task: Task, size?: number): Promise<TabularData>;
10
- protected create(dataset: Dataset, task: Task, size: number): TabularData;
9
+ static init(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Promise<TabularData>;
10
+ protected create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size: number): TabularData;
11
11
  }
@@ -1,11 +1,11 @@
1
+ import * as tf from '@tensorflow/tfjs';
1
2
  import type { Task } from '../../index.js';
2
- import type { Dataset } from '../dataset.js';
3
3
  import { Data } from './data.js';
4
4
  /**
5
5
  * Disco data made of textual samples.
6
6
  */
7
7
  export declare class TextData extends Data {
8
8
  readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
9
- static init(dataset: Dataset, task: Task, size?: number): Promise<TextData>;
10
- protected create(dataset: Dataset, task: Task, size?: number): TextData;
9
+ static init(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Promise<TextData>;
10
+ protected create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): TextData;
11
11
  }
@@ -1,5 +1,48 @@
1
- import type tf from '@tensorflow/tfjs';
2
- /**
3
- * Convenient type for the common dataset type used in TF.js.
4
- */
5
- export type Dataset = tf.data.Dataset<tf.TensorContainer>;
1
+ import { List } from "immutable";
2
+ type DatasetLike<T> = AsyncIterable<T> | Iterable<T> | (() => AsyncIterator<T, void>) | (() => Iterator<T, void>);
3
+ /** Immutable series of data */
4
+ export declare class Dataset<T> implements AsyncIterable<T> {
5
+ #private;
6
+ /** Wrap given data generator
7
+ *
8
+ * To avoid loading everything in memory, it is a function that upon calling
9
+ * should return a new AsyncGenerator with the same data as before.
10
+ */
11
+ constructor(content: DatasetLike<T>);
12
+ [Symbol.asyncIterator](): AsyncIterator<T>;
13
+ /** Apply function to each element
14
+ *
15
+ * @param mapper how to change each element
16
+ */
17
+ map<U>(mapper: (_: T) => U | Promise<U>): Dataset<U>;
18
+ /** Combine with another Dataset.
19
+ *
20
+ * @param other what to yield after us
21
+ */
22
+ chain(other: Dataset<T> | DatasetLike<T>): Dataset<T>;
23
+ /** Divide into two based on given ratio
24
+ *
25
+ * @param ratio between 0 (all on left) and 1 (all on right)
26
+ */
27
+ split(ratio: number): [Dataset<T>, Dataset<T>];
28
+ /** Slice into chunks
29
+ *
30
+ * Last slice is smaller if dataset isn't perfectly divisible
31
+ *
32
+ * @param size count of element per chunk
33
+ */
34
+ batch(size: number): Dataset<List<T>>;
35
+ /** Join side-by-side
36
+ *
37
+ * Stops as soon as one runs out
38
+ *
39
+ * @param other right side
40
+ **/
41
+ zip<U>(other: Dataset<U> | DatasetLike<U>): Dataset<[T, U]>;
42
+ /** Compute size
43
+ *
44
+ * This is a costly operation as we need to go through the whole Dataset.
45
+ */
46
+ size(): Promise<number>;
47
+ }
48
+ export {};
@@ -1 +1,155 @@
1
- export {};
1
+ import { List } from "immutable";
2
+ /** Immutable series of data */
3
+ export class Dataset {
4
+ #content;
5
+ /** Wrap given data generator
6
+ *
7
+ * To avoid loading everything in memory, it is a function that upon calling
8
+ * should return a new AsyncGenerator with the same data as before.
9
+ */
10
+ constructor(content) {
11
+ this.#content = async function* () {
12
+ let iter;
13
+ if (typeof content === "function")
14
+ iter = content();
15
+ else if (Symbol.asyncIterator in content)
16
+ iter = content[Symbol.asyncIterator]();
17
+ else
18
+ iter = content[Symbol.iterator]();
19
+ while (true) {
20
+ const result = await iter.next();
21
+ if (result.done === true)
22
+ break;
23
+ yield result.value;
24
+ }
25
+ };
26
+ }
27
+ [Symbol.asyncIterator]() {
28
+ return this.#content();
29
+ }
30
+ /** Apply function to each element
31
+ *
32
+ * @param mapper how to change each element
33
+ */
34
+ map(mapper) {
35
+ const content = {
36
+ [Symbol.asyncIterator]: () => this.#content(),
37
+ };
38
+ return new Dataset(async function* () {
39
+ for await (const e of content)
40
+ yield await mapper(e);
41
+ });
42
+ }
43
+ /** Combine with another Dataset.
44
+ *
45
+ * @param other what to yield after us
46
+ */
47
+ chain(other) {
48
+ if (!(other instanceof Dataset))
49
+ other = new Dataset(other);
50
+ const self = {
51
+ [Symbol.asyncIterator]: () => this.#content(),
52
+ };
53
+ return new Dataset(async function* () {
54
+ yield* self;
55
+ yield* other;
56
+ });
57
+ }
58
+ /** Divide into two based on given ratio
59
+ *
60
+ * @param ratio between 0 (all on left) and 1 (all on right)
61
+ */
62
+ split(ratio) {
63
+ if (ratio < 0 || ratio > 1)
64
+ throw new Error("ratio out of range");
65
+ const content = {
66
+ [Symbol.asyncIterator]: () => this.#content(),
67
+ };
68
+ // to avoid using random sampling or knowing the size beforehand,
69
+ // we compute the actual ratio and make it converge towards the wanted one
70
+ return [
71
+ new Dataset(async function* () {
72
+ let yielded_by_other = 0;
73
+ let total_size = 0;
74
+ for await (const e of content) {
75
+ total_size++;
76
+ if (yielded_by_other / total_size >= ratio) {
77
+ yield e;
78
+ }
79
+ else {
80
+ yielded_by_other++;
81
+ }
82
+ }
83
+ }),
84
+ new Dataset(async function* () {
85
+ let yielded = 0;
86
+ let total_size = 0;
87
+ for await (const e of content) {
88
+ total_size++;
89
+ if (yielded / total_size < ratio) {
90
+ yielded++;
91
+ yield e;
92
+ }
93
+ }
94
+ }),
95
+ ];
96
+ }
97
+ /** Slice into chunks
98
+ *
99
+ * Last slice is smaller if dataset isn't perfectly divisible
100
+ *
101
+ * @param size count of element per chunk
102
+ */
103
+ batch(size) {
104
+ if (size <= 0 || !Number.isInteger(size))
105
+ throw new Error("invalid size");
106
+ const content = {
107
+ [Symbol.asyncIterator]: () => this.#content(),
108
+ };
109
+ return new Dataset(async function* () {
110
+ let batch = List();
111
+ for await (const e of content) {
112
+ batch = batch.push(e);
113
+ if (batch.size === size) {
114
+ yield batch;
115
+ batch = List();
116
+ }
117
+ }
118
+ if (!batch.isEmpty())
119
+ yield batch;
120
+ });
121
+ }
122
+ /** Join side-by-side
123
+ *
124
+ * Stops as soon as one runs out
125
+ *
126
+ * @param other right side
127
+ **/
128
+ zip(other) {
129
+ if (!(other instanceof Dataset))
130
+ other = new Dataset(other);
131
+ const content = {
132
+ [Symbol.asyncIterator]: () => this.#content(),
133
+ };
134
+ return new Dataset(async function* () {
135
+ const left = content[Symbol.asyncIterator]();
136
+ const right = other[Symbol.asyncIterator]();
137
+ while (true) {
138
+ const [l, r] = await Promise.all([left.next(), right.next()]);
139
+ if (l.done || r.done)
140
+ return;
141
+ yield [l.value, r.value];
142
+ }
143
+ });
144
+ }
145
+ /** Compute size
146
+ *
147
+ * This is a costly operation as we need to go through the whole Dataset.
148
+ */
149
+ async size() {
150
+ let ret = 0;
151
+ for await (const _ of this)
152
+ ret++;
153
+ return ret;
154
+ }
155
+ }
@@ -0,0 +1,14 @@
1
+ /**
2
+ * Raw image with type level dimensions.
3
+ *
4
+ * @typeParam D depth of the image
5
+ * @typeParam W width, positive and integral
6
+ * @typeParam H height, positive and integral
7
+ */
8
+ export declare class Image<D extends 1 | 3 | 4 = 1 | 3 | 4, W extends number = number, H extends number = number> {
9
+ readonly data: Readonly<Uint8Array>;
10
+ readonly width: W;
11
+ readonly height: H;
12
+ readonly depth: D;
13
+ constructor(data: Readonly<Uint8Array>, width: W, height: H, depth: D);
14
+ }
@@ -0,0 +1,21 @@
1
+ /**
2
+ * Raw image with type level dimensions.
3
+ *
4
+ * @typeParam D depth of the image
5
+ * @typeParam W width, positive and integral
6
+ * @typeParam H height, positive and integral
7
+ */
8
+ export class Image {
9
+ data;
10
+ width;
11
+ height;
12
+ depth;
13
+ constructor(data, width, height, depth) {
14
+ this.data = data;
15
+ this.width = width;
16
+ this.height = height;
17
+ this.depth = depth;
18
+ if (data.length != width * height * depth)
19
+ throw new Error("data isn't of excepted size");
20
+ }
21
+ }
@@ -1,5 +1,3 @@
1
- export type { Dataset } from './dataset.js';
2
- export { DatasetBuilder } from './dataset_builder.js';
3
- export { ImageLoader, TabularLoader, DataLoader, TextLoader } from './data_loader/index.js';
4
- export type { DataSplit } from './data/index.js';
5
- export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './data/index.js';
1
+ export { Dataset } from "./dataset.js";
2
+ export * from "./types.js";
3
+ export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING, } from "./data/index.js";
@@ -1,3 +1,3 @@
1
- export { DatasetBuilder } from './dataset_builder.js';
2
- export { ImageLoader, TabularLoader, DataLoader, TextLoader } from './data_loader/index.js';
3
- export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './data/index.js';
1
+ export { Dataset } from "./dataset.js";
2
+ export * from "./types.js";
3
+ export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING, } from "./data/index.js";
@@ -0,0 +1,4 @@
1
+ import { Image } from "./image.js";
2
+ export { Image };
3
+ export type Tabular = Partial<Record<string, string>>;
4
+ export type Text = string;
@@ -0,0 +1,2 @@
1
+ import { Image } from "./image.js";
2
+ export { Image };
package/dist/index.d.ts CHANGED
@@ -14,3 +14,7 @@ export * as models from './models/index.js';
14
14
  export * from './task/index.js';
15
15
  export * as defaultTasks from './default_tasks/index.js';
16
16
  export * as async_iterator from "./utils/async_iterator.js";
17
+ export { Dataset } from "./dataset/index.js";
18
+ export * from "./dataset/types.js";
19
+ export * from "./types.js";
20
+ export * as processing from "./processing.js";
package/dist/index.js CHANGED
@@ -14,3 +14,7 @@ export * as models from './models/index.js';
14
14
  export * from './task/index.js';
15
15
  export * as defaultTasks from './default_tasks/index.js';
16
16
  export * as async_iterator from "./utils/async_iterator.js";
17
+ export { Dataset } from "./dataset/index.js";
18
+ export * from "./dataset/types.js"; // TODO merge with above
19
+ export * from "./types.js";
20
+ export * as processing from "./processing.js";
@@ -30,6 +30,8 @@ class GPTModel extends tf.LayersModel {
30
30
  return this.config;
31
31
  }
32
32
  compile() {
33
+ if (this.optimizer !== undefined)
34
+ return;
33
35
  this.optimizer = this.config.weightDecay !== 0
34
36
  ? getCustomAdam(this, this.config.lr, this.config.weightDecay)
35
37
  : tf.train.adam(this.config.lr);
@@ -1,6 +1,5 @@
1
1
  import type tf from "@tensorflow/tfjs";
2
2
  import type { WeightsContainer } from "../index.js";
3
- import type { Dataset } from "../dataset/index.js";
4
3
  import type { BatchLogs, EpochLogs } from "./logs.js";
5
4
  export type Prediction = tf.Tensor;
6
5
  export type Sample = tf.Tensor;
@@ -23,7 +22,7 @@ export declare abstract class Model implements Disposable {
23
22
  * @param tracker watch the various steps
24
23
  * @yields on every epoch, training can be stop by `return`ing it
25
24
  */
26
- abstract train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
25
+ abstract train(trainingData: tf.data.Dataset<tf.TensorContainer>, validationData?: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<BatchLogs, EpochLogs>;
27
26
  /** Predict likely values */
28
27
  abstract predict(input: Sample): Promise<Prediction>;
29
28
  /**