@epfml/discojs-web 3.0.1-p20240902100041.0 → 3.0.1-p20240904094219.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -1,2 +1 @@
1
1
  export * from "./loaders/index.js";
2
- export * from "./memory/index.js";
package/dist/index.js CHANGED
@@ -1,2 +1 @@
1
1
  export * from "./loaders/index.js";
2
- export * from "./memory/index.js";
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs-web",
3
- "version": "3.0.1-p20240902100041.0",
3
+ "version": "3.0.1-p20240904094219.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1 +0,0 @@
1
- export { IndexedDB } from './memory.js';
@@ -1 +0,0 @@
1
- export { IndexedDB } from './memory.js';
@@ -1,31 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import type { Model, ModelInfo, ModelSource } from '@epfml/discojs';
3
- import { Memory } from '@epfml/discojs';
4
- export declare class IndexedDB extends Memory {
5
- getModelMemoryPath(source: ModelSource): string;
6
- getModelInfo(source: ModelSource): ModelInfo;
7
- getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
8
- contains(source: ModelSource): Promise<boolean>;
9
- getModel(source: ModelSource): Promise<Model>;
10
- deleteModel(source: ModelSource): Promise<void>;
11
- loadModel(source: ModelSource): Promise<void>;
12
- /**
13
- * Saves the working model to the source.
14
- * @param source the destination
15
- * @param model the model
16
- */
17
- updateWorkingModel(source: ModelSource, model: Model): Promise<void>;
18
- /**
19
- * Creates a saved copy of the working model corresponding to the source.
20
- * @param source the source
21
- */
22
- saveWorkingModel(source: ModelSource): Promise<string>;
23
- saveModel(source: ModelSource, model: Model): Promise<string>;
24
- /**
25
- * Downloads the model corresponding to the source.
26
- * @param source the source
27
- */
28
- downloadModel(source: ModelSource): Promise<void>;
29
- latestDuplicate(source: ModelSource): Promise<number | undefined>;
30
- duplicateSource(source: ModelSource): Promise<ModelInfo>;
31
- }
@@ -1,161 +0,0 @@
1
- /**
2
- * Helper functions used to load and save TFJS models from IndexedDB. The
3
- * working model is the model currently being trained for a task. Saved models
4
- * are models that were explicitly saved to IndexedDB. The two working/ and saved/
5
- * folders are invisible to the user. The user only interacts with the saved/
6
- * folder via the model library. The working/ folder is only used by the backend.
7
- * The working model is loaded from IndexedDB for training (model.fit) only.
8
- */
9
- import { Map } from 'immutable';
10
- import * as tf from '@tensorflow/tfjs';
11
- import { Memory, models } from '@epfml/discojs';
12
- export class IndexedDB extends Memory {
13
- getModelMemoryPath(source) {
14
- if (typeof source === 'string') {
15
- return source;
16
- }
17
- const version = source.version ?? 0;
18
- return `indexeddb://${source.type}/${source.tensorBackend}/${source.taskID}/${source.name}@${version}`;
19
- }
20
- getModelInfo(source) {
21
- if (typeof source !== 'string') {
22
- return source;
23
- }
24
- const [type, tensorBackend, taskID, fullName] = source.split('/').splice(2);
25
- if (type !== 'working' && type !== 'saved') {
26
- throw Error("Unknown memory model type");
27
- }
28
- const [name, versionSuffix] = fullName.split('@');
29
- const version = versionSuffix === undefined ? 0 : Number(versionSuffix);
30
- if (tensorBackend !== 'tfjs' && tensorBackend !== 'gpt') {
31
- throw Error("Unknown tensor backend");
32
- }
33
- return { type, taskID, name, version, tensorBackend };
34
- }
35
- async getModelMetadata(source) {
36
- const models = await tf.io.listModels();
37
- return models[this.getModelMemoryPath(source)];
38
- }
39
- async contains(source) {
40
- return await this.getModelMetadata(source) !== undefined;
41
- }
42
- async getModel(source) {
43
- const layersModel = await tf.loadLayersModel(this.getModelMemoryPath(source));
44
- const tensorBackend = this.getModelInfo(source).tensorBackend;
45
- switch (tensorBackend) {
46
- case 'tfjs':
47
- return new models.TFJS(layersModel);
48
- case 'gpt':
49
- return new models.GPT(undefined, layersModel);
50
- default: {
51
- const _ = tensorBackend;
52
- throw new Error('should never happen');
53
- }
54
- }
55
- }
56
- async deleteModel(source) {
57
- await tf.io.removeModel(this.getModelMemoryPath(source));
58
- }
59
- async loadModel(source) {
60
- const src = this.getModelInfo(source);
61
- if (src.type === 'working') {
62
- // Model is already loaded
63
- return;
64
- }
65
- await tf.io.copyModel(this.getModelMemoryPath(src), this.getModelMemoryPath({ ...src, type: 'working', version: 0 }));
66
- }
67
- /**
68
- * Saves the working model to the source.
69
- * @param source the destination
70
- * @param model the model
71
- */
72
- async updateWorkingModel(source, model) {
73
- const src = this.getModelInfo(source);
74
- if (src.type !== 'working') {
75
- throw new Error('expected working type model');
76
- }
77
- // Enforce version 0 to always keep a single working model at a time
78
- const modelInfo = { ...src, type: 'working', version: 0 };
79
- let includeOptimizer;
80
- if (model instanceof models.TFJS) {
81
- modelInfo['tensorBackend'] = 'tfjs';
82
- includeOptimizer = true;
83
- }
84
- else if (model instanceof models.GPT) {
85
- modelInfo['tensorBackend'] = 'gpt';
86
- includeOptimizer = false; // true raises an error
87
- }
88
- else {
89
- throw new Error('unknown model type');
90
- }
91
- const indexedDBURL = this.getModelMemoryPath(modelInfo);
92
- await model.extract().save(indexedDBURL, { includeOptimizer });
93
- }
94
- /**
95
- * Creates a saved copy of the working model corresponding to the source.
96
- * @param source the source
97
- */
98
- async saveWorkingModel(source) {
99
- const src = this.getModelInfo(source);
100
- if (src.type !== 'working') {
101
- throw new Error('expected working type model');
102
- }
103
- const dst = this.getModelMemoryPath(await this.duplicateSource({ ...src, type: 'saved' }));
104
- await tf.io.copyModel(this.getModelMemoryPath({ ...src, type: 'working' }), dst);
105
- return dst;
106
- }
107
- async saveModel(source, model) {
108
- const src = this.getModelInfo(source);
109
- if (src.type !== 'saved') {
110
- throw new Error('expected saved type model');
111
- }
112
- const modelInfo = await this.duplicateSource({ ...src, type: 'saved' });
113
- let includeOptimizer;
114
- if (model instanceof models.TFJS) {
115
- modelInfo['tensorBackend'] = 'tfjs';
116
- includeOptimizer = true;
117
- }
118
- else if (model instanceof models.GPT) {
119
- modelInfo['tensorBackend'] = 'gpt';
120
- includeOptimizer = false; // true raises an error
121
- }
122
- else {
123
- throw new Error('unknown model type');
124
- }
125
- const indexedDBURL = this.getModelMemoryPath(modelInfo);
126
- await model.extract().save(indexedDBURL, { includeOptimizer });
127
- return indexedDBURL;
128
- }
129
- /**
130
- * Downloads the model corresponding to the source.
131
- * @param source the source
132
- */
133
- async downloadModel(source) {
134
- const src = this.getModelInfo(source);
135
- await tf.io.copyModel(this.getModelMemoryPath(source), `downloads://${src.taskID}_${src.name}`);
136
- }
137
- async latestDuplicate(source) {
138
- if (typeof source !== 'string') {
139
- source = this.getModelMemoryPath({ ...source, version: 0 });
140
- }
141
- // perform a single memory read
142
- const paths = Map(await tf.io.listModels());
143
- if (!paths.has(source)) {
144
- return undefined;
145
- }
146
- const latest = Map(paths).keySeq().toList()
147
- .map((p) => this.getModelInfo(p).version).max();
148
- if (latest === undefined) {
149
- return 0;
150
- }
151
- return latest;
152
- }
153
- async duplicateSource(source) {
154
- const latestDuplicate = await this.latestDuplicate(source);
155
- source = this.getModelInfo(source);
156
- if (latestDuplicate === undefined) {
157
- return source;
158
- }
159
- return { ...source, version: latestDuplicate + 1 };
160
- }
161
- }