@epfml/discojs 3.0.1-p20241025115642.0 → 3.0.1-p20241107104659.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. package/dist/aggregator/get.d.ts +3 -3
  2. package/dist/client/client.d.ts +5 -5
  3. package/dist/client/decentralized/decentralized_client.d.ts +2 -2
  4. package/dist/client/federated/federated_client.d.ts +2 -2
  5. package/dist/client/utils.d.ts +2 -2
  6. package/dist/dataset/dataset.d.ts +9 -2
  7. package/dist/dataset/dataset.js +83 -36
  8. package/dist/dataset/image.d.ts +5 -0
  9. package/dist/dataset/image.js +6 -1
  10. package/dist/dataset/index.d.ts +0 -1
  11. package/dist/dataset/index.js +0 -1
  12. package/dist/dataset/types.d.ts +2 -0
  13. package/dist/default_tasks/cifar10.d.ts +1 -1
  14. package/dist/default_tasks/cifar10.js +2 -3
  15. package/dist/default_tasks/lus_covid.d.ts +1 -1
  16. package/dist/default_tasks/lus_covid.js +2 -3
  17. package/dist/default_tasks/mnist.d.ts +1 -1
  18. package/dist/default_tasks/mnist.js +3 -5
  19. package/dist/default_tasks/simple_face.d.ts +1 -1
  20. package/dist/default_tasks/simple_face.js +2 -3
  21. package/dist/default_tasks/titanic.d.ts +1 -1
  22. package/dist/default_tasks/titanic.js +3 -6
  23. package/dist/default_tasks/wikitext.d.ts +1 -1
  24. package/dist/default_tasks/wikitext.js +1 -2
  25. package/dist/index.d.ts +4 -5
  26. package/dist/index.js +4 -5
  27. package/dist/models/gpt/index.d.ts +13 -16
  28. package/dist/models/gpt/index.js +62 -43
  29. package/dist/models/gpt/model.d.ts +1 -15
  30. package/dist/models/gpt/model.js +1 -75
  31. package/dist/models/model.d.ts +7 -12
  32. package/dist/models/tfjs.d.ts +10 -8
  33. package/dist/models/tfjs.js +106 -44
  34. package/dist/models/tokenizer.d.ts +1 -1
  35. package/dist/privacy.js +1 -1
  36. package/dist/processing/image.d.ts +18 -0
  37. package/dist/processing/image.js +75 -0
  38. package/dist/processing/index.d.ts +8 -0
  39. package/dist/processing/index.js +106 -0
  40. package/dist/processing/tabular.d.ts +19 -0
  41. package/dist/processing/tabular.js +33 -0
  42. package/dist/processing/text.d.ts +11 -0
  43. package/dist/processing/text.js +33 -0
  44. package/dist/serialization/model.d.ts +3 -3
  45. package/dist/serialization/model.js +19 -6
  46. package/dist/task/task.d.ts +4 -3
  47. package/dist/task/task.js +5 -3
  48. package/dist/task/task_handler.d.ts +3 -3
  49. package/dist/task/task_provider.d.ts +4 -4
  50. package/dist/task/training_information.d.ts +25 -16
  51. package/dist/task/training_information.js +76 -72
  52. package/dist/training/disco.d.ts +20 -12
  53. package/dist/training/disco.js +32 -13
  54. package/dist/training/trainer.d.ts +6 -7
  55. package/dist/training/trainer.js +6 -6
  56. package/dist/types/data_format.d.ts +40 -0
  57. package/dist/types/index.d.ts +2 -0
  58. package/dist/types/index.js +1 -0
  59. package/dist/validator.d.ts +10 -0
  60. package/dist/validator.js +30 -0
  61. package/package.json +4 -2
  62. package/dist/dataset/data/data.d.ts +0 -47
  63. package/dist/dataset/data/data.js +0 -88
  64. package/dist/dataset/data/data_split.d.ts +0 -8
  65. package/dist/dataset/data/helpers.d.ts +0 -10
  66. package/dist/dataset/data/helpers.js +0 -97
  67. package/dist/dataset/data/image_data.d.ts +0 -11
  68. package/dist/dataset/data/image_data.js +0 -43
  69. package/dist/dataset/data/index.d.ts +0 -5
  70. package/dist/dataset/data/index.js +0 -5
  71. package/dist/dataset/data/preprocessing/base.d.ts +0 -16
  72. package/dist/dataset/data/preprocessing/base.js +0 -1
  73. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
  74. package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
  75. package/dist/dataset/data/preprocessing/index.d.ts +0 -4
  76. package/dist/dataset/data/preprocessing/index.js +0 -3
  77. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
  78. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
  79. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
  80. package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
  81. package/dist/dataset/data/tabular_data.d.ts +0 -11
  82. package/dist/dataset/data/tabular_data.js +0 -24
  83. package/dist/dataset/data/text_data.d.ts +0 -11
  84. package/dist/dataset/data/text_data.js +0 -14
  85. package/dist/processing.d.ts +0 -35
  86. package/dist/processing.js +0 -89
  87. package/dist/types.d.ts +0 -3
  88. package/dist/types.js +0 -1
  89. package/dist/validation/index.d.ts +0 -1
  90. package/dist/validation/index.js +0 -1
  91. package/dist/validation/validator.d.ts +0 -10
  92. package/dist/validation/validator.js +0 -113
  93. /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
@@ -1,11 +1,19 @@
1
- import { client as clients, BatchLogs, EpochLogs, Logger, Task, TrainingInformation } from "../index.js";
2
- import type { TypedLabeledDataset } from "../index.js";
1
+ import { client as clients, BatchLogs, EpochLogs, Logger, TrainingInformation, Dataset } from "../index.js";
2
+ import type { DataFormat, DataType, Task } from "../index.js";
3
3
  import type { Aggregator } from "../aggregator/index.js";
4
4
  import { EventEmitter } from "../utils/event_emitter.js";
5
5
  import { RoundLogs, Trainer } from "./trainer.js";
6
6
  interface DiscoConfig {
7
- scheme: TrainingInformation["scheme"];
7
+ scheme: TrainingInformation<DataType>["scheme"];
8
8
  logger: Logger;
9
+ /**
10
+ * keep preprocessed dataset in memory while training
11
+ *
12
+ * `Dataset` is cached anyway but this cache can get evicted.
13
+ * if your system has enough memory to keep the whole preprocessed `Dataset` around,
14
+ * you can switch this on to only do it once, trading memory for speed.
15
+ */
16
+ preprocessOnce: boolean;
9
17
  }
10
18
  export type RoundStatus = 'not enough participants' | // Server notification to wait for more participants
11
19
  'updating model' | // fetching/aggregating local updates into a global model
@@ -16,11 +24,11 @@ export type RoundStatus = 'not enough participants' | // Server notification to
16
24
  * a convenient object providing a reduced yet complete API that wraps model training and
17
25
  * communication with nodes.
18
26
  */
19
- export declare class Disco extends EventEmitter<{
20
- 'status': RoundStatus;
27
+ export declare class Disco<D extends DataType> extends EventEmitter<{
28
+ status: RoundStatus;
21
29
  }> {
22
30
  #private;
23
- readonly trainer: Trainer;
31
+ readonly trainer: Trainer<D>;
24
32
  /**
25
33
  * Connect to the given task and get ready to train.
26
34
  *
@@ -28,25 +36,25 @@ export declare class Disco extends EventEmitter<{
28
36
  * @param clientConfig client to connect with or parameters on how to create one.
29
37
  * @param config the DiscoConfig
30
38
  */
31
- constructor(task: Task, clientConfig: clients.Client | URL | {
39
+ constructor(task: Task<D>, clientConfig: clients.Client | URL | {
32
40
  aggregator: Aggregator;
33
41
  url: URL;
34
42
  }, config: Partial<DiscoConfig>);
35
43
  /** Train on dataset, yielding logs of every round. */
36
- trainByRound(dataset: TypedLabeledDataset): AsyncGenerator<RoundLogs>;
44
+ trainByRound(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<RoundLogs>;
37
45
  /** Train on dataset, yielding logs of every epoch. */
38
- trainByEpoch(dataset: TypedLabeledDataset): AsyncGenerator<EpochLogs>;
46
+ trainByEpoch(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<EpochLogs>;
39
47
  /** Train on dataset, yielding logs of every batch. */
40
- trainByBatch(dataTuple: TypedLabeledDataset): AsyncGenerator<BatchLogs>;
48
+ trainByBatch(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<BatchLogs>;
41
49
  /** Run whole train on dataset. */
42
- trainFully(dataTuple: TypedLabeledDataset): Promise<void>;
50
+ trainFully(dataset: Dataset<DataFormat.Raw[D]>): Promise<void>;
43
51
  /**
44
52
  * Train on dataset, yield the nested steps.
45
53
  *
46
54
  * Don't forget to await the yielded generator otherwise nothing will progress.
47
55
  * If you don't care about the whole process, use one of the other train methods.
48
56
  **/
49
- train(dataset: TypedLabeledDataset): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
57
+ train(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
50
58
  /**
51
59
  * Completely stops the ongoing training instance.
52
60
  */
@@ -1,9 +1,8 @@
1
- import { async_iterator, client as clients, ConsoleLogger, } from "../index.js";
1
+ import { async_iterator, client as clients, ConsoleLogger, processing, Dataset, } from "../index.js";
2
2
  import { getAggregator } from "../aggregator/index.js";
3
3
  import { enumerate, split } from "../utils/async_iterator.js";
4
4
  import { EventEmitter } from "../utils/event_emitter.js";
5
5
  import { Trainer } from "./trainer.js";
6
- import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
7
6
  /**
8
7
  * Top-level class handling distributed training from a client's perspective. It is meant to be
9
8
  * a convenient object providing a reduced yet complete API that wraps model training and
@@ -14,6 +13,7 @@ export class Disco extends EventEmitter {
14
13
  #client;
15
14
  #logger;
16
15
  #task;
16
+ #preprocessOnce;
17
17
  /**
18
18
  * Connect to the given task and get ready to train.
19
19
  *
@@ -23,9 +23,10 @@ export class Disco extends EventEmitter {
23
23
  */
24
24
  constructor(task, clientConfig, config) {
25
25
  super();
26
- const { scheme, logger } = {
26
+ const { scheme, logger, preprocessOnce } = {
27
27
  scheme: task.trainingInformation.scheme,
28
28
  logger: new ConsoleLogger(),
29
+ preprocessOnce: false,
29
30
  ...config,
30
31
  };
31
32
  let client;
@@ -46,11 +47,12 @@ export class Disco extends EventEmitter {
46
47
  if (client.task !== task)
47
48
  throw new Error("client not setup for given task");
48
49
  this.#logger = logger;
50
+ this.#preprocessOnce = preprocessOnce;
49
51
  this.#client = client;
50
52
  this.#task = task;
51
53
  this.trainer = new Trainer(task, client);
52
54
  // Simply propagate the training status events emitted by the client
53
- this.#client.on('status', status => this.emit('status', status));
55
+ this.#client.on("status", (status) => this.emit("status", status));
54
56
  }
55
57
  /** Train on dataset, yielding logs of every round. */
56
58
  async *trainByRound(dataset) {
@@ -74,14 +76,14 @@ export class Disco extends EventEmitter {
74
76
  }
75
77
  }
76
78
  /** Train on dataset, yielding logs of every batch. */
77
- async *trainByBatch(dataTuple) {
78
- for await (const round of this.train(dataTuple))
79
+ async *trainByBatch(dataset) {
80
+ for await (const round of this.train(dataset))
79
81
  for await (const epoch of round)
80
82
  yield* epoch;
81
83
  }
82
84
  /** Run whole train on dataset. */
83
- async trainFully(dataTuple) {
84
- for await (const round of this.train(dataTuple))
85
+ async trainFully(dataset) {
86
+ for await (const round of this.train(dataset))
85
87
  for await (const epoch of round)
86
88
  for await (const _ of epoch)
87
89
  ;
@@ -94,12 +96,11 @@ export class Disco extends EventEmitter {
94
96
  **/
95
97
  async *train(dataset) {
96
98
  this.#logger.success("Training started");
97
- const data = await labeledDatasetToDataSplit(this.#task, dataset);
98
- const trainData = data.train.preprocess().batch().dataset;
99
- const validationData = data.validation?.preprocess().batch().dataset ?? trainData;
99
+ const [trainingDataset, validationDataset] = await this.#preprocessSplitAndBatch(dataset);
100
100
  // the client fetches the latest weights upon connection
101
- this.trainer.model = await this.#client.connect();
102
- for await (const [round, epochs] of enumerate(this.trainer.train(trainData, validationData))) {
101
+ // TODO unsafe cast
102
+ this.trainer.model = (await this.#client.connect());
103
+ for await (const [round, epochs] of enumerate(this.trainer.train(trainingDataset, validationDataset))) {
103
104
  yield async function* () {
104
105
  const [gen, returnedRoundLogs] = split(epochs);
105
106
  for await (const [epoch, batches] of enumerate(gen)) {
@@ -131,4 +132,22 @@ export class Disco extends EventEmitter {
131
132
  async close() {
132
133
  await this.#client.disconnect();
133
134
  }
135
+ async #preprocessSplitAndBatch(dataset) {
136
+ const { batchSize, validationSplit } = this.#task.trainingInformation;
137
+ const preprocessed = await processing.preprocess(this.#task, dataset);
138
+ const [training, validation] = (this.#preprocessOnce
139
+ ? new Dataset(await arrayFromAsync(preprocessed))
140
+ : preprocessed).split(validationSplit);
141
+ return [
142
+ training.batch(batchSize).cached(),
143
+ validation.batch(batchSize).cached(),
144
+ ];
145
+ }
146
+ }
147
+ // Array.fromAsync not yet widely used (2024)
148
+ async function arrayFromAsync(iter) {
149
+ const ret = [];
150
+ for await (const e of iter)
151
+ ret.push(e);
152
+ return ret;
134
153
  }
@@ -1,17 +1,16 @@
1
- import * as tf from "@tensorflow/tfjs";
2
1
  import { List } from "immutable";
3
- import type { BatchLogs, EpochLogs, Model, Task } from "../index.js";
2
+ import type { Batched, BatchLogs, Dataset, DataFormat, DataType, EpochLogs, Model, Task } from "../index.js";
4
3
  import { Client } from "../client/index.js";
5
4
  export interface RoundLogs {
6
5
  epochs: List<EpochLogs>;
7
6
  participants: number;
8
7
  }
9
8
  /** Train a model and exchange with others **/
10
- export declare class Trainer {
9
+ export declare class Trainer<D extends DataType> {
11
10
  #private;
12
- get model(): Model;
13
- set model(model: Model);
14
- constructor(task: Task, client: Client);
11
+ get model(): Model<D>;
12
+ set model(model: Model<D>);
13
+ constructor(task: Task<D>, client: Client);
15
14
  stopTraining(): Promise<void>;
16
- train(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
15
+ train(dataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
17
16
  }
@@ -29,23 +29,23 @@ export class Trainer {
29
29
  async stopTraining() {
30
30
  await this.#training?.return();
31
31
  }
32
- async *train(dataset, valDataset) {
32
+ async *train(dataset, validationDataset) {
33
33
  if (this.#training !== undefined)
34
34
  throw new Error("training already running, stop it before launching a new one");
35
35
  try {
36
- this.#training = this.#runRounds(dataset, valDataset);
36
+ this.#training = this.#runRounds(dataset, validationDataset);
37
37
  yield* this.#training;
38
38
  }
39
39
  finally {
40
40
  this.#training = undefined;
41
41
  }
42
42
  }
43
- async *#runRounds(dataset, valDataset) {
43
+ async *#runRounds(dataset, validationDataset) {
44
44
  const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
45
45
  let previousRoundWeights;
46
46
  for (let round = 0; round < totalRound; round++) {
47
47
  await this.#client.onRoundBeginCommunication();
48
- yield this.#runRound(dataset, valDataset);
48
+ yield this.#runRound(dataset, validationDataset);
49
49
  let localWeights = this.model.weights;
50
50
  if (this.#privacy !== undefined)
51
51
  localWeights = await applyPrivacy(previousRoundWeights, localWeights, this.#privacy);
@@ -53,10 +53,10 @@ export class Trainer {
53
53
  this.model.weights = previousRoundWeights = networkWeights;
54
54
  }
55
55
  }
56
- async *#runRound(dataset, valDataset) {
56
+ async *#runRound(dataset, validationDataset) {
57
57
  let epochsLogs = List();
58
58
  for (let epoch = 0; epoch < this.#roundDuration; epoch++) {
59
- const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, valDataset));
59
+ const [gen, epochLogs] = async_iterator.split(this.model.train(dataset, validationDataset));
60
60
  yield gen;
61
61
  epochsLogs = epochsLogs.push(await epochLogs);
62
62
  }
@@ -0,0 +1,40 @@
1
+ import { List } from "immutable";
2
+ import type { Image, processing, Tabular, Text } from "../index.js";
3
+ /**
4
+ * The data & label format goes through various stages.
5
+ * Raw* is preprocessed into ModelEncoded.
6
+ * ModelEncoded's labels are postprocess into Inferred.
7
+ *
8
+ * Raw* -> ModelEncoded -> Inferred
9
+ */
10
+ /** what gets ingested by Disco */
11
+ export interface Raw {
12
+ image: [Image, label: string];
13
+ tabular: Tabular;
14
+ text: Text;
15
+ }
16
+ /** what gets ingested by the Validator */
17
+ export interface RawWithoutLabel {
18
+ image: Image;
19
+ tabular: Tabular;
20
+ text: Text;
21
+ }
22
+ type Token = number;
23
+ /**
24
+ * what model can understand
25
+ *
26
+ * training needs data & label input
27
+ * prediction needs data input and outputs label
28
+ **/
29
+ export interface ModelEncoded {
30
+ image: [image: processing.NormalizedImage<3>, label: number];
31
+ tabular: [row: List<number>, number];
32
+ text: [line: List<Token>, next: Token];
33
+ }
34
+ /** what gets outputted by the Validator, for humans */
35
+ export interface Inferred {
36
+ image: string;
37
+ tabular: number;
38
+ text: string;
39
+ }
40
+ export {};
@@ -0,0 +1,2 @@
1
+ export * as DataFormat from "./data_format.js";
2
+ export type DataType = "image" | "tabular" | "text";
@@ -0,0 +1 @@
1
+ export * as DataFormat from "./data_format.js";
@@ -0,0 +1,10 @@
1
+ import type { Dataset, DataFormat, DataType, Model, Task } from "./index.js";
2
+ export declare class Validator<D extends DataType> {
3
+ #private;
4
+ readonly task: Task<D>;
5
+ constructor(task: Task<D>, model: Model<D>);
6
+ /** infer every line of the dataset and check that it is as labelled */
7
+ test(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<boolean, void>;
8
+ /** use the model to predict every line of the dataset */
9
+ infer(dataset: Dataset<DataFormat.RawWithoutLabel[D]>): AsyncGenerator<DataFormat.Inferred[D], void>;
10
+ }
@@ -0,0 +1,30 @@
1
+ import { processing } from "./index.js";
2
+ export class Validator {
3
+ task;
4
+ #model;
5
+ constructor(task, model) {
6
+ this.task = task;
7
+ this.#model = model;
8
+ }
9
+ /** infer every line of the dataset and check that it is as labelled */
10
+ async *test(dataset) {
11
+ const results = (await processing.preprocess(this.task, dataset))
12
+ .batch(this.task.trainingInformation.batchSize)
13
+ .map(async (batch) => (await this.#model.predict(batch.map(([inputs, _]) => inputs)))
14
+ .zip(batch.map(([_, outputs]) => outputs))
15
+ .map(([inferred, truth]) => inferred === truth))
16
+ .unbatch();
17
+ for await (const e of results)
18
+ yield e;
19
+ }
20
+ /** use the model to predict every line of the dataset */
21
+ async *infer(dataset) {
22
+ const modelPredictions = (await processing.preprocessWithoutLabel(this.task, dataset))
23
+ .batch(this.task.trainingInformation.batchSize)
24
+ .map((batch) => this.#model.predict(batch))
25
+ .unbatch();
26
+ const predictions = await processing.postprocess(this.task, modelPredictions);
27
+ for await (const e of predictions)
28
+ yield e;
29
+ }
30
+ }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20241025115642.0",
3
+ "version": "3.0.1-p20241107104659.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -19,6 +19,8 @@
19
19
  },
20
20
  "homepage": "https://github.com/epfml/disco#readme",
21
21
  "dependencies": {
22
+ "@jimp/core": "1",
23
+ "@jimp/plugin-resize": "1",
22
24
  "@msgpack/msgpack": "^3.0.0-beta2",
23
25
  "@tensorflow/tfjs": "4",
24
26
  "@xenova/transformers": "2",
@@ -31,7 +33,7 @@
31
33
  },
32
34
  "devDependencies": {
33
35
  "@tensorflow/tfjs-node": "4",
34
- "@types/chai": "4",
36
+ "@types/chai": "5",
35
37
  "@types/mocha": "10",
36
38
  "@types/simple-peer": "9",
37
39
  "chai": "5",
@@ -1,47 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import type { List } from 'immutable';
3
- import type { Task } from '../../index.js';
4
- import type { PreprocessingFunction } from './preprocessing/base.js';
5
- /**
6
- * Abstract class representing an immutable Disco dataset, including a TF.js dataset,
7
- * Disco task and set of preprocessing functions.
8
- */
9
- export declare abstract class Data {
10
- readonly dataset: tf.data.Dataset<tf.TensorContainer>;
11
- readonly task: Task;
12
- readonly size?: number | undefined;
13
- abstract readonly availablePreprocessing: List<PreprocessingFunction>;
14
- protected constructor(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number | undefined);
15
- static init(_dataset: tf.data.Dataset<tf.TensorContainer>, _task: Task, _size?: number): Promise<Data>;
16
- /**
17
- * Callable abstract method instead of constructor.
18
- */
19
- protected abstract create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Data;
20
- /**
21
- * Creates a new Disco data object containing the batched TF.js dataset, according to the
22
- * task's parameters.
23
- * @returns The batched Disco data
24
- */
25
- batch(): Data;
26
- /**
27
- * The TF.js dataset batched according to the task's parameters.
28
- */
29
- get batchedDataset(): tf.data.Dataset<tf.TensorContainer>;
30
- /**
31
- * Creates a new Disco data object containing the preprocessed TF.js dataset,
32
- * according to the defined set of preprocessing functions and the task's parameters.
33
- * @returns The preprocessed Disco data
34
- */
35
- preprocess(): Data;
36
- /**
37
- * Creates a higher level preprocessing function applying the specified set of preprocessing
38
- * functions in a series. The preprocessing functions are chained according to their defined
39
- * priority.
40
- */
41
- get preprocessing(): (entry: tf.TensorContainer) => Promise<tf.TensorContainer>;
42
- /**
43
- * The TF.js dataset preprocessing according to the set of preprocessing functions and the task's
44
- * parameters.
45
- */
46
- get preprocessedDataset(): tf.data.Dataset<tf.TensorContainer>;
47
- }
@@ -1,88 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- /**
3
- * Abstract class representing an immutable Disco dataset, including a TF.js dataset,
4
- * Disco task and set of preprocessing functions.
5
- */
6
- export class Data {
7
- dataset;
8
- task;
9
- size;
10
- constructor(dataset, task, size) {
11
- this.dataset = dataset;
12
- this.task = task;
13
- this.size = size;
14
- }
15
- static init(_dataset, _task, _size) {
16
- return Promise.reject(new Error('abstract'));
17
- }
18
- /**
19
- * Creates a new Disco data object containing the batched TF.js dataset, according to the
20
- * task's parameters.
21
- * @returns The batched Disco data
22
- */
23
- batch() {
24
- return this.create(this.batchedDataset, this.task, this.size);
25
- }
26
- /**
27
- * The TF.js dataset batched according to the task's parameters.
28
- */
29
- get batchedDataset() {
30
- const batchSize = this.task.trainingInformation.batchSize;
31
- return batchSize === undefined
32
- ? this.dataset
33
- : this.dataset.batch(batchSize);
34
- }
35
- /**
36
- * Creates a new Disco data object containing the preprocessed TF.js dataset,
37
- * according to the defined set of preprocessing functions and the task's parameters.
38
- * @returns The preprocessed Disco data
39
- */
40
- preprocess() {
41
- return this.create(this.preprocessedDataset, this.task, this.size);
42
- }
43
- /**
44
- * Creates a higher level preprocessing function applying the specified set of preprocessing
45
- * functions in a series. The preprocessing functions are chained according to their defined
46
- * priority.
47
- */
48
- get preprocessing() {
49
- const params = this.task.trainingInformation;
50
- const taskPreprocessing = params.preprocessingFunctions;
51
- if (taskPreprocessing === undefined ||
52
- taskPreprocessing.length === 0 ||
53
- this.availablePreprocessing === undefined ||
54
- this.availablePreprocessing.size === 0) {
55
- return x => Promise.resolve(x);
56
- }
57
- const applyPreprocessing = this.availablePreprocessing
58
- .filter((e) => e.type in taskPreprocessing)
59
- .map((e) => e.apply);
60
- const preprocessingChain = async (input) => {
61
- let currentContainer = await input; // Start with the initial tensor container
62
- for (const fn of applyPreprocessing) {
63
- const next = await fn(Promise.resolve(currentContainer), this.task);
64
- // dirty but kinda working way to dispose of converted tensors
65
- if (typeof currentContainer === "object" && typeof next === "object") {
66
- if ("xs" in currentContainer &&
67
- "xs" in next &&
68
- currentContainer.xs !== next.xs)
69
- tf.dispose(currentContainer.xs);
70
- if ("ys" in currentContainer &&
71
- "ys" in next &&
72
- currentContainer.ys !== next.ys)
73
- tf.dispose(currentContainer.ys);
74
- }
75
- currentContainer = next;
76
- }
77
- return currentContainer; // Return the final tensor container
78
- };
79
- return async (entry) => await preprocessingChain(Promise.resolve(entry));
80
- }
81
- /**
82
- * The TF.js dataset preprocessing according to the set of preprocessing functions and the task's
83
- * parameters.
84
- */
85
- get preprocessedDataset() {
86
- return this.dataset.mapAsync(this.preprocessing);
87
- }
88
- }
@@ -1,8 +0,0 @@
1
- import type { Data } from './data.js';
2
- /**
3
- * Train-validation split of Disco data.
4
- */
5
- export interface DataSplit {
6
- train: Data;
7
- validation?: Data;
8
- }
@@ -1,10 +0,0 @@
1
- /** Internal functions to help with Dataset to Data/DataSplit conversion
2
- *
3
- * @todo rm when fully using Dataset
4
- */
5
- import type { Task, TypedDataset, TypedLabeledDataset } from "../../index.js";
6
- import { Data } from "./index.js";
7
- import { DataSplit } from "./data_split.js";
8
- export declare function datasetToData(task: Task, [t, dataset]: TypedDataset): Promise<Data>;
9
- export declare function labeledDatasetToData(task: Task, [t, dataset]: TypedLabeledDataset): Promise<Data>;
10
- export declare function labeledDatasetToDataSplit(task: Task, [t, dataset]: TypedLabeledDataset): Promise<DataSplit>;
@@ -1,97 +0,0 @@
1
- /** Internal functions to help with Dataset to Data/DataSplit conversion
2
- *
3
- * @todo rm when fully using Dataset
4
- */
5
- import { List } from "immutable";
6
- import * as tf from "@tensorflow/tfjs";
7
- import { processing } from "../../index.js";
8
- import { ImageData, TabularData, TextData } from "./index.js";
9
- function intoTFDataset(iter) {
10
- // @ts-expect-error generator
11
- return tf.data.generator(async function* () {
12
- yield* iter;
13
- });
14
- }
15
- function imageToTensor(image) {
16
- return tf.tensor3d(image.data, [image.width, image.height, 3], "int32");
17
- }
18
- function tabularToNumbers(columns, row) {
19
- return List(columns)
20
- .map((column) => processing.extractColumn(row, column))
21
- .map((v) => (v !== "" ? v : "0")) // TODO how to specify defaults?
22
- .map(processing.convertToNumber)
23
- .toArray();
24
- }
25
- export async function datasetToData(task, [t, dataset]) {
26
- switch (t) {
27
- case "image": {
28
- const converted = dataset
29
- .map(processing.removeAlpha)
30
- .map((image) => processing.expandToMulticolor(image))
31
- .map((image) => ({
32
- xs: imageToTensor(image),
33
- }));
34
- return await ImageData.init(intoTFDataset(converted), task);
35
- }
36
- case "tabular": {
37
- const inputColumns = task.trainingInformation.inputColumns;
38
- if (inputColumns === undefined)
39
- throw new Error("tabular task without input columns");
40
- const converted = dataset.map((row) => ({
41
- xs: tabularToNumbers(inputColumns, row),
42
- }));
43
- return await TabularData.init(intoTFDataset(converted), task);
44
- }
45
- case "text":
46
- return await TextData.init(intoTFDataset(dataset), task);
47
- }
48
- }
49
- export async function labeledDatasetToData(task, [t, dataset]) {
50
- switch (t) {
51
- case "image": {
52
- const labels = List(task.trainingInformation.LABEL_LIST);
53
- const converted = dataset
54
- .map(([image, label]) => [
55
- processing.expandToMulticolor(processing.removeAlpha(image)),
56
- processing.indexInList(label, labels),
57
- ])
58
- .map(([image, label]) => ({
59
- xs: imageToTensor(image),
60
- ys: tf.oneHot(label, labels.size, 1, 0, "int32"),
61
- }));
62
- return await ImageData.init(intoTFDataset(converted), task);
63
- }
64
- case "tabular": {
65
- const { inputColumns, outputColumns } = task.trainingInformation;
66
- if (inputColumns === undefined || outputColumns === undefined)
67
- throw new Error("tabular task without input and output columns");
68
- const converted = dataset.map((row) => ({
69
- xs: tabularToNumbers(inputColumns, row),
70
- ys: tf.tensor1d(tabularToNumbers(outputColumns, row)),
71
- }));
72
- return await TabularData.init(intoTFDataset(converted), task);
73
- }
74
- case "text":
75
- return await TextData.init(intoTFDataset(dataset), task);
76
- }
77
- }
78
- export async function labeledDatasetToDataSplit(task, [t, dataset]) {
79
- const split = task.trainingInformation.validationSplit;
80
- let train;
81
- let validation;
82
- switch (t) {
83
- case "image": {
84
- [train, validation] = await Promise.all(dataset.split(split).map((d) => labeledDatasetToData(task, [t, d])));
85
- break;
86
- }
87
- case "tabular": {
88
- [train, validation] = await Promise.all(dataset.split(split).map((d) => labeledDatasetToData(task, [t, d])));
89
- break;
90
- }
91
- case "text": {
92
- [train, validation] = await Promise.all(dataset.split(split).map((d) => labeledDatasetToData(task, [t, d])));
93
- break;
94
- }
95
- }
96
- return { train, validation };
97
- }
@@ -1,11 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import type { Task } from '../../index.js';
3
- import { Data } from './data.js';
4
- /**
5
- * Disco data made of image samples (.jpg, .png, etc.).
6
- */
7
- export declare class ImageData extends Data {
8
- readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
9
- static init(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size?: number): Promise<Data>;
10
- protected create(dataset: tf.data.Dataset<tf.TensorContainer>, task: Task, size: number): ImageData;
11
- }