@epfml/discojs 3.0.1-p20241025115642.0 → 3.0.1-p20241028120035.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. package/dist/aggregator/get.d.ts +3 -3
  2. package/dist/client/client.d.ts +5 -5
  3. package/dist/client/decentralized/decentralized_client.d.ts +2 -2
  4. package/dist/client/federated/federated_client.d.ts +2 -2
  5. package/dist/client/utils.d.ts +2 -2
  6. package/dist/dataset/dataset.d.ts +9 -2
  7. package/dist/dataset/dataset.js +83 -36
  8. package/dist/dataset/image.d.ts +5 -0
  9. package/dist/dataset/image.js +6 -1
  10. package/dist/dataset/index.d.ts +0 -1
  11. package/dist/dataset/index.js +0 -1
  12. package/dist/dataset/types.d.ts +2 -0
  13. package/dist/default_tasks/cifar10.d.ts +1 -1
  14. package/dist/default_tasks/cifar10.js +2 -3
  15. package/dist/default_tasks/lus_covid.d.ts +1 -1
  16. package/dist/default_tasks/lus_covid.js +2 -3
  17. package/dist/default_tasks/mnist.d.ts +1 -1
  18. package/dist/default_tasks/mnist.js +2 -4
  19. package/dist/default_tasks/simple_face.d.ts +1 -1
  20. package/dist/default_tasks/simple_face.js +2 -3
  21. package/dist/default_tasks/titanic.d.ts +1 -1
  22. package/dist/default_tasks/titanic.js +3 -6
  23. package/dist/default_tasks/wikitext.d.ts +1 -1
  24. package/dist/default_tasks/wikitext.js +1 -2
  25. package/dist/index.d.ts +4 -5
  26. package/dist/index.js +4 -5
  27. package/dist/models/gpt/index.d.ts +13 -16
  28. package/dist/models/gpt/index.js +62 -43
  29. package/dist/models/gpt/model.d.ts +1 -15
  30. package/dist/models/gpt/model.js +1 -75
  31. package/dist/models/model.d.ts +7 -12
  32. package/dist/models/tfjs.d.ts +10 -8
  33. package/dist/models/tfjs.js +106 -44
  34. package/dist/models/tokenizer.d.ts +1 -1
  35. package/dist/privacy.js +1 -1
  36. package/dist/processing/image.d.ts +18 -0
  37. package/dist/processing/image.js +75 -0
  38. package/dist/processing/index.d.ts +8 -0
  39. package/dist/processing/index.js +106 -0
  40. package/dist/processing/tabular.d.ts +19 -0
  41. package/dist/processing/tabular.js +33 -0
  42. package/dist/processing/text.d.ts +11 -0
  43. package/dist/processing/text.js +33 -0
  44. package/dist/serialization/model.d.ts +3 -3
  45. package/dist/serialization/model.js +19 -6
  46. package/dist/task/task.d.ts +4 -3
  47. package/dist/task/task.js +5 -3
  48. package/dist/task/task_handler.d.ts +3 -3
  49. package/dist/task/task_provider.d.ts +4 -4
  50. package/dist/task/training_information.d.ts +25 -16
  51. package/dist/task/training_information.js +76 -72
  52. package/dist/training/disco.d.ts +20 -12
  53. package/dist/training/disco.js +32 -13
  54. package/dist/training/trainer.d.ts +6 -7
  55. package/dist/training/trainer.js +6 -6
  56. package/dist/types/data_format.d.ts +40 -0
  57. package/dist/types/index.d.ts +2 -0
  58. package/dist/types/index.js +1 -0
  59. package/dist/validator.d.ts +10 -0
  60. package/dist/validator.js +30 -0
  61. package/package.json +4 -2
  62. package/dist/dataset/data/data.d.ts +0 -47
  63. package/dist/dataset/data/data.js +0 -88
  64. package/dist/dataset/data/data_split.d.ts +0 -8
  65. package/dist/dataset/data/helpers.d.ts +0 -10
  66. package/dist/dataset/data/helpers.js +0 -97
  67. package/dist/dataset/data/image_data.d.ts +0 -11
  68. package/dist/dataset/data/image_data.js +0 -43
  69. package/dist/dataset/data/index.d.ts +0 -5
  70. package/dist/dataset/data/index.js +0 -5
  71. package/dist/dataset/data/preprocessing/base.d.ts +0 -16
  72. package/dist/dataset/data/preprocessing/base.js +0 -1
  73. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
  74. package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
  75. package/dist/dataset/data/preprocessing/index.d.ts +0 -4
  76. package/dist/dataset/data/preprocessing/index.js +0 -3
  77. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
  78. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
  79. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
  80. package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
  81. package/dist/dataset/data/tabular_data.d.ts +0 -11
  82. package/dist/dataset/data/tabular_data.js +0 -24
  83. package/dist/dataset/data/text_data.d.ts +0 -11
  84. package/dist/dataset/data/text_data.js +0 -14
  85. package/dist/processing.d.ts +0 -35
  86. package/dist/processing.js +0 -89
  87. package/dist/types.d.ts +0 -3
  88. package/dist/types.js +0 -1
  89. package/dist/validation/index.d.ts +0 -1
  90. package/dist/validation/index.js +0 -1
  91. package/dist/validation/validator.d.ts +0 -10
  92. package/dist/validation/validator.js +0 -113
  93. /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
@@ -2,21 +2,27 @@
2
2
  * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
3
3
  **/
4
4
  import createDebug from "debug";
5
- import { List } from 'immutable';
6
- import * as tf from '@tensorflow/tfjs';
7
- import { WeightsContainer } from '../../index.js';
5
+ import { List, Range } from "immutable";
6
+ import * as tf from "@tensorflow/tfjs";
7
+ import { WeightsContainer } from "../../index.js";
8
8
  import { Model, EpochLogs } from "../index.js";
9
- import { GPTForCausalLM } from './model.js';
10
- import { DEFAULT_CONFIG } from './config.js';
11
- import evaluate from './evaluate.js';
9
+ import { GPTModel } from "./model.js";
10
+ import { DEFAULT_CONFIG } from "./config.js";
11
+ import evaluate from "./evaluate.js";
12
12
  const debug = createDebug("discojs:models:gpt");
13
13
  export class GPT extends Model {
14
14
  model;
15
+ #blockSize;
15
16
  #maxBatchCount;
17
+ #vocabSize;
16
18
  constructor(partialConfig, layersModel) {
17
19
  super();
18
- this.model = new GPTForCausalLM(partialConfig, layersModel);
20
+ const model = new GPTModel(partialConfig, layersModel);
21
+ model.compile();
22
+ this.model = model;
23
+ this.#blockSize = partialConfig?.blockSize ?? DEFAULT_CONFIG.blockSize;
19
24
  this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter;
25
+ this.#vocabSize = partialConfig?.vocabSize ?? DEFAULT_CONFIG.vocabSize;
20
26
  }
21
27
  /**
22
28
  * The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
@@ -27,26 +33,20 @@ export class GPT extends Model {
27
33
  * @param epochs the number of passes of the training dataset
28
34
  * @param tracker
29
35
  */
30
- async *train(trainingData, validationData) {
31
- this.model.compile();
32
- const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
36
+ async *train(trainingDataset, validationDataset) {
33
37
  let batchesLogs = List();
34
- for (let batchNumber = 0; batchNumber < this.#maxBatchCount; batchNumber++) {
35
- const iteration = await batches.next();
36
- if (iteration.done)
37
- break;
38
- const batch = iteration.value;
38
+ for await (const [batch, _] of trainingDataset.zip(Range(0, this.#maxBatchCount))) {
39
39
  const batchLogs = await this.#runBatch(batch);
40
- tf.dispose(batch);
41
40
  yield batchLogs;
42
41
  batchesLogs = batchesLogs.push(batchLogs);
43
42
  }
44
- const validation = validationData && (await this.#evaluate(validationData));
43
+ const validation = validationDataset && (await this.#evaluate(validationDataset));
45
44
  return new EpochLogs(batchesLogs, validation);
46
45
  }
47
46
  async #runBatch(batch) {
47
+ const tfBatch = this.#batchToTF(batch);
48
48
  let logs;
49
- await this.model.fitDataset(tf.data.array([batch]), {
49
+ await this.model.fitDataset(tf.data.array([tfBatch]), {
50
50
  epochs: 1,
51
51
  verbose: 0, // don't pollute
52
52
  callbacks: {
@@ -55,6 +55,7 @@ export class GPT extends Model {
55
55
  },
56
56
  },
57
57
  });
58
+ tf.dispose(tfBatch);
58
59
  if (logs === undefined)
59
60
  throw new Error("batch didn't gave any logs");
60
61
  const { loss, acc: accuracy } = logs;
@@ -67,38 +68,56 @@ export class GPT extends Model {
67
68
  };
68
69
  }
69
70
  async #evaluate(dataset) {
70
- const evaluation = await evaluate(this.model, dataset.map((t) => {
71
- switch (t) {
72
- case null:
73
- case undefined:
74
- throw new Error("nullish value in dataset");
75
- default:
76
- // TODO unsafe cast
77
- return t;
78
- }
79
- }), this.config.maxEvalBatches);
71
+ const evaluation = await evaluate(this.model, tf.data.generator(async function* () {
72
+ yield* dataset.map((batch) => this.#batchToTF(batch));
73
+ }.bind(this)), this.config.maxEvalBatches);
80
74
  return {
81
75
  accuracy: evaluation.val_acc,
82
76
  loss: evaluation.val_loss,
83
77
  };
84
78
  }
85
- predict(input) {
86
- const ret = this.model.predict(input);
87
- if (Array.isArray(ret)) {
88
- throw new Error("prediction yield many Tensors but should have only returned one");
89
- }
90
- return Promise.resolve(ret);
79
+ #batchToTF(batch) {
80
+ return tf.tidy(() => ({
81
+ xs: tf.stack(batch.map(([line]) => tf.tensor1d(line.toArray(), "int32")).toArray()), // cast as stack doesn't type
82
+ ys: tf.stack(batch
83
+ .map(([line, next]) => tf.oneHot(line.shift().push(next).toArray(), this.#vocabSize))
84
+ .toArray()), // cast as oneHot/stack doesn't type
85
+ }));
91
86
  }
92
- async generate(input, tokenizer, newTokens = 10) {
93
- const { input_ids: tokens } = await tokenizer(input, { return_tensor: false });
94
- const generationConfig = {
95
- maxNewTokens: newTokens,
87
+ async predict(batch, options) {
88
+ const config = {
96
89
  temperature: 1.0,
97
- doSample: false
90
+ doSample: false,
91
+ ...options,
98
92
  };
99
- const predictedTokens = await this.model.generate(tokens, generationConfig);
100
- const generatedWords = tokenizer.decode(predictedTokens[0]);
101
- return generatedWords;
93
+ return List(await Promise.all(batch.map((tokens) => this.#predictSingle(tokens, config))));
94
+ }
95
+ async #predictSingle(tokens, config) {
96
+ // slice input tokens if longer than context length
97
+ tokens = tokens.slice(-this.#blockSize);
98
+ const input = tf.tidy(() => tf.tensor1d(tokens.toArray(), "int32").expandDims(0));
99
+ const logits = tf.tidy(() => {
100
+ const output = this.model.predict(input);
101
+ if (Array.isArray(output))
102
+ throw new Error("The model outputs too multiple values");
103
+ if (output.rank !== 3)
104
+ throw new Error("The model outputs wrong shape");
105
+ return output.squeeze([0]);
106
+ });
107
+ input.dispose();
108
+ const probs = tf.tidy(() => logits
109
+ .slice([logits.shape[0] - 1])
110
+ .squeeze([0])
111
+ .div(config.temperature)
112
+ .softmax());
113
+ logits.dispose();
114
+ const next = tf.tidy(() => config.doSample
115
+ ? tf.multinomial(probs, 1).squeeze([0])
116
+ : probs.argMax());
117
+ probs.dispose();
118
+ const ret = await next.array();
119
+ next.dispose();
120
+ return ret;
102
121
  }
103
122
  get config() {
104
123
  return this.model.getGPTConfig;
@@ -117,7 +136,7 @@ export class GPT extends Model {
117
136
  serialize() {
118
137
  return {
119
138
  weights: this.weights,
120
- config: this.config
139
+ config: this.config,
121
140
  };
122
141
  }
123
142
  extract() {
@@ -14,25 +14,11 @@ export declare abstract class Dataset<T> {
14
14
  * GPTModel extends tf.LayersModel and overrides tfjs' default training loop
15
15
  *
16
16
  */
17
- declare class GPTModel extends tf.LayersModel {
17
+ export declare class GPTModel extends tf.LayersModel {
18
18
  protected readonly config: Required<GPTConfig>;
19
19
  constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
20
20
  get getGPTConfig(): Required<GPTConfig>;
21
21
  compile(): void;
22
22
  fitDataset<T>(dataset: Dataset<T>, trainingArgs: tf.ModelFitDatasetArgs<T>): Promise<tf.History>;
23
23
  }
24
- interface GenerateConfig {
25
- maxNewTokens: number;
26
- temperature: number;
27
- doSample: boolean;
28
- }
29
- /**
30
- * GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
31
- * This class extends GPTModel and adds supports for text generation
32
- *
33
- */
34
- export declare class GPTForCausalLM extends GPTModel {
35
- generate(idxRaw: tf.TensorLike, conf: GenerateConfig): Promise<number[][]>;
36
- private generateOnce;
37
- }
38
24
  export {};
@@ -9,7 +9,7 @@ const debug = createDebug("discojs:models:gpt");
9
9
  * GPTModel extends tf.LayersModel and overrides tfjs' default training loop
10
10
  *
11
11
  */
12
- class GPTModel extends tf.LayersModel {
12
+ export class GPTModel extends tf.LayersModel {
13
13
  config;
14
14
  constructor(partialConfig, layersModel) {
15
15
  // Fill missing config parameters with default values
@@ -124,77 +124,3 @@ class GPTModel extends tf.LayersModel {
124
124
  return new tf.History();
125
125
  }
126
126
  }
127
- const defaultGenerateConfig = {
128
- maxNewTokens: 20,
129
- temperature: 1.0,
130
- doSample: false
131
- };
132
- function prepareIdx(idx) {
133
- return tf.tidy(() => {
134
- let ret;
135
- if (idx instanceof tf.Tensor) {
136
- ret = idx.clone();
137
- }
138
- else {
139
- ret = tf.tensor(idx);
140
- }
141
- if (ret.dtype !== 'int32') {
142
- ret = ret.toInt();
143
- }
144
- switch (ret.shape.length) {
145
- case 1:
146
- return ret.expandDims(0);
147
- case 2:
148
- return ret;
149
- default:
150
- throw new Error('unexpected shape');
151
- }
152
- });
153
- }
154
- /**
155
- * GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
156
- * This class extends GPTModel and adds supports for text generation
157
- *
158
- */
159
- export class GPTForCausalLM extends GPTModel {
160
- async generate(idxRaw, conf) {
161
- const config = Object.assign({}, defaultGenerateConfig, conf);
162
- let idx = prepareIdx(idxRaw);
163
- for (let step = 0; step < config.maxNewTokens; step++) {
164
- const idxNext = this.generateOnce(this, idx, config);
165
- const idxNew = idx.concat(idxNext, 1);
166
- tf.dispose(idx);
167
- idx = idxNew;
168
- tf.dispose(idxNext);
169
- }
170
- const idxArr = await idx.array();
171
- tf.dispose(idx);
172
- return idxArr;
173
- }
174
- generateOnce(model, idx, config) {
175
- const idxNext = tf.tidy(() => {
176
- // slice input tokens if longer than context length
177
- const blockSize = this.config.blockSize;
178
- idx = idx.shape[1] <= blockSize
179
- ? idx : idx.slice([0, idx.shape[1] - blockSize]);
180
- const output = model.predict(idx);
181
- if (Array.isArray(output))
182
- throw new Error('The model outputs too multiple values');
183
- if (output.shape.length !== 3)
184
- throw new Error('The model outputs wrong shape');
185
- const logits = output;
186
- const logitsScaled = logits
187
- .slice([0, idx.shape[1] - 1, 0])
188
- .reshape([logits.shape[0], logits.shape[2]])
189
- .div(tf.scalar(config.temperature));
190
- const probs = logitsScaled.softmax(-1);
191
- if (config.doSample) {
192
- return tf.multinomial(probs, 1);
193
- }
194
- else {
195
- return probs.argMax(-1).expandDims(1);
196
- }
197
- });
198
- return idxNext;
199
- }
200
- }
@@ -1,14 +1,11 @@
1
- import type tf from "@tensorflow/tfjs";
2
- import type { WeightsContainer } from "../index.js";
1
+ import type { Batched, Dataset, DataFormat, DataType, WeightsContainer } from "../index.js";
3
2
  import type { BatchLogs, EpochLogs } from "./logs.js";
4
- export type Prediction = tf.Tensor;
5
- export type Sample = tf.Tensor;
6
3
  /**
7
4
  * Trainable predictor
8
5
  *
9
6
  * Allow for various implementation of models (various train function, tensor-library, ...)
10
7
  **/
11
- export declare abstract class Model implements Disposable {
8
+ export declare abstract class Model<D extends DataType> implements Disposable {
12
9
  /** Return training state */
13
10
  abstract get weights(): WeightsContainer;
14
11
  /** Set training state */
@@ -16,15 +13,13 @@ export declare abstract class Model implements Disposable {
16
13
  /**
17
14
  * Improve predictor
18
15
  *
19
- * @param trainingData dataset to optimize for
20
- * @param validationData dataset to measure how well it is training
21
- * @param epochs number of pass over the training dataset
22
- * @param tracker watch the various steps
23
- * @yields on every epoch, training can be stop by `return`ing it
16
+ * @param trainingDataset dataset to optimize for
17
+ * @param validationDataset dataset to measure how well it is training
18
+ * @yields on every epoch, training can be stop by `return`ing or `throw`ing it
24
19
  */
25
- abstract train(trainingData: tf.data.Dataset<tf.TensorContainer>, validationData?: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<BatchLogs, EpochLogs>;
20
+ abstract train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>): AsyncGenerator<BatchLogs, EpochLogs>;
26
21
  /** Predict likely values */
27
- abstract predict(input: Sample): Promise<Prediction>;
22
+ abstract predict(batch: Batched<DataFormat.ModelEncoded[D][0]>): Promise<Batched<DataFormat.ModelEncoded[D][1]>>;
28
23
  /**
29
24
  * This method is automatically called to cleanup the memory occupied by the model
30
25
  * when leaving the definition scope if the instance has been defined with the `using` keyword.
@@ -1,21 +1,22 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { WeightsContainer } from '../index.js';
2
+ import { Batched, Dataset, DataFormat, DataType, WeightsContainer } from "../index.js";
3
3
  import { BatchLogs } from './index.js';
4
4
  import { Model } from './index.js';
5
- import { Prediction, Sample } from './model.js';
6
5
  import { EpochLogs } from './logs.js';
6
+ type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts];
7
7
  /** TensorFlow JavaScript model with standard training */
8
- export declare class TFJS extends Model {
8
+ export declare class TFJS<D extends "image" | "tabular"> extends Model<D> {
9
9
  #private;
10
+ readonly datatype: D;
10
11
  private readonly model;
11
12
  /** Wrap the given trainable model */
12
- constructor(model: tf.LayersModel);
13
+ constructor(datatype: D, model: tf.LayersModel);
13
14
  get weights(): WeightsContainer;
14
15
  set weights(ws: WeightsContainer);
15
- train(trainingData: tf.data.Dataset<tf.TensorContainer>, validationData?: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<BatchLogs, EpochLogs>;
16
- predict(input: Sample): Promise<Prediction>;
17
- static deserialize(raw: tf.io.ModelArtifacts): Promise<Model>;
18
- serialize(): Promise<tf.io.ModelArtifacts>;
16
+ train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>): AsyncGenerator<BatchLogs, EpochLogs>;
17
+ predict(batch: Batched<DataFormat.ModelEncoded[D][0]>): Promise<Batched<DataFormat.ModelEncoded[D][1]>>;
18
+ static deserialize<D extends "image" | "tabular">([datatype, artifacts,]: Serialized<D>): Promise<TFJS<D>>;
19
+ serialize(): Promise<Serialized<D>>;
19
20
  [Symbol.dispose](): void;
20
21
  /**
21
22
  * extract wrapped model
@@ -24,3 +25,4 @@ export declare class TFJS extends Model {
24
25
  */
25
26
  extract(): tf.LayersModel;
26
27
  }
28
+ export {};
@@ -1,18 +1,22 @@
1
- import { List, Map } from 'immutable';
1
+ import { List, Map, Range } from "immutable";
2
2
  import * as tf from '@tensorflow/tfjs';
3
- import { WeightsContainer } from '../index.js';
3
+ import { WeightsContainer, } from "../index.js";
4
4
  import { Model } from './index.js';
5
5
  import { EpochLogs } from './logs.js';
6
6
  /** TensorFlow JavaScript model with standard training */
7
7
  export class TFJS extends Model {
8
+ datatype;
8
9
  model;
9
10
  /** Wrap the given trainable model */
10
- constructor(model) {
11
+ constructor(datatype, model) {
11
12
  super();
13
+ this.datatype = datatype;
12
14
  this.model = model;
13
15
  if (model.loss === undefined) {
14
16
  throw new Error('TFJS models need to be compiled to be used');
15
17
  }
18
+ if (model.outputs.length !== 1)
19
+ throw new Error("only support single output model");
16
20
  }
17
21
  get weights() {
18
22
  return new WeightsContainer(this.model.weights.map((w) => w.read()));
@@ -20,57 +24,43 @@ export class TFJS extends Model {
20
24
  set weights(ws) {
21
25
  this.model.setWeights(ws.weights);
22
26
  }
23
- async *train(trainingData, validationData) {
24
- const batches = await trainingData.iterator(); // tf.LazyIterator isn't an AsyncGenerator
27
+ async *train(trainingDataset, validationDataset) {
25
28
  let batchesLogs = List();
26
- for (let batchNumber = 0; true; batchNumber++) {
27
- const iteration = await batches.next();
28
- if (iteration.done)
29
- break;
30
- const batch = iteration.value;
29
+ for await (const [batch, batchNumber] of trainingDataset.zip(Range())) {
31
30
  const batchLogs = {
32
31
  batch: batchNumber,
33
32
  ...(await this.#runBatch(batch)),
34
33
  };
35
- tf.dispose(batch);
36
34
  yield batchLogs;
37
35
  batchesLogs = batchesLogs.push(batchLogs);
38
36
  }
39
- const validation = validationData && (await this.#evaluate(validationData));
37
+ const validation = validationDataset && (await this.#evaluate(validationDataset));
40
38
  return new EpochLogs(batchesLogs, validation);
41
39
  }
42
40
  async #runBatch(batch) {
43
- let logs;
44
- await this.model.fitDataset(tf.data.array([batch]), {
41
+ const { xs, ys } = this.#batchToTF(batch);
42
+ const { history } = await this.model.fit(xs, ys, {
45
43
  epochs: 1,
46
44
  verbose: 0, // don't pollute
47
- callbacks: {
48
- onEpochEnd: (_, cur) => {
49
- logs = cur;
50
- },
51
- },
52
45
  });
53
- if (logs === undefined)
54
- throw new Error("batch didn't gave any logs");
55
- const { loss, acc: accuracy } = logs;
56
- if (loss === undefined || isNaN(loss))
57
- throw new Error("training loss is undefined or NaN");
46
+ const { loss: losses, acc: accuracies } = history;
47
+ if (losses === undefined ||
48
+ accuracies === undefined ||
49
+ typeof losses[0] !== "number" ||
50
+ typeof accuracies[0] !== "number" ||
51
+ isNaN(losses[0]) ||
52
+ isNaN(accuracies[0]))
53
+ throw new Error("training loss or accuracy is undefined or NaN");
58
54
  return {
59
- accuracy,
60
- loss,
55
+ accuracy: accuracies[0],
56
+ loss: losses[0],
61
57
  memoryUsage: tf.memory().numBytes / 1024 / 1024 / 1024,
62
58
  };
63
59
  }
64
60
  async #evaluate(dataset) {
65
- const evaluation = await this.model.evaluateDataset(dataset.map((t) => {
66
- switch (t) {
67
- case null:
68
- case undefined:
69
- throw new Error("nullish value in dataset");
70
- default:
71
- return t;
72
- }
73
- }));
61
+ const evaluation = await this.model.evaluateDataset(tf.data.generator(async function* () {
62
+ yield* dataset.map((batch) => this.#batchToTF(batch));
63
+ }.bind(this)));
74
64
  const metricToValue = Map(List(this.model.metricsNames).zip(Array.isArray(evaluation)
75
65
  ? List(await Promise.all(evaluation.map((t) => t.data())))
76
66
  : List.of(await evaluation.data()))).map((values) => {
@@ -87,16 +77,39 @@ export class TFJS extends Model {
87
77
  throw new Error("some needed metrics are missing");
88
78
  return { accuracy, loss };
89
79
  }
90
- predict(input) {
91
- const ret = this.model.predict(input);
92
- if (Array.isArray(ret)) {
93
- throw new Error('prediction yield many Tensors but should have only returned one');
80
+ async predict(batch) {
81
+ async function cleanupPredicted(y) {
82
+ if (y.shape[0] === 1) {
83
+ // Binary classification
84
+ const threshold = tf.scalar(0.5);
85
+ const binaryTensor = y.greaterEqual(threshold);
86
+ const binaryArray = await binaryTensor.data();
87
+ tf.dispose([y, binaryTensor, threshold]);
88
+ return binaryArray[0];
89
+ }
90
+ // Multi-class classification
91
+ const indexTensor = y.argMax();
92
+ const indexArray = await indexTensor.data();
93
+ tf.dispose([y, indexTensor]);
94
+ return indexArray[0];
95
+ // Multi-label classification is not supported
94
96
  }
95
- return Promise.resolve(ret);
97
+ const xs = this.#batchWithoutLabelToTF(batch);
98
+ const prediction = this.model.predict(xs);
99
+ if (Array.isArray(prediction))
100
+ throw new Error("prediction yield many Tensors but should have only returned one");
101
+ tf.dispose(xs);
102
+ if (prediction.rank !== 2)
103
+ throw new Error("unexpected batched prediction shape");
104
+ const ret = List(await Promise.all(tf.unstack(prediction).map((y) => cleanupPredicted(
105
+ // cast as unstack reduce by one the rank
106
+ y))));
107
+ prediction.dispose();
108
+ return ret;
96
109
  }
97
- static async deserialize(raw) {
98
- return new this(await tf.loadLayersModel({
99
- load: () => Promise.resolve(raw)
110
+ static async deserialize([datatype, artifacts,]) {
111
+ return new this(datatype, await tf.loadLayersModel({
112
+ load: () => Promise.resolve(artifacts),
100
113
  }));
101
114
  }
102
115
  async serialize() {
@@ -115,7 +128,7 @@ export class TFJS extends Model {
115
128
  }, {
116
129
  includeOptimizer: true // keep model compiled
117
130
  });
118
- return await ret;
131
+ return [this.datatype, await ret];
119
132
  }
120
133
  [Symbol.dispose]() {
121
134
  this.model.dispose();
@@ -128,4 +141,53 @@ export class TFJS extends Model {
128
141
  extract() {
129
142
  return this.model;
130
143
  }
144
+ #batchToTF(batch) {
145
+ const outputSize = tf.util.sizeFromShape(this.model.outputShape.map((dim) => {
146
+ if (Array.isArray(dim))
147
+ throw new Error("TODO support multiple outputs");
148
+ return dim ?? 1;
149
+ }));
150
+ switch (this.datatype) {
151
+ case "image": {
152
+ // cast as typescript doesn't reduce generic type
153
+ const b = batch;
154
+ return tf.tidy(() => ({
155
+ xs: tf.stack(b
156
+ .map(([image]) => tf.tensor3d(image.data, [image.width, image.height, 3], "float32"))
157
+ .toArray()),
158
+ ys: tf.stack(b
159
+ .map(([_, label]) => tf.oneHot(label, outputSize, 1, 0, "int32"))
160
+ .toArray()),
161
+ }));
162
+ }
163
+ case "tabular": {
164
+ // cast as typescript doesn't reduce generic type
165
+ const b = batch;
166
+ return tf.tidy(() => ({
167
+ xs: tf.stack(b.map(([inputs, _]) => tf.tensor1d(inputs.toArray())).toArray()),
168
+ ys: tf.stack(b.map(([_, output]) => tf.tensor1d([output])).toArray()),
169
+ }));
170
+ }
171
+ }
172
+ const _ = this.datatype;
173
+ throw new Error("should never happen");
174
+ }
175
+ #batchWithoutLabelToTF(batch) {
176
+ switch (this.datatype) {
177
+ case "image": {
178
+ // cast as typescript doesn't reduce generic type
179
+ const b = batch;
180
+ return tf.tidy(() => tf.stack(b
181
+ .map((image) => tf.tensor3d(image.data, [image.width, image.height, 3], "float32"))
182
+ .toArray()));
183
+ }
184
+ case "tabular": {
185
+ // cast as typescript doesn't reduce generic type
186
+ const b = batch;
187
+ return tf.tidy(() => tf.stack(b.map((inputs) => tf.tensor1d(inputs.toArray())).toArray()));
188
+ }
189
+ }
190
+ const _ = this.datatype;
191
+ throw new Error("should never happen");
192
+ }
131
193
  }
@@ -11,4 +11,4 @@ import { PreTrainedTokenizer } from '@xenova/transformers';
11
11
  * @param task the task object specifying which tokenizer to use
12
12
  * @returns an initialized tokenizer object
13
13
  */
14
- export declare function getTaskTokenizer(task: Task): Promise<PreTrainedTokenizer>;
14
+ export declare function getTaskTokenizer(task: Task<'text'>): Promise<PreTrainedTokenizer>;
package/dist/privacy.js CHANGED
@@ -5,7 +5,7 @@ async function frobeniusNorm(weights) {
5
5
  .reduce((a, b) => a.add(b))
6
6
  .data();
7
7
  if (squared.length !== 1)
8
- throw new Error("unexcepted weights shape");
8
+ throw new Error("unexpected weights shape");
9
9
  return Math.sqrt(squared[0]);
10
10
  }
11
11
  /** Scramble weights */
@@ -0,0 +1,18 @@
1
+ import { Image } from "../index.js";
2
+ /** Image where intensity is represented in the range 0..1 */
3
+ export declare class NormalizedImage<D extends 1 | 3 | 4 = 1 | 3 | 4, W extends number = number, H extends number = number> {
4
+ readonly data: Readonly<Float32Array>;
5
+ readonly width: W;
6
+ readonly height: H;
7
+ readonly depth: D;
8
+ private constructor();
9
+ static from<D extends 1 | 3 | 4 = 1 | 3 | 4, W extends number = number, H extends number = number>(image: Image<D, W, H>): NormalizedImage<D, W, H>;
10
+ }
11
+ /** Remove the alpha channel of an image */
12
+ export declare function removeAlpha<W extends number, H extends number>(image: Image<4, W, H>): Image<3, W, H>;
13
+ export declare function removeAlpha<D extends 1 | 3, W extends number, H extends number>(image: Image<D | 4, W, H>): Image<D, W, H>;
14
+ /** Convert monochrome images to multicolor */
15
+ export declare function expandToMulticolor<W extends number, H extends number>(image: Image<1, W, H>): Image<3, W, H>;
16
+ export declare function expandToMulticolor<D extends 3 | 4, W extends number, H extends number>(image: Image<1 | D, W, H>): Image<D, W, H>;
17
+ export declare function resize<D extends 1 | 3 | 4, W extends number, H extends number>(width: W, height: H, image: Image<D, number, number>): Image<4, W, H>;
18
+ export declare function normalize<D extends 1 | 3 | 4, W extends number, H extends number>(image: Image<D, W, H>): NormalizedImage<D, W, H>;