@epfml/discojs-web 2.1.2-p20240603114517.0 → 2.1.2-p20240617140649.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ import { data } from '@epfml/discojs';
3
3
  export class TextLoader extends data.TextLoader {
4
4
  loadDatasetFrom(source) {
5
5
  const file = new tf.data.FileDataSource(source);
6
- const dataset = new tf.data.TextLineDataset(file).filter(s => s != ' '); // newline creates empty strings
6
+ const dataset = new tf.data.TextLineDataset(file).filter(s => s !== ' '); // newline creates empty strings
7
7
  return Promise.resolve(dataset);
8
8
  }
9
9
  }
@@ -2,8 +2,8 @@ import * as tf from '@tensorflow/tfjs';
2
2
  import type { Path, Model, ModelInfo, ModelSource } from '@epfml/discojs';
3
3
  import { Memory } from '@epfml/discojs';
4
4
  export declare class IndexedDB extends Memory {
5
- pathFor(source: ModelSource): Path;
6
- infoFor(source: ModelSource): ModelInfo;
5
+ getModelMemoryPath(source: ModelSource): Path;
6
+ getModelInfo(source: ModelSource): ModelInfo;
7
7
  getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
8
8
  contains(source: ModelSource): Promise<boolean>;
9
9
  getModel(source: ModelSource): Promise<Model>;
@@ -8,48 +8,61 @@
8
8
  */
9
9
  import { Map } from 'immutable';
10
10
  import * as tf from '@tensorflow/tfjs';
11
- import { Memory, ModelType, models } from '@epfml/discojs';
11
+ import { Memory, models } from '@epfml/discojs';
12
12
  export class IndexedDB extends Memory {
13
- pathFor(source) {
13
+ getModelMemoryPath(source) {
14
14
  if (typeof source === 'string') {
15
15
  return source;
16
16
  }
17
- if (source.type === undefined || source.taskID === undefined || source.name === undefined) {
18
- throw new TypeError('source incomplete');
19
- }
20
17
  const version = source.version ?? 0;
21
- return `indexeddb://${source.type}/${source.taskID}/${source.name}@${version}`;
18
+ return `indexeddb://${source.type}/${source.tensorBackend}/${source.taskID}/${source.name}@${version}`;
22
19
  }
23
- infoFor(source) {
20
+ getModelInfo(source) {
24
21
  if (typeof source !== 'string') {
25
22
  return source;
26
23
  }
27
- const [stringType, taskID, fullName] = source.split('/').splice(2);
28
- const type = stringType === 'working' ? ModelType.WORKING : ModelType.SAVED;
24
+ const [type, tensorBackend, taskID, fullName] = source.split('/').splice(2);
25
+ if (type !== 'working' && type !== 'saved') {
26
+ throw Error("Unknown memory model type");
27
+ }
29
28
  const [name, versionSuffix] = fullName.split('@');
30
29
  const version = versionSuffix === undefined ? 0 : Number(versionSuffix);
31
- return { type, taskID, name, version };
30
+ if (tensorBackend !== 'tfjs' && tensorBackend !== 'gpt') {
31
+ throw Error("Unknown tensor backend");
32
+ }
33
+ return { type, taskID, name, version, tensorBackend };
32
34
  }
33
35
  async getModelMetadata(source) {
34
36
  const models = await tf.io.listModels();
35
- return models[this.pathFor(source)];
37
+ return models[this.getModelMemoryPath(source)];
36
38
  }
37
39
  async contains(source) {
38
40
  return await this.getModelMetadata(source) !== undefined;
39
41
  }
40
42
  async getModel(source) {
41
- return new models.TFJS(await tf.loadLayersModel(this.pathFor(source)));
43
+ const layersModel = await tf.loadLayersModel(this.getModelMemoryPath(source));
44
+ const tensorBackend = this.getModelInfo(source).tensorBackend;
45
+ switch (tensorBackend) {
46
+ case 'tfjs':
47
+ return new models.TFJS(layersModel);
48
+ case 'gpt':
49
+ return new models.GPT(undefined, layersModel);
50
+ default: {
51
+ const _ = tensorBackend;
52
+ throw new Error('should never happen');
53
+ }
54
+ }
42
55
  }
43
56
  async deleteModel(source) {
44
- await tf.io.removeModel(this.pathFor(source));
57
+ await tf.io.removeModel(this.getModelMemoryPath(source));
45
58
  }
46
59
  async loadModel(source) {
47
- const src = this.infoFor(source);
48
- if (src.type === ModelType.WORKING) {
60
+ const src = this.getModelInfo(source);
61
+ if (src.type === 'working') {
49
62
  // Model is already loaded
50
63
  return;
51
64
  }
52
- await tf.io.copyModel(this.pathFor(src), this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }));
65
+ await tf.io.copyModel(this.getModelMemoryPath(src), this.getModelMemoryPath({ ...src, type: 'working', version: 0 }));
53
66
  }
54
67
  /**
55
68
  * Saves the working model to the source.
@@ -57,67 +70,81 @@ export class IndexedDB extends Memory {
57
70
  * @param model the model
58
71
  */
59
72
  async updateWorkingModel(source, model) {
60
- const src = this.infoFor(source);
61
- if (src.type !== undefined && src.type !== ModelType.WORKING) {
62
- throw new Error('expected working model');
73
+ const src = this.getModelInfo(source);
74
+ if (src.type !== 'working') {
75
+ throw new Error('expected working type model');
63
76
  }
77
+ // Enforce version 0 to always keep a single working model at a time
78
+ const modelInfo = { ...src, type: 'working', version: 0 };
79
+ let includeOptimizer;
64
80
  if (model instanceof models.TFJS) {
65
- await model.extract().save(this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }), { includeOptimizer: true });
81
+ modelInfo['tensorBackend'] = 'tfjs';
82
+ includeOptimizer = true;
83
+ }
84
+ else if (model instanceof models.GPT) {
85
+ modelInfo['tensorBackend'] = 'gpt';
86
+ includeOptimizer = false; // true raises an error
66
87
  }
67
88
  else {
68
89
  throw new Error('unknown model type');
69
90
  }
70
- // Enforce version 0 to always keep a single working model at a time
91
+ const indexedDBURL = this.getModelMemoryPath(modelInfo);
92
+ await model.extract().save(indexedDBURL, { includeOptimizer });
71
93
  }
72
94
  /**
73
95
  * Creates a saved copy of the working model corresponding to the source.
74
96
  * @param source the source
75
97
  */
76
98
  async saveWorkingModel(source) {
77
- const src = this.infoFor(source);
78
- if (src.type !== undefined && src.type !== ModelType.WORKING) {
79
- throw new Error('expected working model');
99
+ const src = this.getModelInfo(source);
100
+ if (src.type !== 'working') {
101
+ throw new Error('expected working type model');
80
102
  }
81
- const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED }));
82
- await tf.io.copyModel(this.pathFor({ ...src, type: ModelType.WORKING }), dst);
103
+ const dst = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: 'saved' }));
104
+ await tf.io.copyModel(this.getModelMemoryPath({ ...src, type: 'working' }), dst);
83
105
  return dst;
84
106
  }
85
107
  async saveModel(source, model) {
86
- const src = this.infoFor(source);
87
- if (src.type !== undefined && src.type !== ModelType.SAVED) {
88
- throw new Error('expected saved model');
108
+ const src = this.getModelInfo(source);
109
+ if (src.type !== 'saved') {
110
+ throw new Error('expected saved type model');
89
111
  }
90
- const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED }));
112
+ const modelInfo = await this.duplicateSource({ ...src, type: 'saved' });
113
+ let includeOptimizer;
91
114
  if (model instanceof models.TFJS) {
92
- await model.extract().save(dst, { includeOptimizer: true });
115
+ modelInfo['tensorBackend'] = 'tfjs';
116
+ includeOptimizer = true;
117
+ }
118
+ else if (model instanceof models.GPT) {
119
+ modelInfo['tensorBackend'] = 'gpt';
120
+ includeOptimizer = false; // true raises an error
93
121
  }
94
122
  else {
95
123
  throw new Error('unknown model type');
96
124
  }
97
- return dst;
125
+ const indexedDBURL = this.getModelMemoryPath(modelInfo);
126
+ await model.extract().save(indexedDBURL, { includeOptimizer });
127
+ return indexedDBURL;
98
128
  }
99
129
  /**
100
130
  * Downloads the model corresponding to the source.
101
131
  * @param source the source
102
132
  */
103
133
  async downloadModel(source) {
104
- const src = this.infoFor(source);
105
- await tf.io.copyModel(this.pathFor(source), `downloads://${src.taskID}_${src.name}`);
134
+ const src = this.getModelInfo(source);
135
+ await tf.io.copyModel(this.getModelMemoryPath(source), `downloads://${src.taskID}_${src.name}`);
106
136
  }
107
137
  async latestDuplicate(source) {
108
138
  if (typeof source !== 'string') {
109
- source = this.pathFor({ ...source, version: 0 });
139
+ source = this.getModelMemoryPath({ ...source, version: 0 });
110
140
  }
111
141
  // perform a single memory read
112
142
  const paths = Map(await tf.io.listModels());
113
143
  if (!paths.has(source)) {
114
144
  return undefined;
115
145
  }
116
- const latest = Map(paths)
117
- .keySeq()
118
- .toList()
119
- .map((p) => this.infoFor(p).version)
120
- .max();
146
+ const latest = Map(paths).keySeq().toList()
147
+ .map((p) => this.getModelInfo(p).version).max();
121
148
  if (latest === undefined) {
122
149
  return 0;
123
150
  }
@@ -125,7 +152,7 @@ export class IndexedDB extends Memory {
125
152
  }
126
153
  async duplicateSource(source) {
127
154
  const latestDuplicate = await this.latestDuplicate(source);
128
- source = this.infoFor(source);
155
+ source = this.getModelInfo(source);
129
156
  if (latestDuplicate === undefined) {
130
157
  return source;
131
158
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs-web",
3
- "version": "2.1.2-p20240603114517.0",
3
+ "version": "2.1.2-p20240617140649.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",