@epfml/discojs-web 2.1.2-p20240515132210.0 → 2.1.2-p20240515133413.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,8 +2,8 @@ import * as tf from '@tensorflow/tfjs';
2
2
  import type { Path, Model, ModelInfo, ModelSource } from '@epfml/discojs';
3
3
  import { Memory } from '@epfml/discojs';
4
4
  export declare class IndexedDB extends Memory {
5
- getModelMemoryPath(source: ModelSource): Path;
6
- getModelInfo(source: ModelSource): ModelInfo;
5
+ pathFor(source: ModelSource): Path;
6
+ infoFor(source: ModelSource): ModelInfo;
7
7
  getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
8
8
  contains(source: ModelSource): Promise<boolean>;
9
9
  getModel(source: ModelSource): Promise<Model>;
@@ -8,9 +8,9 @@
8
8
  */
9
9
  import { Map } from 'immutable';
10
10
  import * as tf from '@tensorflow/tfjs';
11
- import { Memory, StoredModelType, models } from '@epfml/discojs';
11
+ import { Memory, ModelType, models } from '@epfml/discojs';
12
12
  export class IndexedDB extends Memory {
13
- getModelMemoryPath(source) {
13
+ pathFor(source) {
14
14
  if (typeof source === 'string') {
15
15
  return source;
16
16
  }
@@ -20,41 +20,36 @@ export class IndexedDB extends Memory {
20
20
  const version = source.version ?? 0;
21
21
  return `indexeddb://${source.type}/${source.taskID}/${source.name}@${version}`;
22
22
  }
23
- getModelInfo(source) {
23
+ infoFor(source) {
24
24
  if (typeof source !== 'string') {
25
25
  return source;
26
26
  }
27
27
  const [stringType, taskID, fullName] = source.split('/').splice(2);
28
- const type = stringType === 'working' ? StoredModelType.WORKING : StoredModelType.SAVED;
28
+ const type = stringType === 'working' ? ModelType.WORKING : ModelType.SAVED;
29
29
  const [name, versionSuffix] = fullName.split('@');
30
30
  const version = versionSuffix === undefined ? 0 : Number(versionSuffix);
31
31
  return { type, taskID, name, version };
32
32
  }
33
33
  async getModelMetadata(source) {
34
34
  const models = await tf.io.listModels();
35
- return models[this.getModelMemoryPath(source)];
35
+ return models[this.pathFor(source)];
36
36
  }
37
37
  async contains(source) {
38
38
  return await this.getModelMetadata(source) !== undefined;
39
39
  }
40
40
  async getModel(source) {
41
- console.log("source", source);
42
- console.log("memory path", this.getModelMemoryPath(source));
43
- const layersModel = await tf.loadLayersModel(this.getModelMemoryPath(source));
44
- console.log("layers model", layersModel);
45
- return new models.TFJS(layersModel);
41
+ return new models.TFJS(await tf.loadLayersModel(this.pathFor(source)));
46
42
  }
47
43
  async deleteModel(source) {
48
- await tf.io.removeModel(this.getModelMemoryPath(source));
44
+ await tf.io.removeModel(this.pathFor(source));
49
45
  }
50
46
  async loadModel(source) {
51
- console.log("Loading model");
52
- const src = this.getModelInfo(source);
53
- if (src.type === StoredModelType.WORKING) {
47
+ const src = this.infoFor(source);
48
+ if (src.type === ModelType.WORKING) {
54
49
  // Model is already loaded
55
50
  return;
56
51
  }
57
- await tf.io.copyModel(this.getModelMemoryPath(src), this.getModelMemoryPath({ ...src, type: StoredModelType.WORKING, version: 0 }));
52
+ await tf.io.copyModel(this.pathFor(src), this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }));
58
53
  }
59
54
  /**
60
55
  * Saves the working model to the source.
@@ -62,16 +57,12 @@ export class IndexedDB extends Memory {
62
57
  * @param model the model
63
58
  */
64
59
  async updateWorkingModel(source, model) {
65
- const src = this.getModelInfo(source);
66
- if (src.type !== undefined && src.type !== StoredModelType.WORKING) {
60
+ const src = this.infoFor(source);
61
+ if (src.type !== undefined && src.type !== ModelType.WORKING) {
67
62
  throw new Error('expected working model');
68
63
  }
69
- const indexedDBURL = this.getModelMemoryPath({ ...src, type: StoredModelType.WORKING, version: 0 });
70
64
  if (model instanceof models.TFJS) {
71
- await model.extract().save(indexedDBURL, { includeOptimizer: true });
72
- }
73
- else if (model instanceof models.GPT) {
74
- await model.extract().save(indexedDBURL, { includeOptimizer: false }); // true raises an error
65
+ await model.extract().save(this.pathFor({ ...src, type: ModelType.WORKING, version: 0 }), { includeOptimizer: true });
75
66
  }
76
67
  else {
77
68
  throw new Error('unknown model type');
@@ -83,42 +74,39 @@ export class IndexedDB extends Memory {
83
74
  * @param source the source
84
75
  */
85
76
  async saveWorkingModel(source) {
86
- const src = this.getModelInfo(source);
87
- if (src.type !== undefined && src.type !== StoredModelType.WORKING) {
77
+ const src = this.infoFor(source);
78
+ if (src.type !== undefined && src.type !== ModelType.WORKING) {
88
79
  throw new Error('expected working model');
89
80
  }
90
- const dst = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: StoredModelType.SAVED }));
91
- await tf.io.copyModel(this.getModelMemoryPath({ ...src, type: StoredModelType.WORKING }), dst);
81
+ const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED }));
82
+ await tf.io.copyModel(this.pathFor({ ...src, type: ModelType.WORKING }), dst);
92
83
  return dst;
93
84
  }
94
85
  async saveModel(source, model) {
95
- const src = this.getModelInfo(source);
96
- if (src.type !== undefined && src.type !== StoredModelType.SAVED) {
86
+ const src = this.infoFor(source);
87
+ if (src.type !== undefined && src.type !== ModelType.SAVED) {
97
88
  throw new Error('expected saved model');
98
89
  }
99
- const indexedDBURL = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: StoredModelType.SAVED }));
90
+ const dst = this.pathFor(await this.duplicateSource({ ...src, type: ModelType.SAVED }));
100
91
  if (model instanceof models.TFJS) {
101
- await model.extract().save(indexedDBURL, { includeOptimizer: true });
102
- }
103
- else if (model instanceof models.GPT) {
104
- await model.extract().save(indexedDBURL, { includeOptimizer: false }); // true raises an error
92
+ await model.extract().save(dst, { includeOptimizer: true });
105
93
  }
106
94
  else {
107
95
  throw new Error('unknown model type');
108
96
  }
109
- return indexedDBURL;
97
+ return dst;
110
98
  }
111
99
  /**
112
100
  * Downloads the model corresponding to the source.
113
101
  * @param source the source
114
102
  */
115
103
  async downloadModel(source) {
116
- const src = this.getModelInfo(source);
117
- await tf.io.copyModel(this.getModelMemoryPath(source), `downloads://${src.taskID}_${src.name}`);
104
+ const src = this.infoFor(source);
105
+ await tf.io.copyModel(this.pathFor(source), `downloads://${src.taskID}_${src.name}`);
118
106
  }
119
107
  async latestDuplicate(source) {
120
108
  if (typeof source !== 'string') {
121
- source = this.getModelMemoryPath({ ...source, version: 0 });
109
+ source = this.pathFor({ ...source, version: 0 });
122
110
  }
123
111
  // perform a single memory read
124
112
  const paths = Map(await tf.io.listModels());
@@ -128,7 +116,7 @@ export class IndexedDB extends Memory {
128
116
  const latest = Map(paths)
129
117
  .keySeq()
130
118
  .toList()
131
- .map((p) => this.getModelInfo(p).version)
119
+ .map((p) => this.infoFor(p).version)
132
120
  .max();
133
121
  if (latest === undefined) {
134
122
  return 0;
@@ -137,7 +125,7 @@ export class IndexedDB extends Memory {
137
125
  }
138
126
  async duplicateSource(source) {
139
127
  const latestDuplicate = await this.latestDuplicate(source);
140
- source = this.getModelInfo(source);
128
+ source = this.infoFor(source);
141
129
  if (latestDuplicate === undefined) {
142
130
  return source;
143
131
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs-web",
3
- "version": "2.1.2-p20240515132210.0",
3
+ "version": "2.1.2-p20240515133413.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",