@epfml/discojs 2.1.2-p20240515133413.0 → 2.1.2-p20240531085945.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/aggregator/base.js +1 -0
  2. package/dist/aggregator/mean.d.ts +10 -15
  3. package/dist/aggregator/mean.js +36 -50
  4. package/dist/aggregator/secure.d.ts +5 -7
  5. package/dist/aggregator/secure.js +56 -44
  6. package/dist/client/federated/messages.d.ts +1 -8
  7. package/dist/client/federated/messages.js +1 -10
  8. package/dist/client/messages.d.ts +1 -3
  9. package/dist/client/messages.js +0 -2
  10. package/dist/dataset/dataset_builder.d.ts +2 -11
  11. package/dist/dataset/dataset_builder.js +22 -46
  12. package/dist/default_tasks/cifar10.d.ts +2 -0
  13. package/dist/default_tasks/{cifar10/index.js → cifar10.js} +2 -2
  14. package/dist/default_tasks/index.d.ts +3 -2
  15. package/dist/default_tasks/index.js +3 -2
  16. package/dist/default_tasks/lus_covid.js +1 -1
  17. package/dist/default_tasks/simple_face.d.ts +2 -0
  18. package/dist/default_tasks/{simple_face/index.js → simple_face.js} +3 -3
  19. package/dist/default_tasks/skin_condition.d.ts +2 -0
  20. package/dist/default_tasks/skin_condition.js +79 -0
  21. package/dist/models/gpt/config.d.ts +32 -0
  22. package/dist/models/gpt/config.js +42 -0
  23. package/dist/models/gpt/evaluate.d.ts +7 -0
  24. package/dist/models/gpt/evaluate.js +44 -0
  25. package/dist/models/gpt/index.d.ts +35 -0
  26. package/dist/models/gpt/index.js +104 -0
  27. package/dist/models/gpt/layers.d.ts +13 -0
  28. package/dist/models/gpt/layers.js +272 -0
  29. package/dist/models/gpt/model.d.ts +43 -0
  30. package/dist/models/gpt/model.js +191 -0
  31. package/dist/models/gpt/optimizers.d.ts +4 -0
  32. package/dist/models/gpt/optimizers.js +95 -0
  33. package/dist/models/index.d.ts +5 -0
  34. package/dist/models/index.js +4 -0
  35. package/dist/{default_tasks/simple_face/model.js → models/mobileNetV2_35_alpha_2_classes.js} +2 -0
  36. package/dist/{default_tasks/cifar10/model.js → models/mobileNet_v1_025_224.js} +1 -0
  37. package/dist/models/model.d.ts +51 -0
  38. package/dist/models/model.js +8 -0
  39. package/dist/models/tfjs.d.ts +24 -0
  40. package/dist/models/tfjs.js +107 -0
  41. package/dist/models/tokenizer.d.ts +14 -0
  42. package/dist/models/tokenizer.js +22 -0
  43. package/dist/validation/validator.js +8 -7
  44. package/package.json +1 -1
  45. package/dist/default_tasks/cifar10/index.d.ts +0 -2
  46. package/dist/default_tasks/simple_face/index.d.ts +0 -2
  47. /package/dist/{default_tasks/simple_face/model.d.ts → models/mobileNetV2_35_alpha_2_classes.d.ts} +0 -0
  48. /package/dist/{default_tasks/cifar10/model.d.ts → models/mobileNet_v1_025_224.d.ts} +0 -0
@@ -0,0 +1,191 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { getModelSizes, DEFAULT_CONFIG } from './config.js';
3
+ import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js';
4
+ import evaluate from './evaluate.js';
5
+ import { GPTArchitecture } from './layers.js';
6
+ /**
7
+ * GPTModel extends tf.LayersModel and overrides tfjs' default training loop
8
+ *
9
+ */
10
+ class GPTModel extends tf.LayersModel {
11
+ config;
12
+ disposalRefs; // Array to store tensor to dispose manually
13
+ // Object to pass down to layers to store max memory allocated
14
+ // This is an object rather than a primitive to pass the reference
15
+ peakMemory;
16
+ constructor(partialConfig) {
17
+ // Fill missing config parameters with default values
18
+ let completeConfig = { ...DEFAULT_CONFIG, ...partialConfig };
19
+ // Add layer sizes depending on which model has been specified
20
+ completeConfig = { ...completeConfig, ...getModelSizes(completeConfig.modelType) };
21
+ // Init the tf.LayersModel and assign it to this
22
+ const disposalRefs = [];
23
+ const peakMemory = { value: 0 };
24
+ const gpt = GPTArchitecture(completeConfig, disposalRefs, peakMemory);
25
+ const { inputs, outputs, name } = gpt;
26
+ super({ inputs, outputs, name });
27
+ this.config = completeConfig;
28
+ this.disposalRefs = disposalRefs;
29
+ this.peakMemory = peakMemory;
30
+ }
31
+ // Some tensors are not cleaned up when model.dispose is called
32
+ // So we dispose them manually
33
+ disposeRefs() {
34
+ for (const tensorContainer of this.disposalRefs) {
35
+ tf.dispose([tensorContainer]);
36
+ }
37
+ }
38
+ get getGPTConfig() {
39
+ return this.config;
40
+ }
41
+ compile() {
42
+ this.optimizer = this.config.weightDecay !== 0
43
+ ? getCustomAdam(this, this.config.lr, this.config.weightDecay)
44
+ : tf.train.adam(this.config.lr);
45
+ this.peakMemory.value = 0;
46
+ }
47
+ async fitDataset(dataset, trainingArgs) {
48
+ const callbacks = trainingArgs.callbacks;
49
+ const evalDataset = trainingArgs.validationData;
50
+ await callbacks.onTrainBegin?.();
51
+ for (let epoch = 1; epoch <= trainingArgs.epochs; epoch++) {
52
+ let averageLoss = 0;
53
+ let iteration = 1;
54
+ const iterator = await dataset.iterator();
55
+ let preprocessingTime = performance.now();
56
+ let next = await iterator.next();
57
+ preprocessingTime = performance.now() - preprocessingTime;
58
+ while (next.done !== true && iteration <= this.config.maxIter) {
59
+ let weightUpdateTime = performance.now();
60
+ await callbacks.onEpochBegin?.(epoch);
61
+ const { xs, ys } = next.value;
62
+ const lossFn = () => {
63
+ const logits = this.apply(xs);
64
+ if (Array.isArray(logits)) {
65
+ throw new Error('model outputs too many tensor');
66
+ }
67
+ if (logits instanceof tf.SymbolicTensor) {
68
+ throw new Error('model outputs symbolic tensor');
69
+ }
70
+ return tf.losses.softmaxCrossEntropy(ys, logits);
71
+ };
72
+ let backwardPassMemory = 0;
73
+ const lossTensor = tf.tidy(() => {
74
+ const { grads, value: lossTensor } = this.optimizer.computeGradients(lossFn);
75
+ const gradsClipped = clipByGlobalNormObj(grads, 1);
76
+ this.optimizer.applyGradients(gradsClipped);
77
+ backwardPassMemory = tf.memory().numBytes / 1024 / 1024 / 1024;
78
+ return lossTensor;
79
+ });
80
+ const loss = await lossTensor.array();
81
+ averageLoss += loss;
82
+ weightUpdateTime = performance.now() - weightUpdateTime;
83
+ // Probably never the case. Empirically the attention mechanism always allocates
84
+ // more memory than the backward pass
85
+ if (backwardPassMemory > this.peakMemory.value) {
86
+ this.peakMemory.value = backwardPassMemory;
87
+ }
88
+ tf.dispose([xs, ys, lossTensor]);
89
+ if (evalDataset !== undefined &&
90
+ this.config.evaluateEvery !== undefined &&
91
+ iteration % this.config.evaluateEvery == 0) {
92
+ const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches);
93
+ console.log(iterationLogs);
94
+ }
95
+ console.log(`Epoch: ${epoch}`, `\tStep: ${iteration} / ${this.config.maxIter}`, `\tLoss: ${loss.toFixed(3)}`, `\tPeak memory: ${this.peakMemory.value.toFixed(2)} GB`, `\tNumber of tensors allocated: ${tf.memory().numTensors}`, `\tPreprocessing time: ${preprocessingTime.toFixed(0)} ms`, `\tWeight update time: ${weightUpdateTime.toFixed(0)} ms`);
96
+ iteration++;
97
+ next = await iterator.next();
98
+ }
99
+ // Memory leak: If we reached the last iteration rather than the end of the dataset, cleanup the tensors
100
+ if (next.done != true && iteration > this.config.maxIter) {
101
+ const { xs, ys } = next.value;
102
+ tf.dispose([xs, ys]);
103
+ }
104
+ let logs = {
105
+ 'loss': averageLoss / iteration,
106
+ 'peakMemory': this.peakMemory.value
107
+ };
108
+ if (evalDataset !== undefined) {
109
+ logs = { ...logs, ...await evaluate(this, evalDataset, this.config.maxEvalBatches) };
110
+ console.log(logs);
111
+ }
112
+ await callbacks.onEpochEnd?.(epoch, logs);
113
+ }
114
+ await callbacks.onTrainEnd?.();
115
+ return new tf.History();
116
+ }
117
+ }
118
+ const defaultGenerateConfig = {
119
+ maxNewTokens: 20,
120
+ temperature: 1.0,
121
+ doSample: false
122
+ };
123
+ function prepareIdx(idx) {
124
+ return tf.tidy(() => {
125
+ let ret;
126
+ if (idx instanceof tf.Tensor) {
127
+ ret = idx.clone();
128
+ }
129
+ else {
130
+ ret = tf.tensor(idx);
131
+ }
132
+ if (ret.dtype !== 'int32') {
133
+ ret = ret.toInt();
134
+ }
135
+ switch (ret.shape.length) {
136
+ case 1:
137
+ return ret.expandDims(0);
138
+ case 2:
139
+ return ret;
140
+ default:
141
+ throw new Error('unexpected shape');
142
+ }
143
+ });
144
+ }
145
+ /**
146
+ * GPTForCausalLM stands for GPT model for Causal Language Modeling. Causal because it only looks at past tokens and not future ones
147
+ * This class extends GPTModel and adds supports for text generation
148
+ *
149
+ */
150
+ export class GPTForCausalLM extends GPTModel {
151
+ async generate(idxRaw, conf) {
152
+ const config = Object.assign({}, defaultGenerateConfig, conf);
153
+ let idx = prepareIdx(idxRaw);
154
+ for (let step = 0; step < config.maxNewTokens; step++) {
155
+ const idxNext = this.generateOnce(this, idx, config);
156
+ const idxNew = idx.concat(idxNext, 1);
157
+ tf.dispose(idx);
158
+ idx = idxNew;
159
+ tf.dispose(idxNext);
160
+ }
161
+ const idxArr = await idx.array();
162
+ tf.dispose(idx);
163
+ return idxArr;
164
+ }
165
+ generateOnce(model, idx, config) {
166
+ const idxNext = tf.tidy(() => {
167
+ // slice input tokens if longer than context length
168
+ const blockSize = this.config.blockSize;
169
+ idx = idx.shape[1] <= blockSize
170
+ ? idx : idx.slice([0, idx.shape[1] - blockSize]);
171
+ const output = model.predict(idx);
172
+ if (Array.isArray(output))
173
+ throw new Error('The model outputs too multiple values');
174
+ if (output.shape.length !== 3)
175
+ throw new Error('The model outputs wrong shape');
176
+ const logits = output;
177
+ const logitsScaled = logits
178
+ .slice([0, idx.shape[1] - 1, 0])
179
+ .reshape([logits.shape[0], logits.shape[2]])
180
+ .div(tf.scalar(config.temperature));
181
+ const probs = logitsScaled.softmax(-1);
182
+ if (config.doSample) {
183
+ return tf.multinomial(probs, 1);
184
+ }
185
+ else {
186
+ return probs.argMax(-1).expandDims(1);
187
+ }
188
+ });
189
+ return idxNext;
190
+ }
191
+ }
@@ -0,0 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ declare function clipByGlobalNormObj(tensorsObj: Record<string, tf.Tensor>, clipNorm: number, useNorm?: tf.Tensor): Record<string, tf.Tensor>;
3
+ declare function getCustomAdam(model: tf.LayersModel, lr: number, weightDecay: number): tf.Optimizer;
4
+ export { getCustomAdam, clipByGlobalNormObj };
@@ -0,0 +1,95 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ function l2Loss(tensor) {
3
+ return tf.div(tf.sum(tf.square(tensor)), 2);
4
+ }
5
+ function globalNorm(tensors) {
6
+ const halfSquaredNorms = [];
7
+ tensors.forEach((tensor) => {
8
+ halfSquaredNorms.push(l2Loss(tensor));
9
+ });
10
+ const halfSquaredNorm = tf.sum(tf.stack(halfSquaredNorms));
11
+ const norm = tf.sqrt(tf.mul(halfSquaredNorm, tf.scalar(2.0, halfSquaredNorm.dtype)));
12
+ return norm;
13
+ }
14
+ function clipByGlobalNorm(tensors, clipNorm, useNorm) {
15
+ return tf.tidy(() => {
16
+ useNorm = useNorm ?? globalNorm(tensors);
17
+ const scale = tf.mul(clipNorm, tf.minimum(tf.div(tf.scalar(1.0), useNorm), tf.div(tf.scalar(1.0, useNorm.dtype), clipNorm)));
18
+ const tensorsClipped = [];
19
+ tensors.forEach((tensor) => {
20
+ tensorsClipped.push(tf.clone(tf.mul(tensor, scale)));
21
+ });
22
+ return tensorsClipped;
23
+ });
24
+ }
25
+ function clipByGlobalNormObj(tensorsObj, clipNorm, useNorm) {
26
+ const varNames = Object.keys(tensorsObj);
27
+ const tensorsArr = varNames.map((n) => tensorsObj[n]);
28
+ const tensorsArrClipped = clipByGlobalNorm(tensorsArr, clipNorm, useNorm);
29
+ const tensorsObjClipped = {};
30
+ tensorsArrClipped.forEach((t, ti) => {
31
+ tensorsObjClipped[varNames[ti]] = t;
32
+ });
33
+ return tensorsObjClipped;
34
+ }
35
+ class AdamW extends tf.AdamOptimizer {
36
+ weightDecayRate;
37
+ includeInWeightDecay;
38
+ excludeFromWeightDecay;
39
+ gradientClipNorm;
40
+ constructor(params) {
41
+ console.log('Using custom AdamW optimizer');
42
+ const defaultParams = {
43
+ learningRate: 0.1,
44
+ beta1: 0.9,
45
+ beta2: 0.999,
46
+ epsilon: 1e-7,
47
+ weightDecayRate: 0,
48
+ includeInWeightDecay: [],
49
+ excludeFromWeightDecay: [],
50
+ gradientClipNorm: 1.0
51
+ };
52
+ const p = Object.assign({}, defaultParams, params);
53
+ super(p.learningRate, p.beta1, p.beta2, p.epsilon);
54
+ this.weightDecayRate = p.weightDecayRate;
55
+ this.includeInWeightDecay = p.includeInWeightDecay;
56
+ this.excludeFromWeightDecay = p.excludeFromWeightDecay;
57
+ this.gradientClipNorm = p.gradientClipNorm;
58
+ }
59
+ applyGradients(variableGradients) {
60
+ const varNames = Array.isArray(variableGradients)
61
+ ? variableGradients.map((v) => v.name)
62
+ : Object.keys(variableGradients);
63
+ varNames.forEach((name) => {
64
+ if (this.includeInWeightDecay.includes(name)) {
65
+ const value = tf.engine().registeredVariables[name];
66
+ const newValue = tf.sub(value, tf.mul(this.learningRate, tf.mul(value, this.weightDecayRate)));
67
+ value.assign(newValue);
68
+ }
69
+ });
70
+ super.applyGradients(variableGradients);
71
+ }
72
+ }
73
+ function getCustomAdam(model, lr, weightDecay) {
74
+ const includeInWeightDecay = [];
75
+ const excludeFromWeightDecay = [];
76
+ // TODO unsafe cast
77
+ const namedWeights = model.getNamedWeights();
78
+ namedWeights.forEach((v) => {
79
+ if (v.name.includes('bias') ||
80
+ v.name.includes('normalization') ||
81
+ v.name.includes('emb')) {
82
+ excludeFromWeightDecay.push(v.name);
83
+ }
84
+ else {
85
+ includeInWeightDecay.push(v.name);
86
+ }
87
+ });
88
+ return new AdamW({
89
+ learningRate: lr,
90
+ weightDecayRate: weightDecay,
91
+ includeInWeightDecay,
92
+ excludeFromWeightDecay
93
+ });
94
+ }
95
+ export { getCustomAdam, clipByGlobalNormObj };
@@ -0,0 +1,5 @@
1
+ export { EpochLogs, Model } from './model.js';
2
+ export { GPT } from './gpt/index.js';
3
+ export { GPTConfig } from './gpt/config.js';
4
+ export { TFJS } from './tfjs.js';
5
+ export { getTaskTokenizer } from './tokenizer.js';
@@ -0,0 +1,4 @@
1
+ export { Model } from './model.js';
2
+ export { GPT } from './gpt/index.js';
3
+ export { TFJS } from './tfjs.js';
4
+ export { getTaskTokenizer } from './tokenizer.js';
@@ -1,3 +1,5 @@
1
+ // Source: https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json
2
+ // This model was converted using the tensorflow.js converter
1
3
  export default {
2
4
  format: "layers-model",
3
5
  generatedBy: "keras v2.6.0",
@@ -1,3 +1,4 @@
1
+ // Source: https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json
1
2
  export default {
2
3
  modelTopology: {
3
4
  keras_version: "2.1.4",
@@ -0,0 +1,51 @@
1
+ /// <reference types="node" resolution-mode="require"/>
2
+ import type tf from "@tensorflow/tfjs";
3
+ import type { WeightsContainer } from "../index.js";
4
+ import type { Dataset } from "../dataset/index.js";
5
+ export interface EpochLogs {
6
+ epoch: number;
7
+ training: {
8
+ loss: number;
9
+ accuracy?: number;
10
+ };
11
+ validation?: {
12
+ loss: number;
13
+ accuracy: number;
14
+ };
15
+ peakMemory: number;
16
+ }
17
+ export type Prediction = tf.Tensor;
18
+ export type Sample = tf.Tensor;
19
+ /**
20
+ * Trainable predictor
21
+ *
22
+ * Allow for various implementation of models (various train function, tensor-library, ...)
23
+ **/
24
+ export declare abstract class Model implements Disposable {
25
+ /** Return training state */
26
+ abstract get weights(): WeightsContainer;
27
+ /** Set training state */
28
+ abstract set weights(ws: WeightsContainer);
29
+ /**
30
+ * Improve predictor
31
+ *
32
+ * @param trainingData dataset to optimize for
33
+ * @param validationData dataset to measure how well it is training
34
+ * @param epochs number of pass over the training dataset
35
+ * @param tracker watch the various steps
36
+ * @yields on every epoch, training can be stop by `return`ing it
37
+ */
38
+ abstract train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs, void>;
39
+ /** Predict likely values */
40
+ abstract predict(input: Sample): Promise<Prediction>;
41
+ /**
42
+ * This method is automatically called to cleanup the memory occupied by the model
43
+ * when leaving the definition scope if the instance has been defined with the `using` keyword.
44
+ * For example:
45
+ * function f() {
46
+ * using model = new Model();
47
+ * }
48
+ * Calling f() will call the model's dispose method when exiting the function.
49
+ */
50
+ abstract [Symbol.dispose](): void;
51
+ }
@@ -0,0 +1,8 @@
1
+ /**
2
+ * Trainable predictor
3
+ *
4
+ * Allow for various implementation of models (various train function, tensor-library, ...)
5
+ **/
6
+ // TODO make it typesafe: same shape of data/input/weights
7
+ export class Model {
8
+ }
@@ -0,0 +1,24 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { WeightsContainer } from '../index.js';
3
+ import { Model } from './index.js';
4
+ import type { EpochLogs, Prediction, Sample } from './model.js';
5
+ import type { Dataset } from '../dataset/index.js';
6
+ /** TensorFlow JavaScript model with standard training */
7
+ export declare class TFJS extends Model {
8
+ private readonly model;
9
+ /** Wrap the given trainable model */
10
+ constructor(model: tf.LayersModel);
11
+ get weights(): WeightsContainer;
12
+ set weights(ws: WeightsContainer);
13
+ train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs>;
14
+ predict(input: Sample): Promise<Prediction>;
15
+ static deserialize(raw: tf.io.ModelArtifacts): Promise<Model>;
16
+ serialize(): Promise<tf.io.ModelArtifacts>;
17
+ [Symbol.dispose](): void;
18
+ /**
19
+ * extract wrapped model
20
+ *
21
+ * @deprecated use `Model` instead of relying on tf specifics
22
+ */
23
+ extract(): tf.LayersModel;
24
+ }
@@ -0,0 +1,107 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { WeightsContainer } from '../index.js';
3
+ import { Model } from './index.js';
4
+ /** TensorFlow JavaScript model with standard training */
5
+ export class TFJS extends Model {
6
+ model;
7
+ /** Wrap the given trainable model */
8
+ constructor(model) {
9
+ super();
10
+ this.model = model;
11
+ if (model.loss === undefined) {
12
+ throw new Error('TFJS models need to be compiled to be used');
13
+ }
14
+ }
15
+ get weights() {
16
+ return new WeightsContainer(this.model.weights.map((w) => w.read()));
17
+ }
18
+ set weights(ws) {
19
+ this.model.setWeights(ws.weights);
20
+ }
21
+ async *train(trainingData, validationData, epochs = 1) {
22
+ for (let epoch = 0; epoch < epochs; epoch++) {
23
+ let logs;
24
+ let peakMemory = 0;
25
+ await this.model.fitDataset(trainingData, {
26
+ epochs: 1,
27
+ validationData,
28
+ callbacks: {
29
+ onBatchEnd: (_) => {
30
+ const currentMemory = tf.memory().numBytes / 1024 / 1024 / 1024; // GB
31
+ if (currentMemory > peakMemory) {
32
+ peakMemory = currentMemory;
33
+ }
34
+ },
35
+ onEpochEnd: (_, cur) => { logs = cur; }
36
+ },
37
+ });
38
+ if (logs === undefined) {
39
+ throw new Error("Epoch didn't gave any logs");
40
+ }
41
+ const { loss, acc, val_acc, val_loss } = logs;
42
+ if (loss === undefined || isNaN(loss) || acc === undefined || isNaN(acc)) {
43
+ throw new Error("Training loss is undefined or nan");
44
+ }
45
+ const structuredLogs = {
46
+ epoch,
47
+ peakMemory,
48
+ training: {
49
+ loss: logs.loss,
50
+ accuracy: logs.acc,
51
+ }
52
+ };
53
+ if (validationData !== undefined) {
54
+ if (val_loss === undefined || isNaN(val_loss) ||
55
+ val_acc === undefined || isNaN(val_acc)) {
56
+ throw new Error("Invalid validation logs");
57
+ }
58
+ structuredLogs.validation = {
59
+ accuracy: logs.val_acc,
60
+ loss: logs.val_loss
61
+ };
62
+ }
63
+ yield structuredLogs;
64
+ }
65
+ }
66
+ predict(input) {
67
+ const ret = this.model.predict(input);
68
+ if (Array.isArray(ret)) {
69
+ throw new Error('prediction yield many Tensors but should have only returned one');
70
+ }
71
+ return Promise.resolve(ret);
72
+ }
73
+ static async deserialize(raw) {
74
+ return new this(await tf.loadLayersModel({
75
+ load: () => Promise.resolve(raw)
76
+ }));
77
+ }
78
+ async serialize() {
79
+ let resolveArtifacts;
80
+ const ret = new Promise((resolve) => { resolveArtifacts = resolve; });
81
+ await this.model.save({
82
+ save: (artifacts) => {
83
+ resolveArtifacts(artifacts);
84
+ return Promise.resolve({
85
+ modelArtifactsInfo: {
86
+ dateSaved: new Date(),
87
+ modelTopologyType: 'JSON'
88
+ }
89
+ });
90
+ }
91
+ }, {
92
+ includeOptimizer: true // keep model compiled
93
+ });
94
+ return await ret;
95
+ }
96
+ [Symbol.dispose]() {
97
+ this.model.dispose();
98
+ }
99
+ /**
100
+ * extract wrapped model
101
+ *
102
+ * @deprecated use `Model` instead of relying on tf specifics
103
+ */
104
+ extract() {
105
+ return this.model;
106
+ }
107
+ }
@@ -0,0 +1,14 @@
1
+ import type { Task } from '../index.js';
2
+ import { PreTrainedTokenizer } from '@xenova/transformers';
3
+ /**
4
+ * A task's tokenizer is initially specified as the tokenizer name, e.g., 'Xenova/gpt2'.
5
+ * The first time the tokenizer is needed, this function initializes the actual tokenizer object
6
+ * and saves it in the task' tokenizer field to be reused in subsequent calls.
7
+ *
8
+ * We are proceeding as such because the task object is sent from the server to the client. Rather than
9
+ * sending complex objects through the network, we simply send the tokenizer name, which is then initialized client-side the
10
+ * first time it is called.
11
+ * @param task the task object specifying which tokenizer to use
12
+ * @returns an initialized tokenizer object
13
+ */
14
+ export declare function getTaskTokenizer(task: Task): Promise<PreTrainedTokenizer>;
@@ -0,0 +1,22 @@
1
+ import { AutoTokenizer } from '@xenova/transformers';
2
+ /**
3
+ * A task's tokenizer is initially specified as the tokenizer name, e.g., 'Xenova/gpt2'.
4
+ * The first time the tokenizer is needed, this function initializes the actual tokenizer object
5
+ * and saves it in the task' tokenizer field to be reused in subsequent calls.
6
+ *
7
+ * We are proceeding as such because the task object is sent from the server to the client. Rather than
8
+ * sending complex objects through the network, we simply send the tokenizer name, which is then initialized client-side the
9
+ * first time it is called.
10
+ * @param task the task object specifying which tokenizer to use
11
+ * @returns an initialized tokenizer object
12
+ */
13
+ export async function getTaskTokenizer(task) {
14
+ let tokenizer = task.trainingInformation.tokenizer;
15
+ if (tokenizer === undefined)
16
+ throw Error('No tokenizer specified in the task training information');
17
+ if (typeof tokenizer == 'string') {
18
+ tokenizer = await AutoTokenizer.from_pretrained(tokenizer);
19
+ task.trainingInformation.tokenizer = tokenizer;
20
+ }
21
+ return tokenizer;
22
+ }
@@ -21,14 +21,15 @@ export class Validator {
21
21
  }
22
22
  }
23
23
  async getLabel(ys) {
24
- switch (ys.shape[1]) {
25
- case 1:
26
- return await ys.greaterEqual(tf.scalar(0.5)).data();
27
- case 2:
28
- return await ys.argMax(1).data();
29
- default:
30
- throw new Error(`unable to reduce tensor of shape: ${ys.shape.toString()}`);
24
+ // Binary classification
25
+ if (ys.shape[1] == 1) {
26
+ return await ys.greaterEqual(tf.scalar(0.5)).data();
27
+ // Multi-class classification
31
28
  }
29
+ else {
30
+ return await ys.argMax(-1).data();
31
+ }
32
+ // Multi-label classification is not supported
32
33
  }
33
34
  async assess(data, useConfusionMatrix = false) {
34
35
  const batchSize = this.task.trainingInformation?.batchSize;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.1.2-p20240515133413.0",
3
+ "version": "2.1.2-p20240531085945.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,2 +0,0 @@
1
- import type { TaskProvider } from '../../index.js';
2
- export declare const cifar10: TaskProvider;
@@ -1,2 +0,0 @@
1
- import type { TaskProvider } from '../../index.js';
2
- export declare const simpleFace: TaskProvider;