@epfml/discojs-web 2.1.2-p20240513140724.0 → 2.1.2-p20240515132210.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,8 +2,8 @@ import * as tf from '@tensorflow/tfjs';
2
2
  import type { Path, Model, ModelInfo, ModelSource } from '@epfml/discojs';
3
3
  import { Memory } from '@epfml/discojs';
4
4
  export declare class IndexedDB extends Memory {
5
- pathFor(source: ModelSource): Path;
6
- infoFor(source: ModelSource): ModelInfo;
5
+ getModelMemoryPath(source: ModelSource): Path;
6
+ getModelInfo(source: ModelSource): ModelInfo;
7
7
  getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
8
8
  contains(source: ModelSource): Promise<boolean>;
9
9
  getModel(source: ModelSource): Promise<Model>;
@@ -8,9 +8,9 @@
8
8
  */
9
9
  import { Map } from 'immutable';
10
10
  import * as tf from '@tensorflow/tfjs';
11
- import { Memory, ModelType, models } from '@epfml/discojs';
11
+ import { Memory, StoredModelType, models } from '@epfml/discojs';
12
12
  export class IndexedDB extends Memory {
13
- pathFor(source) {
13
+ getModelMemoryPath(source) {
14
14
  if (typeof source === 'string') {
15
15
  return source;
16
16
  }
@@ -20,36 +20,41 @@ export class IndexedDB extends Memory {
20
20
  const version = source.version ?? 0;
21
21
  return `indexeddb://${source.type}/${source.taskID}/${source.name}@${version}`;
22
22
  }
23
- infoFor(source) {
23
+ getModelInfo(source) {
24
24
  if (typeof source !== 'string') {
25
25
  return source;
26
26
  }
27
27
  const [stringType, taskID, fullName] = source.split('/').splice(2);
28
- const type = stringType === 'working' ? ModelType.WORKING : ModelType.SAVED;
28
+ const type = stringType === 'working' ? StoredModelType.WORKING : StoredModelType.SAVED;
29
29
  const [name, versionSuffix] = fullName.split('@');
30
30
  const version = versionSuffix === undefined ? 0 : Number(versionSuffix);
31
31
  return { type, taskID, name, version };
32
32
  }
33
33
  async getModelMetadata(source) {
34
34
  const models = await tf.io.listModels();
35
- return models[this.pathFor(source)];
35
+ return models[this.getModelMemoryPath(source)];
36
36
  }
37
37
  async contains(source) {
38
38
  return await this.getModelMetadata(source) !== undefined;
39
39
  }
40
40
  async getModel(source) {
41
- return new models.TFJS(await tf.loadLayersModel(this.pathFor(source)));
41
+ console.log("source", source);
42
+ console.log("memory path", this.getModelMemoryPath(source));
43
+ const layersModel = await tf.loadLayersModel(this.getModelMemoryPath(source));
44
+ console.log("layers model", layersModel);
45
+ return new models.TFJS(layersModel);
42
46
  }
43
47
  async deleteModel(source) {
44
- await tf.io.removeModel(this.pathFor(source));
48
+ await tf.io.removeModel(this.getModelMemoryPath(source));
45
49
  }
46
50
  async loadModel(source) {
47
- const src = this.infoFor(source);
48
- if (src.type === ModelType.WORKING) {
51
+ console.log("Loading model");
52
+ const src = this.getModelInfo(source);
53
+ if (src.type === StoredModelType.WORKING) {
49
54
  // Model is already loaded
50
55
  return;
51
56
  }
52
- await tf.io.copyModel(this.pathFor(src), this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }));
57
+ await tf.io.copyModel(this.getModelMemoryPath(src), this.getModelMemoryPath({ ...src, type: StoredModelType.WORKING, version: 0 }));
53
58
  }
54
59
  /**
55
60
  * Saves the working model to the source.
@@ -57,12 +62,16 @@ export class IndexedDB extends Memory {
57
62
  * @param model the model
58
63
  */
59
64
  async updateWorkingModel(source, model) {
60
- const src = this.infoFor(source);
61
- if (src.type !== undefined && src.type !== ModelType.WORKING) {
65
+ const src = this.getModelInfo(source);
66
+ if (src.type !== undefined && src.type !== StoredModelType.WORKING) {
62
67
  throw new Error('expected working model');
63
68
  }
69
+ const indexedDBURL = this.getModelMemoryPath({ ...src, type: StoredModelType.WORKING, version: 0 });
64
70
  if (model instanceof models.TFJS) {
65
- await model.extract().save(this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }), { includeOptimizer: true });
71
+ await model.extract().save(indexedDBURL, { includeOptimizer: true });
72
+ }
73
+ else if (model instanceof models.GPT) {
74
+ await model.extract().save(indexedDBURL, { includeOptimizer: false }); // true raises an error
66
75
  }
67
76
  else {
68
77
  throw new Error('unknown model type');
@@ -74,39 +83,42 @@ export class IndexedDB extends Memory {
74
83
  * @param source the source
75
84
  */
76
85
  async saveWorkingModel(source) {
77
- const src = this.infoFor(source);
78
- if (src.type !== undefined && src.type !== ModelType.WORKING) {
86
+ const src = this.getModelInfo(source);
87
+ if (src.type !== undefined && src.type !== StoredModelType.WORKING) {
79
88
  throw new Error('expected working model');
80
89
  }
81
- const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED }));
82
- await tf.io.copyModel(this.pathFor({ ...src, type: ModelType.WORKING }), dst);
90
+ const dst = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: StoredModelType.SAVED }));
91
+ await tf.io.copyModel(this.getModelMemoryPath({ ...src, type: StoredModelType.WORKING }), dst);
83
92
  return dst;
84
93
  }
85
94
  async saveModel(source, model) {
86
- const src = this.infoFor(source);
87
- if (src.type !== undefined && src.type !== ModelType.SAVED) {
95
+ const src = this.getModelInfo(source);
96
+ if (src.type !== undefined && src.type !== StoredModelType.SAVED) {
88
97
  throw new Error('expected saved model');
89
98
  }
90
- const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED }));
99
+ const indexedDBURL = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: StoredModelType.SAVED }));
91
100
  if (model instanceof models.TFJS) {
92
- await model.extract().save(dst, { includeOptimizer: true });
101
+ await model.extract().save(indexedDBURL, { includeOptimizer: true });
102
+ }
103
+ else if (model instanceof models.GPT) {
104
+ await model.extract().save(indexedDBURL, { includeOptimizer: false }); // true raises an error
93
105
  }
94
106
  else {
95
107
  throw new Error('unknown model type');
96
108
  }
97
- return dst;
109
+ return indexedDBURL;
98
110
  }
99
111
  /**
100
112
  * Downloads the model corresponding to the source.
101
113
  * @param source the source
102
114
  */
103
115
  async downloadModel(source) {
104
- const src = this.infoFor(source);
105
- await tf.io.copyModel(this.pathFor(source), `downloads://${src.taskID}_${src.name}`);
116
+ const src = this.getModelInfo(source);
117
+ await tf.io.copyModel(this.getModelMemoryPath(source), `downloads://${src.taskID}_${src.name}`);
106
118
  }
107
119
  async latestDuplicate(source) {
108
120
  if (typeof source !== 'string') {
109
- source = this.pathFor({ ...source, version: 0 });
121
+ source = this.getModelMemoryPath({ ...source, version: 0 });
110
122
  }
111
123
  // perform a single memory read
112
124
  const paths = Map(await tf.io.listModels());
@@ -116,7 +128,7 @@ export class IndexedDB extends Memory {
116
128
  const latest = Map(paths)
117
129
  .keySeq()
118
130
  .toList()
119
- .map((p) => this.infoFor(p).version)
131
+ .map((p) => this.getModelInfo(p).version)
120
132
  .max();
121
133
  if (latest === undefined) {
122
134
  return 0;
@@ -125,7 +137,7 @@ export class IndexedDB extends Memory {
125
137
  }
126
138
  async duplicateSource(source) {
127
139
  const latestDuplicate = await this.latestDuplicate(source);
128
- source = this.infoFor(source);
140
+ source = this.getModelInfo(source);
129
141
  if (latestDuplicate === undefined) {
130
142
  return source;
131
143
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs-web",
3
- "version": "2.1.2-p20240513140724.0",
3
+ "version": "2.1.2-p20240515132210.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",