@epfml/discojs-web 2.1.2-p20240603114517.0 → 2.1.2-p20240617070831.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/data/text_loader.js +1 -1
- package/dist/memory/memory.d.ts +2 -2
- package/dist/memory/memory.js +68 -41
- package/package.json +1 -1
package/dist/data/text_loader.js
CHANGED
|
@@ -3,7 +3,7 @@ import { data } from '@epfml/discojs';
|
|
|
3
3
|
export class TextLoader extends data.TextLoader {
|
|
4
4
|
loadDatasetFrom(source) {
|
|
5
5
|
const file = new tf.data.FileDataSource(source);
|
|
6
|
-
const dataset = new tf.data.TextLineDataset(file).filter(s => s
|
|
6
|
+
const dataset = new tf.data.TextLineDataset(file).filter(s => s !== ' '); // newline creates empty strings
|
|
7
7
|
return Promise.resolve(dataset);
|
|
8
8
|
}
|
|
9
9
|
}
|
package/dist/memory/memory.d.ts
CHANGED
|
@@ -2,8 +2,8 @@ import * as tf from '@tensorflow/tfjs';
|
|
|
2
2
|
import type { Path, Model, ModelInfo, ModelSource } from '@epfml/discojs';
|
|
3
3
|
import { Memory } from '@epfml/discojs';
|
|
4
4
|
export declare class IndexedDB extends Memory {
|
|
5
|
-
|
|
6
|
-
|
|
5
|
+
getModelMemoryPath(source: ModelSource): Path;
|
|
6
|
+
getModelInfo(source: ModelSource): ModelInfo;
|
|
7
7
|
getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
|
|
8
8
|
contains(source: ModelSource): Promise<boolean>;
|
|
9
9
|
getModel(source: ModelSource): Promise<Model>;
|
package/dist/memory/memory.js
CHANGED
|
@@ -8,48 +8,61 @@
|
|
|
8
8
|
*/
|
|
9
9
|
import { Map } from 'immutable';
|
|
10
10
|
import * as tf from '@tensorflow/tfjs';
|
|
11
|
-
import { Memory,
|
|
11
|
+
import { Memory, models } from '@epfml/discojs';
|
|
12
12
|
export class IndexedDB extends Memory {
|
|
13
|
-
|
|
13
|
+
getModelMemoryPath(source) {
|
|
14
14
|
if (typeof source === 'string') {
|
|
15
15
|
return source;
|
|
16
16
|
}
|
|
17
|
-
if (source.type === undefined || source.taskID === undefined || source.name === undefined) {
|
|
18
|
-
throw new TypeError('source incomplete');
|
|
19
|
-
}
|
|
20
17
|
const version = source.version ?? 0;
|
|
21
|
-
return `indexeddb://${source.type}/${source.taskID}/${source.name}@${version}`;
|
|
18
|
+
return `indexeddb://${source.type}/${source.tensorBackend}/${source.taskID}/${source.name}@${version}`;
|
|
22
19
|
}
|
|
23
|
-
|
|
20
|
+
getModelInfo(source) {
|
|
24
21
|
if (typeof source !== 'string') {
|
|
25
22
|
return source;
|
|
26
23
|
}
|
|
27
|
-
const [
|
|
28
|
-
|
|
24
|
+
const [type, tensorBackend, taskID, fullName] = source.split('/').splice(2);
|
|
25
|
+
if (type !== 'working' && type !== 'saved') {
|
|
26
|
+
throw Error("Unknown memory model type");
|
|
27
|
+
}
|
|
29
28
|
const [name, versionSuffix] = fullName.split('@');
|
|
30
29
|
const version = versionSuffix === undefined ? 0 : Number(versionSuffix);
|
|
31
|
-
|
|
30
|
+
if (tensorBackend !== 'tfjs' && tensorBackend !== 'gpt') {
|
|
31
|
+
throw Error("Unknown tensor backend");
|
|
32
|
+
}
|
|
33
|
+
return { type, taskID, name, version, tensorBackend };
|
|
32
34
|
}
|
|
33
35
|
async getModelMetadata(source) {
|
|
34
36
|
const models = await tf.io.listModels();
|
|
35
|
-
return models[this.
|
|
37
|
+
return models[this.getModelMemoryPath(source)];
|
|
36
38
|
}
|
|
37
39
|
async contains(source) {
|
|
38
40
|
return await this.getModelMetadata(source) !== undefined;
|
|
39
41
|
}
|
|
40
42
|
async getModel(source) {
|
|
41
|
-
|
|
43
|
+
const layersModel = await tf.loadLayersModel(this.getModelMemoryPath(source));
|
|
44
|
+
const tensorBackend = this.getModelInfo(source).tensorBackend;
|
|
45
|
+
switch (tensorBackend) {
|
|
46
|
+
case 'tfjs':
|
|
47
|
+
return new models.TFJS(layersModel);
|
|
48
|
+
case 'gpt':
|
|
49
|
+
return new models.GPT(undefined, layersModel);
|
|
50
|
+
default: {
|
|
51
|
+
const _ = tensorBackend;
|
|
52
|
+
throw new Error('should never happen');
|
|
53
|
+
}
|
|
54
|
+
}
|
|
42
55
|
}
|
|
43
56
|
async deleteModel(source) {
|
|
44
|
-
await tf.io.removeModel(this.
|
|
57
|
+
await tf.io.removeModel(this.getModelMemoryPath(source));
|
|
45
58
|
}
|
|
46
59
|
async loadModel(source) {
|
|
47
|
-
const src = this.
|
|
48
|
-
if (src.type ===
|
|
60
|
+
const src = this.getModelInfo(source);
|
|
61
|
+
if (src.type === 'working') {
|
|
49
62
|
// Model is already loaded
|
|
50
63
|
return;
|
|
51
64
|
}
|
|
52
|
-
await tf.io.copyModel(this.
|
|
65
|
+
await tf.io.copyModel(this.getModelMemoryPath(src), this.getModelMemoryPath({ ...src, type: 'working', version: 0 }));
|
|
53
66
|
}
|
|
54
67
|
/**
|
|
55
68
|
* Saves the working model to the source.
|
|
@@ -57,67 +70,81 @@ export class IndexedDB extends Memory {
|
|
|
57
70
|
* @param model the model
|
|
58
71
|
*/
|
|
59
72
|
async updateWorkingModel(source, model) {
|
|
60
|
-
const src = this.
|
|
61
|
-
if (src.type !==
|
|
62
|
-
throw new Error('expected working model');
|
|
73
|
+
const src = this.getModelInfo(source);
|
|
74
|
+
if (src.type !== 'working') {
|
|
75
|
+
throw new Error('expected working type model');
|
|
63
76
|
}
|
|
77
|
+
// Enforce version 0 to always keep a single working model at a time
|
|
78
|
+
const modelInfo = { ...src, type: 'working', version: 0 };
|
|
79
|
+
let includeOptimizer;
|
|
64
80
|
if (model instanceof models.TFJS) {
|
|
65
|
-
|
|
81
|
+
modelInfo['tensorBackend'] = 'tfjs';
|
|
82
|
+
includeOptimizer = true;
|
|
83
|
+
}
|
|
84
|
+
else if (model instanceof models.GPT) {
|
|
85
|
+
modelInfo['tensorBackend'] = 'gpt';
|
|
86
|
+
includeOptimizer = false; // true raises an error
|
|
66
87
|
}
|
|
67
88
|
else {
|
|
68
89
|
throw new Error('unknown model type');
|
|
69
90
|
}
|
|
70
|
-
|
|
91
|
+
const indexedDBURL = this.getModelMemoryPath(modelInfo);
|
|
92
|
+
await model.extract().save(indexedDBURL, { includeOptimizer });
|
|
71
93
|
}
|
|
72
94
|
/**
|
|
73
95
|
* Creates a saved copy of the working model corresponding to the source.
|
|
74
96
|
* @param source the source
|
|
75
97
|
*/
|
|
76
98
|
async saveWorkingModel(source) {
|
|
77
|
-
const src = this.
|
|
78
|
-
if (src.type !==
|
|
79
|
-
throw new Error('expected working model');
|
|
99
|
+
const src = this.getModelInfo(source);
|
|
100
|
+
if (src.type !== 'working') {
|
|
101
|
+
throw new Error('expected working type model');
|
|
80
102
|
}
|
|
81
|
-
const dst = this.
|
|
82
|
-
await tf.io.copyModel(this.
|
|
103
|
+
const dst = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: 'saved' }));
|
|
104
|
+
await tf.io.copyModel(this.getModelMemoryPath({ ...src, type: 'working' }), dst);
|
|
83
105
|
return dst;
|
|
84
106
|
}
|
|
85
107
|
async saveModel(source, model) {
|
|
86
|
-
const src = this.
|
|
87
|
-
if (src.type !==
|
|
88
|
-
throw new Error('expected saved model');
|
|
108
|
+
const src = this.getModelInfo(source);
|
|
109
|
+
if (src.type !== 'saved') {
|
|
110
|
+
throw new Error('expected saved type model');
|
|
89
111
|
}
|
|
90
|
-
const
|
|
112
|
+
const modelInfo = await this.duplicateSource({ ...src, type: 'saved' });
|
|
113
|
+
let includeOptimizer;
|
|
91
114
|
if (model instanceof models.TFJS) {
|
|
92
|
-
|
|
115
|
+
modelInfo['tensorBackend'] = 'tfjs';
|
|
116
|
+
includeOptimizer = true;
|
|
117
|
+
}
|
|
118
|
+
else if (model instanceof models.GPT) {
|
|
119
|
+
modelInfo['tensorBackend'] = 'gpt';
|
|
120
|
+
includeOptimizer = false; // true raises an error
|
|
93
121
|
}
|
|
94
122
|
else {
|
|
95
123
|
throw new Error('unknown model type');
|
|
96
124
|
}
|
|
97
|
-
|
|
125
|
+
const indexedDBURL = this.getModelMemoryPath(modelInfo);
|
|
126
|
+
await model.extract().save(indexedDBURL, { includeOptimizer });
|
|
127
|
+
return indexedDBURL;
|
|
98
128
|
}
|
|
99
129
|
/**
|
|
100
130
|
* Downloads the model corresponding to the source.
|
|
101
131
|
* @param source the source
|
|
102
132
|
*/
|
|
103
133
|
async downloadModel(source) {
|
|
104
|
-
const src = this.
|
|
105
|
-
await tf.io.copyModel(this.
|
|
134
|
+
const src = this.getModelInfo(source);
|
|
135
|
+
await tf.io.copyModel(this.getModelMemoryPath(source), `downloads://${src.taskID}_${src.name}`);
|
|
106
136
|
}
|
|
107
137
|
async latestDuplicate(source) {
|
|
108
138
|
if (typeof source !== 'string') {
|
|
109
|
-
source = this.
|
|
139
|
+
source = this.getModelMemoryPath({ ...source, version: 0 });
|
|
110
140
|
}
|
|
111
141
|
// perform a single memory read
|
|
112
142
|
const paths = Map(await tf.io.listModels());
|
|
113
143
|
if (!paths.has(source)) {
|
|
114
144
|
return undefined;
|
|
115
145
|
}
|
|
116
|
-
const latest = Map(paths)
|
|
117
|
-
.
|
|
118
|
-
.toList()
|
|
119
|
-
.map((p) => this.infoFor(p).version)
|
|
120
|
-
.max();
|
|
146
|
+
const latest = Map(paths).keySeq().toList()
|
|
147
|
+
.map((p) => this.getModelInfo(p).version).max();
|
|
121
148
|
if (latest === undefined) {
|
|
122
149
|
return 0;
|
|
123
150
|
}
|
|
@@ -125,7 +152,7 @@ export class IndexedDB extends Memory {
|
|
|
125
152
|
}
|
|
126
153
|
async duplicateSource(source) {
|
|
127
154
|
const latestDuplicate = await this.latestDuplicate(source);
|
|
128
|
-
source = this.
|
|
155
|
+
source = this.getModelInfo(source);
|
|
129
156
|
if (latestDuplicate === undefined) {
|
|
130
157
|
return source;
|
|
131
158
|
}
|