@epfml/discojs 3.0.1-p20241024094708.0 → 3.0.1-p20241028120035.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. package/dist/aggregator/get.d.ts +3 -3
  2. package/dist/client/client.d.ts +5 -5
  3. package/dist/client/decentralized/decentralized_client.d.ts +2 -2
  4. package/dist/client/federated/federated_client.d.ts +2 -2
  5. package/dist/client/utils.d.ts +2 -2
  6. package/dist/dataset/dataset.d.ts +9 -2
  7. package/dist/dataset/dataset.js +83 -36
  8. package/dist/dataset/image.d.ts +5 -0
  9. package/dist/dataset/image.js +6 -1
  10. package/dist/dataset/index.d.ts +0 -1
  11. package/dist/dataset/index.js +0 -1
  12. package/dist/dataset/types.d.ts +2 -0
  13. package/dist/default_tasks/cifar10.d.ts +1 -1
  14. package/dist/default_tasks/cifar10.js +2 -3
  15. package/dist/default_tasks/lus_covid.d.ts +1 -1
  16. package/dist/default_tasks/lus_covid.js +2 -3
  17. package/dist/default_tasks/mnist.d.ts +1 -1
  18. package/dist/default_tasks/mnist.js +2 -4
  19. package/dist/default_tasks/simple_face.d.ts +1 -1
  20. package/dist/default_tasks/simple_face.js +2 -3
  21. package/dist/default_tasks/titanic.d.ts +1 -1
  22. package/dist/default_tasks/titanic.js +3 -6
  23. package/dist/default_tasks/wikitext.d.ts +1 -1
  24. package/dist/default_tasks/wikitext.js +1 -2
  25. package/dist/index.d.ts +4 -5
  26. package/dist/index.js +4 -5
  27. package/dist/models/gpt/index.d.ts +13 -16
  28. package/dist/models/gpt/index.js +62 -43
  29. package/dist/models/gpt/model.d.ts +1 -15
  30. package/dist/models/gpt/model.js +1 -75
  31. package/dist/models/model.d.ts +7 -12
  32. package/dist/models/tfjs.d.ts +10 -8
  33. package/dist/models/tfjs.js +106 -44
  34. package/dist/models/tokenizer.d.ts +1 -1
  35. package/dist/privacy.js +1 -1
  36. package/dist/processing/image.d.ts +18 -0
  37. package/dist/processing/image.js +75 -0
  38. package/dist/processing/index.d.ts +8 -0
  39. package/dist/processing/index.js +106 -0
  40. package/dist/processing/tabular.d.ts +19 -0
  41. package/dist/processing/tabular.js +33 -0
  42. package/dist/processing/text.d.ts +11 -0
  43. package/dist/processing/text.js +33 -0
  44. package/dist/serialization/model.d.ts +3 -3
  45. package/dist/serialization/model.js +19 -6
  46. package/dist/task/task.d.ts +4 -3
  47. package/dist/task/task.js +5 -3
  48. package/dist/task/task_handler.d.ts +3 -3
  49. package/dist/task/task_provider.d.ts +4 -4
  50. package/dist/task/training_information.d.ts +25 -16
  51. package/dist/task/training_information.js +76 -72
  52. package/dist/training/disco.d.ts +20 -12
  53. package/dist/training/disco.js +32 -13
  54. package/dist/training/trainer.d.ts +6 -7
  55. package/dist/training/trainer.js +6 -6
  56. package/dist/types/data_format.d.ts +40 -0
  57. package/dist/types/index.d.ts +2 -0
  58. package/dist/types/index.js +1 -0
  59. package/dist/validator.d.ts +10 -0
  60. package/dist/validator.js +30 -0
  61. package/package.json +4 -2
  62. package/dist/dataset/data/data.d.ts +0 -47
  63. package/dist/dataset/data/data.js +0 -88
  64. package/dist/dataset/data/data_split.d.ts +0 -8
  65. package/dist/dataset/data/helpers.d.ts +0 -10
  66. package/dist/dataset/data/helpers.js +0 -97
  67. package/dist/dataset/data/image_data.d.ts +0 -11
  68. package/dist/dataset/data/image_data.js +0 -43
  69. package/dist/dataset/data/index.d.ts +0 -5
  70. package/dist/dataset/data/index.js +0 -5
  71. package/dist/dataset/data/preprocessing/base.d.ts +0 -16
  72. package/dist/dataset/data/preprocessing/base.js +0 -1
  73. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
  74. package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
  75. package/dist/dataset/data/preprocessing/index.d.ts +0 -4
  76. package/dist/dataset/data/preprocessing/index.js +0 -3
  77. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
  78. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
  79. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
  80. package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
  81. package/dist/dataset/data/tabular_data.d.ts +0 -11
  82. package/dist/dataset/data/tabular_data.js +0 -24
  83. package/dist/dataset/data/text_data.d.ts +0 -11
  84. package/dist/dataset/data/text_data.js +0 -14
  85. package/dist/processing.d.ts +0 -35
  86. package/dist/processing.js +0 -89
  87. package/dist/types.d.ts +0 -3
  88. package/dist/types.js +0 -1
  89. package/dist/validation/index.d.ts +0 -1
  90. package/dist/validation/index.js +0 -1
  91. package/dist/validation/validator.d.ts +0 -10
  92. package/dist/validation/validator.js +0 -113
  93. /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
@@ -1,43 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { Data } from './data.js';
3
- import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing/index.js';
4
- /**
5
- * Disco data made of image samples (.jpg, .png, etc.).
6
- */
7
- export class ImageData extends Data {
8
- availablePreprocessing = IMAGE_PREPROCESSING;
9
- static async init(dataset, task, size) {
10
- // Here we do our best to check data format before proceeding to training, for
11
- // better error handling. An incorrectly formatted image in the dataset might still
12
- // cause an error during training, because of the lazy aspect of the dataset; we only
13
- // verify the first sample.
14
- if (task.trainingInformation.preprocessingFunctions?.includes(ImagePreprocessing.Resize) !== true) {
15
- const iteration = await dataset.iterator().then((iter) => iter.next());
16
- if (iteration.done === true)
17
- throw new Error("empty dataset");
18
- const sample = iteration.value;
19
- // TODO: We suppose the presence of labels
20
- // TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
21
- if (typeof sample !== 'object' || sample === null || sample === undefined) {
22
- throw new Error("Image is undefined or is not an object");
23
- }
24
- let shape;
25
- if ('xs' in sample) {
26
- shape = sample.xs.shape;
27
- }
28
- else {
29
- shape = sample.shape;
30
- }
31
- const { IMAGE_H, IMAGE_W } = task.trainingInformation;
32
- if (IMAGE_W !== undefined && IMAGE_H !== undefined &&
33
- (shape[0] !== IMAGE_W || shape[1] !== IMAGE_H)) {
34
- throw new Error(`Image doesn't have the dimensions specified in the task's training information. Expected ${IMAGE_H}x${IMAGE_W} but got ${shape[0]}x${shape[1]}.`);
35
- }
36
- tf.dispose(sample);
37
- }
38
- return new ImageData(dataset, task, size);
39
- }
40
- create(dataset, task, size) {
41
- return new ImageData(dataset, task, size);
42
- }
43
- }
@@ -1,5 +0,0 @@
1
- export { Data } from './data.js';
2
- export { ImageData } from './image_data.js';
3
- export { TabularData } from './tabular_data.js';
4
- export { TextData } from './text_data.js';
5
- export { ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './preprocessing/index.js';
@@ -1,5 +0,0 @@
1
- export { Data } from './data.js';
2
- export { ImageData } from './image_data.js';
3
- export { TabularData } from './tabular_data.js';
4
- export { TextData } from './text_data.js';
5
- export { ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './preprocessing/index.js';
@@ -1,16 +0,0 @@
1
- import type tf from '@tensorflow/tfjs';
2
- import type { Task } from '../../../index.js';
3
- import type { ImagePreprocessing } from './image_preprocessing.js';
4
- import type { TabularPreprocessing } from './tabular_preprocessing.js';
5
- import type { TextPreprocessing } from './text_preprocessing.js';
6
- /**
7
- * All available preprocessing type enums.
8
- */
9
- export type Preprocessing = ImagePreprocessing | TextPreprocessing | TabularPreprocessing;
10
- /**
11
- * Preprocessing function associating a preprocessing type enum to a sample transformation.
12
- */
13
- export interface PreprocessingFunction {
14
- type: Preprocessing;
15
- apply: (x: Promise<tf.TensorContainer>, task: Task) => Promise<tf.TensorContainer>;
16
- }
@@ -1 +0,0 @@
1
- export {};
@@ -1,13 +0,0 @@
1
- import { List } from 'immutable';
2
- import type { PreprocessingFunction } from './base.js';
3
- /**
4
- * Available image preprocessing types.
5
- */
6
- export declare enum ImagePreprocessing {
7
- Resize = 0,
8
- Normalize = 1
9
- }
10
- /**
11
- * Available image preprocessing functions.
12
- */
13
- export declare const AVAILABLE_PREPROCESSING: List<PreprocessingFunction>;
@@ -1,42 +0,0 @@
1
- import { List } from 'immutable';
2
- import * as tf from '@tensorflow/tfjs';
3
- /**
4
- * Available image preprocessing types.
5
- */
6
- export var ImagePreprocessing;
7
- (function (ImagePreprocessing) {
8
- ImagePreprocessing[ImagePreprocessing["Resize"] = 0] = "Resize";
9
- ImagePreprocessing[ImagePreprocessing["Normalize"] = 1] = "Normalize";
10
- })(ImagePreprocessing || (ImagePreprocessing = {}));
11
- const resize = {
12
- type: ImagePreprocessing.Resize,
13
- apply: async (entry, task) => {
14
- const { xs, ys } = await entry;
15
- const params = task.trainingInformation;
16
- return {
17
- xs: params.IMAGE_W !== undefined && params.IMAGE_H !== undefined
18
- ? xs.resizeBilinear([params.IMAGE_H, params.IMAGE_W])
19
- : xs,
20
- ys
21
- };
22
- }
23
- };
24
- const normalize = {
25
- type: ImagePreprocessing.Normalize,
26
- apply: async (entry) => {
27
- const { xs, ys } = await entry;
28
- return tf.tidy(() => {
29
- return {
30
- xs: xs.div(tf.scalar(255)),
31
- ys
32
- };
33
- });
34
- }
35
- };
36
- /**
37
- * Available image preprocessing functions.
38
- */
39
- export const AVAILABLE_PREPROCESSING = List([
40
- resize,
41
- normalize
42
- ]).sortBy((e) => e.type);
@@ -1,4 +0,0 @@
1
- export type { Preprocessing, PreprocessingFunction } from './base.js';
2
- export { AVAILABLE_PREPROCESSING as IMAGE_PREPROCESSING, ImagePreprocessing } from './image_preprocessing.js';
3
- export { AVAILABLE_PREPROCESSING as TABULAR_PREPROCESSING, TabularPreprocessing } from './tabular_preprocessing.js';
4
- export { AVAILABLE_PREPROCESSING as TEXT_PREPROCESSING, TextPreprocessing } from './text_preprocessing.js';
@@ -1,3 +0,0 @@
1
- export { AVAILABLE_PREPROCESSING as IMAGE_PREPROCESSING, ImagePreprocessing } from './image_preprocessing.js';
2
- export { AVAILABLE_PREPROCESSING as TABULAR_PREPROCESSING, TabularPreprocessing } from './tabular_preprocessing.js';
3
- export { AVAILABLE_PREPROCESSING as TEXT_PREPROCESSING, TextPreprocessing } from './text_preprocessing.js';
@@ -1,13 +0,0 @@
1
- import { List } from 'immutable';
2
- import type { PreprocessingFunction } from './base.js';
3
- /**
4
- * Available tabular preprocessing types.
5
- */
6
- export declare enum TabularPreprocessing {
7
- Sanitize = 0,
8
- Normalize = 1
9
- }
10
- /**
11
- * Available tabular preprocessing functions.
12
- */
13
- export declare const AVAILABLE_PREPROCESSING: List<PreprocessingFunction>;
@@ -1,45 +0,0 @@
1
- import { List } from 'immutable';
2
- /**
3
- * Available tabular preprocessing types.
4
- */
5
- export var TabularPreprocessing;
6
- (function (TabularPreprocessing) {
7
- TabularPreprocessing[TabularPreprocessing["Sanitize"] = 0] = "Sanitize";
8
- TabularPreprocessing[TabularPreprocessing["Normalize"] = 1] = "Normalize";
9
- })(TabularPreprocessing || (TabularPreprocessing = {}));
10
- const sanitize = {
11
- type: TabularPreprocessing.Sanitize,
12
- apply: async (entry) => {
13
- const entryContainer = await entry;
14
- // if preprocessing a dataset without labels, then the entry is an array of numbers
15
- if (Array.isArray(entryContainer)) {
16
- const entry = entryContainer;
17
- return entry.map((i) => i ?? 0);
18
- // if it is an object
19
- }
20
- else if (typeof entryContainer === 'object' && entry !== null) {
21
- // if the object is a tensor container with features xs and labels ys
22
- if (Object.hasOwn(entryContainer, 'xs')) {
23
- const { xs, ys } = entryContainer;
24
- return {
25
- xs: xs.map(i => i ?? 0),
26
- ys
27
- };
28
- // if the object contains features as a dict of feature names-values
29
- }
30
- else {
31
- const entry = Object.values(entryContainer);
32
- return entry.map((i) => i ?? 0);
33
- }
34
- }
35
- else {
36
- throw new Error('Unrecognized format during tabular preprocessing');
37
- }
38
- }
39
- };
40
- /**
41
- * Available tabular preprocessing functions.
42
- */
43
- export const AVAILABLE_PREPROCESSING = List([
44
- sanitize
45
- ]).sortBy((e) => e.type);
@@ -1,13 +0,0 @@
1
- import { List } from 'immutable';
2
- import type { PreprocessingFunction } from './base.js';
3
- /**
4
- * Available text preprocessing types.
5
- */
6
- export declare enum TextPreprocessing {
7
- Tokenize = 0,
8
- LeftPadding = 1
9
- }
10
- /**
11
- * Available text preprocessing functions.
12
- */
13
- export declare const AVAILABLE_PREPROCESSING: List<PreprocessingFunction>;
@@ -1,100 +0,0 @@
1
- import { List } from 'immutable';
2
- import * as tf from '@tensorflow/tfjs';
3
- import { models } from '../../../index.js';
4
- /**
5
- * Available text preprocessing types.
6
- */
7
- export var TextPreprocessing;
8
- (function (TextPreprocessing) {
9
- TextPreprocessing[TextPreprocessing["Tokenize"] = 0] = "Tokenize";
10
- TextPreprocessing[TextPreprocessing["LeftPadding"] = 1] = "LeftPadding";
11
- })(TextPreprocessing || (TextPreprocessing = {}));
12
- function isNumberArray(raw) {
13
- if (!Array.isArray(raw))
14
- return false;
15
- const arr = raw; // isArray is unsafely guarding with any[]
16
- return arr.every((e) => typeof e === "number");
17
- }
18
- function isTokenizedEntry(raw) {
19
- if (typeof raw !== "object" || raw === null)
20
- return false;
21
- const { tokens } = raw;
22
- if (!isNumberArray(tokens))
23
- return false;
24
- const _ = { tokens };
25
- return true;
26
- }
27
- /**
28
- * LeftPadding pads all incoming inputs to be a fixed length, which should be specified
29
- * in `task.trainingInformation.maxSequenceLength`.
30
- *
31
- * We are currently only implementing left padding for text generation
32
- * https://huggingface.co/docs/transformers/en/llm_tutorial#wrong-padding-side
33
- * The function can easily be extended to support right padding if needed
34
- *
35
- * Once Transformers.js supports left padding, it will be possible to pad inputs
36
- * directly when tokenizing
37
- * https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
38
- */
39
- const leftPadding = {
40
- type: TextPreprocessing.LeftPadding,
41
- apply: async (input, task) => {
42
- const x = await input;
43
- if (!isTokenizedEntry(x))
44
- throw new Error("The leftPadding preprocessing expects a non empty 1D array of number");
45
- const { tokens } = x;
46
- const tokenizer = await models.getTaskTokenizer(task);
47
- return tf.tidy(() => {
48
- // maxLength is the final length of xs
49
- // Because ys the contains the tokens in xs shifted by one (to predict the next token), we need
50
- // to include one more token than maxSequenceLength in order to have the next token's label of the maxSequenceLength'th token
51
- const maxLength = task.trainingInformation.maxSequenceLength ?? tokenizer.model_max_length;
52
- const maxLengthPlusLabel = maxLength + 1;
53
- let fixedLengthTokens = tf.tensor1d(tokens, 'int32'); // cast tokens from float to int for gpt-tfjs
54
- if (fixedLengthTokens.size > maxLengthPlusLabel) { // Should never happen because tokenization truncates inputs
55
- throw Error("There are more tokens than expected after tokenization and truncation");
56
- }
57
- else if (fixedLengthTokens.size < maxLengthPlusLabel) { // Pad inputs to fixed length
58
- const paddingToken = tokenizer.pad_token_id;
59
- fixedLengthTokens = fixedLengthTokens.pad([[Math.max(0, maxLengthPlusLabel - fixedLengthTokens.size), 0]], paddingToken);
60
- }
61
- // if tokens.size == maxLengthPlusLabel we can leave it as it is
62
- // ys is a one-hot encoding of the next token (i.e. xs shifted by one)
63
- // cast because oneHot isn't size-typing its return value
64
- const ys = tf.oneHot(fixedLengthTokens.slice([1]), tokenizer.model.vocab.length + 1);
65
- // remove the extra token now that ys is created
66
- const xs = fixedLengthTokens.slice([0], maxLength);
67
- return { xs, ys };
68
- });
69
- }
70
- };
71
- /**
72
- * Tokenize and truncates input strings
73
- */
74
- const tokenize = {
75
- type: TextPreprocessing.Tokenize,
76
- apply: async (x, task) => {
77
- const xs = await x;
78
- if (typeof xs !== 'string')
79
- throw new Error("The tokenize preprocessing expects a string as input");
80
- const tokenizer = await models.getTaskTokenizer(task);
81
- // Add plus one to include the next token label of the last token in the input sequence
82
- // The inputs are truncated down to exactly maxSequenceLength in leftPadding
83
- const maxLength = task.trainingInformation.maxSequenceLength ?? tokenizer.model_max_length;
84
- const maxLengthPlusLabel = maxLength + 1;
85
- const { input_ids: tokens } = tokenizer(xs, {
86
- // Transformers.js currently only supports right padding while we need left for text generation
87
- // Right padding should be supported in the future, once it is, we can directly pad while tokenizing
88
- // https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
89
- padding: false,
90
- truncation: true,
91
- return_tensor: false,
92
- max_length: maxLengthPlusLabel,
93
- });
94
- return { tokens };
95
- }
96
- };
97
- /**
98
- * Available text preprocessing functions.
99
- */
100
- export const AVAILABLE_PREPROCESSING = List.of(tokenize, leftPadding).sortBy((e) => e.type);
@@ -1,11 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import type { Task } from '../../index.js';
3
- import { Data } from './data.js';
4
- /**
5
- * Disco data made of tabular (.csv, .tsv, etc.) files.
6
- */
7
- export declare class TabularData extends Data {
8
- readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
9
- static init(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Promise<TabularData>;
10
- protected create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size: number): TabularData;
11
- }
@@ -1,24 +0,0 @@
1
- import { Data } from './data.js';
2
- import { TABULAR_PREPROCESSING } from './preprocessing/index.js';
3
- /**
4
- * Disco data made of tabular (.csv, .tsv, etc.) files.
5
- */
6
- export class TabularData extends Data {
7
- availablePreprocessing = TABULAR_PREPROCESSING;
8
- static async init(dataset, task, size) {
9
- // Force the check of the data column format (among other things) before proceeding
10
- // to training, for better error handling. An incorrectly formatted line might still
11
- // cause an error during training, because of the lazy aspect of the dataset; we only
12
- // load/read the tabular file's lines on training.
13
- try {
14
- await dataset.iterator();
15
- }
16
- catch (cause) {
17
- throw new Error('data input format not compatible with chosen task', { cause });
18
- }
19
- return new TabularData(dataset, task, size);
20
- }
21
- create(dataset, task, size) {
22
- return new TabularData(dataset, task, size);
23
- }
24
- }
@@ -1,11 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import type { Task } from '../../index.js';
3
- import { Data } from './data.js';
4
- /**
5
- * Disco data made of textual samples.
6
- */
7
- export declare class TextData extends Data {
8
- readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
9
- static init(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Promise<TextData>;
10
- protected create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): TextData;
11
- }
@@ -1,14 +0,0 @@
1
- import { Data } from './data.js';
2
- import { TEXT_PREPROCESSING } from './preprocessing/index.js';
3
- /**
4
- * Disco data made of textual samples.
5
- */
6
- export class TextData extends Data {
7
- availablePreprocessing = TEXT_PREPROCESSING;
8
- static init(dataset, task, size) {
9
- return Promise.resolve(new TextData(dataset, task, size));
10
- }
11
- create(dataset, task, size) {
12
- return new TextData(dataset, task, size);
13
- }
14
- }
@@ -1,35 +0,0 @@
1
- /** Dataset shapers, convenient to map with */
2
- import { PreTrainedTokenizer } from "@xenova/transformers";
3
- import { List } from "immutable";
4
- import { Image } from "./dataset/image.js";
5
- /**
6
- * Convert a string to a number
7
- *
8
- * @throws if it isn't written as a number
9
- */
10
- export declare function convertToNumber(raw: string): number;
11
- /**
12
- * Return the named field of an object with string values
13
- *
14
- * @throws if the named field isn't there
15
- */
16
- export declare function extractColumn(row: Partial<Record<string, string>>, column: string): string;
17
- /**
18
- * Return the index of the element in the given list
19
- *
20
- * @throws if not found
21
- */
22
- export declare function indexInList(element: string, elements: List<string>): number;
23
- /**
24
- * Tokenize and truncates input strings
25
- *
26
- * @param length number of tokens
27
- * @returns encoded string in an array of token, size of max_length
28
- */
29
- export declare function tokenizeAndLeftPad(line: string, tokenizer: PreTrainedTokenizer, length: number): number[];
30
- /** Remove the alpha channel of an image */
31
- export declare function removeAlpha<W extends number, H extends number>(image: Image<4, W, H>): Image<3, W, H>;
32
- export declare function removeAlpha<D extends 1 | 3, W extends number, H extends number>(image: Image<D | 4, W, H>): Image<D, W, H>;
33
- /** Convert monochrome images to multicolor */
34
- export declare function expandToMulticolor<W extends number, H extends number>(image: Image<1, W, H>): Image<3, W, H>;
35
- export declare function expandToMulticolor<D extends 3 | 4, W extends number, H extends number>(image: Image<1 | D, W, H>): Image<D, W, H>;
@@ -1,89 +0,0 @@
1
- /** Dataset shapers, convenient to map with */
2
- import { Repeat, Seq } from "immutable";
3
- import { Image } from "./dataset/image.js";
4
- /**
5
- * Convert a string to a number
6
- *
7
- * @throws if it isn't written as a number
8
- */
9
- export function convertToNumber(raw) {
10
- const num = Number.parseFloat(raw);
11
- if (Number.isNaN(num))
12
- throw new Error(`unable to parse "${raw}" as number`);
13
- return num;
14
- }
15
- /**
16
- * Return the named field of an object with string values
17
- *
18
- * @throws if the named field isn't there
19
- */
20
- export function extractColumn(row, column) {
21
- const raw = row[column];
22
- if (raw === undefined)
23
- throw new Error(`${column} not found in row`);
24
- return raw;
25
- }
26
- /**
27
- * Return the index of the element in the given list
28
- *
29
- * @throws if not found
30
- */
31
- export function indexInList(element, elements) {
32
- const ret = elements.indexOf(element);
33
- if (ret === -1)
34
- throw new Error(`${element} not found in list`);
35
- return ret;
36
- }
37
- function isArrayOfNumber(raw) {
38
- return Array.isArray(raw) && raw.every((e) => typeof e === "number");
39
- }
40
- /**
41
- * Tokenize and truncates input strings
42
- *
43
- * @param length number of tokens
44
- * @returns encoded string in an array of token, size of max_length
45
- */
46
- export function tokenizeAndLeftPad(line, tokenizer, length) {
47
- if (!Number.isInteger(length))
48
- throw new Error("length should be an integer");
49
- // Transformers.js currently only supports right padding while we need left for text generation
50
- // Right padding should be supported in the future, once it is, we can directly pad while tokenizing
51
- // https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
52
- const tokenized = tokenizer(line, {
53
- padding: false,
54
- truncation: true,
55
- return_tensor: false,
56
- max_length: length,
57
- });
58
- if (typeof tokenized !== "object" ||
59
- tokenized === null ||
60
- !("input_ids" in tokenized) ||
61
- !isArrayOfNumber(tokenized.input_ids))
62
- throw new Error("tokenizer returns unexcepted type");
63
- const tokens = tokenized.input_ids;
64
- const paddingSize = length - tokens.length;
65
- if (paddingSize < 0)
66
- throw new Error("tokenized returned more token than excepted");
67
- const padding = new Array(paddingSize);
68
- padding.fill(tokenizer.pad_token_id);
69
- const padded = padding.concat(tokens);
70
- return padded;
71
- }
72
- export function removeAlpha(image) {
73
- switch (image.depth) {
74
- case 1:
75
- case 3:
76
- return new Image(image.data, image.width, image.height, image.depth);
77
- case 4:
78
- return new Image(image.data.filter((_, i) => i % 4 !== 3), image.width, image.height, 3);
79
- }
80
- }
81
- export function expandToMulticolor(image) {
82
- switch (image.depth) {
83
- case 1:
84
- return new Image(Uint8Array.from(Seq(image.data).flatMap((v) => Repeat(v, 3))), image.width, image.height, 3);
85
- case 3:
86
- case 4:
87
- return new Image(image.data, image.width, image.height, image.depth);
88
- }
89
- }
package/dist/types.d.ts DELETED
@@ -1,3 +0,0 @@
1
- import { Dataset, Image, Tabular, Text } from "./dataset/index.js";
2
- export type TypedDataset = ["image", Dataset<Image>] | ["tabular", Dataset<Tabular>] | ["text", Dataset<Text>];
3
- export type TypedLabeledDataset = ["image", Dataset<[Image, label: string]>] | ["tabular", Dataset<Tabular>] | ["text", Dataset<Text>];
package/dist/types.js DELETED
@@ -1 +0,0 @@
1
- export {};
@@ -1 +0,0 @@
1
- export { Validator } from './validator.js';
@@ -1 +0,0 @@
1
- export { Validator } from './validator.js';
@@ -1,10 +0,0 @@
1
- import type { Model, Task, TypedDataset, TypedLabeledDataset } from "../index.js";
2
- export declare class Validator {
3
- #private;
4
- readonly task: Task;
5
- constructor(task: Task, model: Model);
6
- /** infer every line of the dataset and check that it is as labeled */
7
- test(dataset: TypedLabeledDataset): AsyncGenerator<boolean>;
8
- /** use the model to predict every line of the dataset */
9
- infer(dataset: TypedDataset): AsyncGenerator<number, void>;
10
- }
@@ -1,113 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { datasetToData, labeledDatasetToData, } from "../dataset/data/helpers.js";
3
- function intoTFDataset(iter) {
4
- // @ts-expect-error generator
5
- return tf.data.generator(async function* () {
6
- yield* iter;
7
- });
8
- }
9
- export class Validator {
10
- task;
11
- #model;
12
- constructor(task, model) {
13
- this.task = task;
14
- this.#model = model;
15
- }
16
- /** infer every line of the dataset and check that it is as labeled */
17
- async *test(dataset) {
18
- const preprocessed = (await labeledDatasetToData(this.task, dataset)).preprocess();
19
- const batched = preprocessed.batch().dataset;
20
- const iterator = await tf.data
21
- .zip([
22
- preprocessed.dataset.map((t) => {
23
- if (typeof t !== "object" ||
24
- !("ys" in t) ||
25
- !(t.ys instanceof tf.Tensor) ||
26
- !(t.ys.rank === 1 || t.ys.rank === 2))
27
- throw new Error("unexpected preprocessed dataset");
28
- if ("xs" in t)
29
- tf.dispose(t.xs);
30
- return t.ys;
31
- }),
32
- intoTFDataset(this.#inferOnBatchedData(batched)),
33
- ])
34
- .iterator();
35
- for (let iter = await iterator.next(); iter.done !== true; iter = await iterator.next()) {
36
- const zipped = iter.value;
37
- const label = await getLabel(zipped[0]);
38
- tf.dispose(zipped[0]);
39
- const infered = zipped[1];
40
- yield label === infered;
41
- }
42
- }
43
- /** use the model to predict every line of the dataset */
44
- async *infer(dataset) {
45
- const data = await datasetToData(this.task, dataset);
46
- const batched = data.preprocess().batch().dataset;
47
- yield* this.#inferOnBatchedData(batched);
48
- }
49
- async *#inferOnBatchedData(batched) {
50
- const iterator = await batched.iterator();
51
- for (let iter = await iterator.next(); iter.done !== true; iter = await iterator.next()) {
52
- const row = iter.value;
53
- if (typeof row !== "object" ||
54
- !("xs" in row) ||
55
- !(row.xs instanceof tf.Tensor))
56
- throw new Error("unexpected shape of dataset");
57
- const prediction = await this.#model.predict(row.xs);
58
- tf.dispose(row);
59
- let predictions;
60
- switch (prediction.rank) {
61
- case 2:
62
- case 3:
63
- predictions = await getLabels(
64
- // cast as rank was just checked
65
- prediction);
66
- prediction.dispose();
67
- break;
68
- default:
69
- throw new Error("unexpected batched prediction shape");
70
- }
71
- prediction.dispose();
72
- for (const prediction of predictions)
73
- yield prediction;
74
- }
75
- }
76
- }
77
- async function getLabels(ys) {
78
- // cast as unstack drop a dimension and tfjs doesn't type correctly
79
- return Promise.all(tf.unstack(ys).map((y) => {
80
- const ret = getLabel(y);
81
- y.dispose();
82
- return ret;
83
- }));
84
- }
85
- async function getLabel(ys) {
86
- switch (ys.rank) {
87
- case 1: {
88
- if (ys.shape[0] == 1) {
89
- // Binary classification
90
- const threshold = tf.scalar(0.5);
91
- const binaryTensor = ys.greaterEqual(threshold);
92
- const binaryArray = await binaryTensor.data();
93
- tf.dispose([binaryTensor, threshold]);
94
- return binaryArray[0];
95
- }
96
- // Multi-class classification
97
- const indexTensor = ys.argMax();
98
- const indexArray = await indexTensor.data();
99
- tf.dispose([indexTensor]);
100
- return indexArray[0];
101
- // Multi-label classification is not supported
102
- }
103
- case 2: {
104
- // it's LLM, we only extract the next token
105
- const firstToken = tf.tidy(() => ys.gather([0]).squeeze().argMax());
106
- const raw = await firstToken.data();
107
- firstToken.dispose();
108
- return raw[0];
109
- }
110
- default:
111
- throw new Error("unexpected tensor rank");
112
- }
113
- }