@epfml/discojs 2.1.2-p20240513140724.0 → 2.1.2-p20240515132210.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { WeightsContainer } from '../index.js';
3
+ import { Model } from './index.js';
4
+ import type { EpochLogs, Prediction, Sample } from './model.js';
5
+ import type { Dataset } from '../dataset/index.js';
6
+ /** TensorFlow JavaScript model with standard training */
7
+ export declare class TFJS extends Model {
8
+ private readonly model;
9
+ /** Wrap the given trainable model */
10
+ constructor(model: tf.LayersModel);
11
+ get weights(): WeightsContainer;
12
+ set weights(ws: WeightsContainer);
13
+ train(trainingData: Dataset, validationData?: Dataset, epochs?: number): AsyncGenerator<EpochLogs>;
14
+ predict(input: Sample): Promise<Prediction>;
15
+ static deserialize(raw: tf.io.ModelArtifacts): Promise<Model>;
16
+ serialize(): Promise<tf.io.ModelArtifacts>;
17
+ [Symbol.dispose](): void;
18
+ /**
19
+ * extract wrapped model
20
+ *
21
+ * @deprecated use `Model` instead of relying on tf specifics
22
+ */
23
+ extract(): tf.LayersModel;
24
+ }
@@ -0,0 +1,107 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { WeightsContainer } from '../index.js';
3
+ import { Model } from './index.js';
4
+ /** TensorFlow JavaScript model with standard training */
5
+ export class TFJS extends Model {
6
+ model;
7
+ /** Wrap the given trainable model */
8
+ constructor(model) {
9
+ super();
10
+ this.model = model;
11
+ if (model.loss === undefined) {
12
+ throw new Error('TFJS models need to be compiled to be used');
13
+ }
14
+ }
15
+ get weights() {
16
+ return new WeightsContainer(this.model.weights.map((w) => w.read()));
17
+ }
18
+ set weights(ws) {
19
+ this.model.setWeights(ws.weights);
20
+ }
21
+ async *train(trainingData, validationData, epochs = 1) {
22
+ for (let epoch = 0; epoch < epochs; epoch++) {
23
+ let logs;
24
+ let peakMemory = 0;
25
+ await this.model.fitDataset(trainingData, {
26
+ epochs: 1,
27
+ validationData,
28
+ callbacks: {
29
+ onBatchEnd: (_) => {
30
+ const currentMemory = tf.memory().numBytes / 1024 / 1024 / 1024; // GB
31
+ if (currentMemory > peakMemory) {
32
+ peakMemory = currentMemory;
33
+ }
34
+ },
35
+ onEpochEnd: (_, cur) => { logs = cur; }
36
+ },
37
+ });
38
+ if (logs === undefined) {
39
+ throw new Error("Epoch didn't gave any logs");
40
+ }
41
+ const { loss, acc, val_acc, val_loss } = logs;
42
+ if (loss === undefined || isNaN(loss) || acc === undefined || isNaN(acc)) {
43
+ throw new Error("Training loss is undefined or nan");
44
+ }
45
+ const structuredLogs = {
46
+ epoch,
47
+ peakMemory,
48
+ training: {
49
+ loss: logs.loss,
50
+ accuracy: logs.acc,
51
+ }
52
+ };
53
+ if (validationData !== undefined) {
54
+ if (val_loss === undefined || isNaN(val_loss) ||
55
+ val_acc === undefined || isNaN(val_acc)) {
56
+ throw new Error("Invalid validation logs");
57
+ }
58
+ structuredLogs.validation = {
59
+ accuracy: logs.val_acc,
60
+ loss: logs.val_loss
61
+ };
62
+ }
63
+ yield structuredLogs;
64
+ }
65
+ }
66
+ predict(input) {
67
+ const ret = this.model.predict(input);
68
+ if (Array.isArray(ret)) {
69
+ throw new Error('prediction yield many Tensors but should have only returned one');
70
+ }
71
+ return Promise.resolve(ret);
72
+ }
73
+ static async deserialize(raw) {
74
+ return new this(await tf.loadLayersModel({
75
+ load: () => Promise.resolve(raw)
76
+ }));
77
+ }
78
+ async serialize() {
79
+ let resolveArtifacts;
80
+ const ret = new Promise((resolve) => { resolveArtifacts = resolve; });
81
+ await this.model.save({
82
+ save: (artifacts) => {
83
+ resolveArtifacts(artifacts);
84
+ return Promise.resolve({
85
+ modelArtifactsInfo: {
86
+ dateSaved: new Date(),
87
+ modelTopologyType: 'JSON'
88
+ }
89
+ });
90
+ }
91
+ }, {
92
+ includeOptimizer: true // keep model compiled
93
+ });
94
+ return await ret;
95
+ }
96
+ [Symbol.dispose]() {
97
+ this.model.dispose();
98
+ }
99
+ /**
100
+ * extract wrapped model
101
+ *
102
+ * @deprecated use `Model` instead of relying on tf specifics
103
+ */
104
+ extract() {
105
+ return this.model;
106
+ }
107
+ }
@@ -0,0 +1,14 @@
1
+ import type { Task } from '../index.js';
2
+ import { PreTrainedTokenizer } from '@xenova/transformers';
3
+ /**
4
+ * A task's tokenizer is initially specified as the tokenizer name, e.g., 'Xenova/gpt2'.
5
+ * The first time the tokenizer is needed, this function initializes the actual tokenizer object
6
+ * and saves it in the task' tokenizer field to be reused in subsequent calls.
7
+ *
8
+ * We are proceeding as such because the task object is sent from the server to the client. Rather than
9
+ * sending complex objects through the network, we simply send the tokenizer name, which is then initialized client-side the
10
+ * first time it is called.
11
+ * @param task the task object specifying which tokenizer to use
12
+ * @returns an initialized tokenizer object
13
+ */
14
+ export declare function getTaskTokenizer(task: Task): Promise<PreTrainedTokenizer>;
@@ -0,0 +1,23 @@
1
+ import { AutoTokenizer, env } from '@xenova/transformers';
2
+ /**
3
+ * A task's tokenizer is initially specified as the tokenizer name, e.g., 'Xenova/gpt2'.
4
+ * The first time the tokenizer is needed, this function initializes the actual tokenizer object
5
+ * and saves it in the task' tokenizer field to be reused in subsequent calls.
6
+ *
7
+ * We are proceeding as such because the task object is sent from the server to the client. Rather than
8
+ * sending complex objects through the network, we simply send the tokenizer name, which is then initialized client-side the
9
+ * first time it is called.
10
+ * @param task the task object specifying which tokenizer to use
11
+ * @returns an initialized tokenizer object
12
+ */
13
+ export async function getTaskTokenizer(task) {
14
+ let tokenizer = task.trainingInformation.tokenizer;
15
+ if (tokenizer === undefined)
16
+ throw Error('No tokenizer specified in the task training information');
17
+ if (typeof tokenizer == 'string') {
18
+ env.allowLocalModels = false;
19
+ tokenizer = await AutoTokenizer.from_pretrained(tokenizer);
20
+ task.trainingInformation.tokenizer = tokenizer;
21
+ }
22
+ return tokenizer;
23
+ }
@@ -1,4 +1,4 @@
1
- import { ModelType } from '../../index.js';
1
+ import { StoredModelType } from '../../index.js';
2
2
  import { DistributedTrainer } from './distributed_trainer.js';
3
3
  import { LocalTrainer } from './local_trainer.js';
4
4
  /**
@@ -36,7 +36,7 @@ export class TrainerBuilder {
36
36
  if (modelID === undefined) {
37
37
  throw new TypeError('model ID is undefined');
38
38
  }
39
- const info = { type: ModelType.WORKING, taskID: this.task.id, name: modelID };
39
+ const info = { type: StoredModelType.WORKING, taskID: this.task.id, name: modelID };
40
40
  const model = await (await this.memory.contains(info) ? this.memory.getModel(info) : client.getLatestModel());
41
41
  return model;
42
42
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "2.1.2-p20240513140724.0",
3
+ "version": "2.1.2-p20240515132210.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",