edgeflowjs 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +200 -66
- package/dist/backends/index.d.ts +9 -2
- package/dist/backends/index.d.ts.map +1 -1
- package/dist/backends/index.js +13 -13
- package/dist/backends/index.js.map +1 -1
- package/dist/backends/onnx.d.ts +11 -4
- package/dist/backends/onnx.d.ts.map +1 -1
- package/dist/backends/onnx.js +97 -78
- package/dist/backends/onnx.js.map +1 -1
- package/dist/backends/transformers-adapter.d.ts +99 -0
- package/dist/backends/transformers-adapter.d.ts.map +1 -0
- package/dist/backends/transformers-adapter.js +171 -0
- package/dist/backends/transformers-adapter.js.map +1 -0
- package/dist/backends/webgpu.d.ts +7 -5
- package/dist/backends/webgpu.d.ts.map +1 -1
- package/dist/backends/webgpu.js +7 -5
- package/dist/backends/webgpu.js.map +1 -1
- package/dist/backends/webnn.d.ts +6 -5
- package/dist/backends/webnn.d.ts.map +1 -1
- package/dist/backends/webnn.js +6 -5
- package/dist/backends/webnn.js.map +1 -1
- package/dist/core/composer.d.ts +118 -0
- package/dist/core/composer.d.ts.map +1 -0
- package/dist/core/composer.js +163 -0
- package/dist/core/composer.js.map +1 -0
- package/dist/core/device-profiler.d.ts +75 -0
- package/dist/core/device-profiler.d.ts.map +1 -0
- package/dist/core/device-profiler.js +131 -0
- package/dist/core/device-profiler.js.map +1 -0
- package/dist/core/index.d.ts +4 -0
- package/dist/core/index.d.ts.map +1 -1
- package/dist/core/index.js +8 -0
- package/dist/core/index.js.map +1 -1
- package/dist/core/memory.d.ts +22 -2
- package/dist/core/memory.d.ts.map +1 -1
- package/dist/core/memory.js +49 -13
- package/dist/core/memory.js.map +1 -1
- package/dist/core/plugin.d.ts +100 -0
- package/dist/core/plugin.d.ts.map +1 -0
- package/dist/core/plugin.js +106 -0
- package/dist/core/plugin.js.map +1 -0
- package/dist/core/runtime.d.ts +4 -0
- package/dist/core/runtime.d.ts.map +1 -1
- package/dist/core/runtime.js +18 -0
- package/dist/core/runtime.js.map +1 -1
- package/dist/core/scheduler.d.ts +17 -0
- package/dist/core/scheduler.d.ts.map +1 -1
- package/dist/core/scheduler.js +101 -3
- package/dist/core/scheduler.js.map +1 -1
- package/dist/core/types.d.ts +14 -0
- package/dist/core/types.d.ts.map +1 -1
- package/dist/core/types.js.map +1 -1
- package/dist/core/worker.d.ts +202 -0
- package/dist/core/worker.d.ts.map +1 -0
- package/dist/core/worker.js +477 -0
- package/dist/core/worker.js.map +1 -0
- package/dist/edgeflow.browser.js +9770 -4383
- package/dist/edgeflow.browser.js.map +4 -4
- package/dist/edgeflow.browser.min.js +435 -5
- package/dist/edgeflow.browser.min.js.map +4 -4
- package/dist/index.d.ts +7 -4
- package/dist/index.d.ts.map +1 -1
- package/dist/index.js +28 -10
- package/dist/index.js.map +1 -1
- package/dist/pipelines/automatic-speech-recognition.d.ts +63 -0
- package/dist/pipelines/automatic-speech-recognition.d.ts.map +1 -0
- package/dist/pipelines/automatic-speech-recognition.js +269 -0
- package/dist/pipelines/automatic-speech-recognition.js.map +1 -0
- package/dist/pipelines/base.d.ts +6 -1
- package/dist/pipelines/base.d.ts.map +1 -1
- package/dist/pipelines/base.js +12 -2
- package/dist/pipelines/base.js.map +1 -1
- package/dist/pipelines/feature-extraction.d.ts +5 -40
- package/dist/pipelines/feature-extraction.d.ts.map +1 -1
- package/dist/pipelines/feature-extraction.js +44 -63
- package/dist/pipelines/feature-extraction.js.map +1 -1
- package/dist/pipelines/image-classification.d.ts +4 -36
- package/dist/pipelines/image-classification.d.ts.map +1 -1
- package/dist/pipelines/image-classification.js +22 -60
- package/dist/pipelines/image-classification.js.map +1 -1
- package/dist/pipelines/image-segmentation.d.ts +221 -0
- package/dist/pipelines/image-segmentation.d.ts.map +1 -0
- package/dist/pipelines/image-segmentation.js +535 -0
- package/dist/pipelines/image-segmentation.js.map +1 -0
- package/dist/pipelines/index.d.ts +18 -0
- package/dist/pipelines/index.d.ts.map +1 -1
- package/dist/pipelines/index.js +51 -2
- package/dist/pipelines/index.js.map +1 -1
- package/dist/pipelines/object-detection.d.ts +44 -0
- package/dist/pipelines/object-detection.d.ts.map +1 -0
- package/dist/pipelines/object-detection.js +218 -0
- package/dist/pipelines/object-detection.js.map +1 -0
- package/dist/pipelines/question-answering.d.ts +41 -0
- package/dist/pipelines/question-answering.d.ts.map +1 -0
- package/dist/pipelines/question-answering.js +164 -0
- package/dist/pipelines/question-answering.js.map +1 -0
- package/dist/pipelines/text-classification.d.ts +3 -39
- package/dist/pipelines/text-classification.d.ts.map +1 -1
- package/dist/pipelines/text-classification.js +29 -67
- package/dist/pipelines/text-classification.js.map +1 -1
- package/dist/pipelines/text-generation.d.ts +281 -0
- package/dist/pipelines/text-generation.d.ts.map +1 -0
- package/dist/pipelines/text-generation.js +766 -0
- package/dist/pipelines/text-generation.js.map +1 -0
- package/dist/pipelines/zero-shot-classification.d.ts +45 -0
- package/dist/pipelines/zero-shot-classification.d.ts.map +1 -0
- package/dist/pipelines/zero-shot-classification.js +140 -0
- package/dist/pipelines/zero-shot-classification.js.map +1 -0
- package/dist/tools/benchmark.d.ts +92 -0
- package/dist/tools/benchmark.d.ts.map +1 -0
- package/dist/tools/benchmark.js +213 -0
- package/dist/tools/benchmark.js.map +1 -0
- package/dist/tools/debugger.d.ts +258 -0
- package/dist/tools/debugger.d.ts.map +1 -0
- package/dist/tools/debugger.js +624 -0
- package/dist/tools/debugger.js.map +1 -0
- package/dist/tools/index.d.ts +8 -0
- package/dist/tools/index.d.ts.map +1 -1
- package/dist/tools/index.js +16 -0
- package/dist/tools/index.js.map +1 -1
- package/dist/tools/monitor.d.ts +284 -0
- package/dist/tools/monitor.d.ts.map +1 -0
- package/dist/tools/monitor.js +921 -0
- package/dist/tools/monitor.js.map +1 -0
- package/dist/tools/quantization.d.ts +235 -0
- package/dist/tools/quantization.d.ts.map +1 -0
- package/dist/tools/quantization.js +830 -0
- package/dist/tools/quantization.js.map +1 -0
- package/dist/utils/hub.d.ts +162 -0
- package/dist/utils/hub.d.ts.map +1 -0
- package/dist/utils/hub.js +311 -0
- package/dist/utils/hub.js.map +1 -0
- package/dist/utils/index.d.ts +3 -1
- package/dist/utils/index.d.ts.map +1 -1
- package/dist/utils/index.js +5 -1
- package/dist/utils/index.js.map +1 -1
- package/dist/utils/model-loader.d.ts.map +1 -1
- package/dist/utils/model-loader.js +106 -30
- package/dist/utils/model-loader.js.map +1 -1
- package/dist/utils/offline.d.ts +147 -0
- package/dist/utils/offline.d.ts.map +1 -0
- package/dist/utils/offline.js +405 -0
- package/dist/utils/offline.js.map +1 -0
- package/dist/utils/preprocessor.d.ts +82 -6
- package/dist/utils/preprocessor.d.ts.map +1 -1
- package/dist/utils/preprocessor.js +278 -21
- package/dist/utils/preprocessor.js.map +1 -1
- package/dist/utils/tokenizer.d.ts +197 -72
- package/dist/utils/tokenizer.d.ts.map +1 -1
- package/dist/utils/tokenizer.js +558 -274
- package/dist/utils/tokenizer.js.map +1 -1
- package/package.json +26 -11
|
@@ -0,0 +1,830 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* edgeFlow.js - Model Compression & Quantization Tools
|
|
3
|
+
*
|
|
4
|
+
* In-browser model quantization and compression utilities.
|
|
5
|
+
* Supports dynamic quantization (no calibration data needed).
|
|
6
|
+
*/
|
|
7
|
+
import { EdgeFlowTensor } from '../core/index.js';
|
|
8
|
+
// ============================================================================
|
|
9
|
+
// Quantization Core
|
|
10
|
+
// ============================================================================
|
|
11
|
+
/**
|
|
12
|
+
* Calculate quantization parameters for a tensor
|
|
13
|
+
*/
|
|
14
|
+
function calculateQuantParams(data, bits, symmetric, perChannel, channelAxis = 0, shape = []) {
|
|
15
|
+
const qmin = symmetric ? -(1 << (bits - 1)) : 0;
|
|
16
|
+
const qmax = symmetric ? (1 << (bits - 1)) - 1 : (1 << bits) - 1;
|
|
17
|
+
if (perChannel && shape.length > 1) {
|
|
18
|
+
// Per-channel quantization
|
|
19
|
+
const numChannels = shape[channelAxis] ?? 1;
|
|
20
|
+
const scales = new Float32Array(numChannels);
|
|
21
|
+
const zeroPoints = new Int32Array(numChannels);
|
|
22
|
+
const channelSize = data.length / numChannels;
|
|
23
|
+
let globalMin = Infinity;
|
|
24
|
+
let globalMax = -Infinity;
|
|
25
|
+
for (let c = 0; c < numChannels; c++) {
|
|
26
|
+
let min = Infinity;
|
|
27
|
+
let max = -Infinity;
|
|
28
|
+
for (let i = 0; i < channelSize; i++) {
|
|
29
|
+
const idx = c * channelSize + i;
|
|
30
|
+
const val = data[idx] ?? 0;
|
|
31
|
+
min = Math.min(min, val);
|
|
32
|
+
max = Math.max(max, val);
|
|
33
|
+
}
|
|
34
|
+
globalMin = Math.min(globalMin, min);
|
|
35
|
+
globalMax = Math.max(globalMax, max);
|
|
36
|
+
if (symmetric) {
|
|
37
|
+
const absMax = Math.max(Math.abs(min), Math.abs(max));
|
|
38
|
+
scales[c] = absMax / qmax;
|
|
39
|
+
zeroPoints[c] = 0;
|
|
40
|
+
}
|
|
41
|
+
else {
|
|
42
|
+
scales[c] = (max - min) / (qmax - qmin);
|
|
43
|
+
zeroPoints[c] = Math.round(qmin - min / (scales[c] || 1));
|
|
44
|
+
}
|
|
45
|
+
// Avoid division by zero
|
|
46
|
+
if (scales[c] === 0)
|
|
47
|
+
scales[c] = 1;
|
|
48
|
+
}
|
|
49
|
+
return { scale: scales, zeroPoint: zeroPoints, min: globalMin, max: globalMax };
|
|
50
|
+
}
|
|
51
|
+
else {
|
|
52
|
+
// Per-tensor quantization
|
|
53
|
+
let min = Infinity;
|
|
54
|
+
let max = -Infinity;
|
|
55
|
+
for (let i = 0; i < data.length; i++) {
|
|
56
|
+
const val = data[i] ?? 0;
|
|
57
|
+
min = Math.min(min, val);
|
|
58
|
+
max = Math.max(max, val);
|
|
59
|
+
}
|
|
60
|
+
let scale;
|
|
61
|
+
let zeroPoint;
|
|
62
|
+
if (symmetric) {
|
|
63
|
+
const absMax = Math.max(Math.abs(min), Math.abs(max));
|
|
64
|
+
scale = absMax / qmax;
|
|
65
|
+
zeroPoint = 0;
|
|
66
|
+
}
|
|
67
|
+
else {
|
|
68
|
+
scale = (max - min) / (qmax - qmin);
|
|
69
|
+
zeroPoint = Math.round(qmin - min / (scale || 1));
|
|
70
|
+
}
|
|
71
|
+
// Avoid division by zero
|
|
72
|
+
if (scale === 0)
|
|
73
|
+
scale = 1;
|
|
74
|
+
return { scale, zeroPoint, min, max };
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
/**
|
|
78
|
+
* Quantize float32 data to int8
|
|
79
|
+
*/
|
|
80
|
+
function quantizeToInt8(data, scale, zeroPoint, perChannel, channelSize = data.length) {
|
|
81
|
+
const result = new Int8Array(data.length);
|
|
82
|
+
if (perChannel && scale instanceof Float32Array) {
|
|
83
|
+
const numChannels = scale.length;
|
|
84
|
+
for (let c = 0; c < numChannels; c++) {
|
|
85
|
+
const s = scale[c] ?? 1;
|
|
86
|
+
const zp = zeroPoint[c] ?? 0;
|
|
87
|
+
for (let i = 0; i < channelSize; i++) {
|
|
88
|
+
const idx = c * channelSize + i;
|
|
89
|
+
const val = data[idx] ?? 0;
|
|
90
|
+
result[idx] = Math.max(-128, Math.min(127, Math.round(val / s + zp)));
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
else {
|
|
95
|
+
const s = scale;
|
|
96
|
+
const zp = zeroPoint;
|
|
97
|
+
for (let i = 0; i < data.length; i++) {
|
|
98
|
+
const val = data[i] ?? 0;
|
|
99
|
+
result[i] = Math.max(-128, Math.min(127, Math.round(val / s + zp)));
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
return result;
|
|
103
|
+
}
|
|
104
|
+
/**
|
|
105
|
+
* Quantize float32 data to uint8
|
|
106
|
+
*/
|
|
107
|
+
function quantizeToUint8(data, scale, zeroPoint, perChannel, channelSize = data.length) {
|
|
108
|
+
const result = new Uint8Array(data.length);
|
|
109
|
+
if (perChannel && scale instanceof Float32Array) {
|
|
110
|
+
const numChannels = scale.length;
|
|
111
|
+
for (let c = 0; c < numChannels; c++) {
|
|
112
|
+
const s = scale[c] ?? 1;
|
|
113
|
+
const zp = zeroPoint[c] ?? 0;
|
|
114
|
+
for (let i = 0; i < channelSize; i++) {
|
|
115
|
+
const idx = c * channelSize + i;
|
|
116
|
+
const val = data[idx] ?? 0;
|
|
117
|
+
result[idx] = Math.max(0, Math.min(255, Math.round(val / s + zp)));
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
else {
|
|
122
|
+
const s = scale;
|
|
123
|
+
const zp = zeroPoint;
|
|
124
|
+
for (let i = 0; i < data.length; i++) {
|
|
125
|
+
const val = data[i] ?? 0;
|
|
126
|
+
result[i] = Math.max(0, Math.min(255, Math.round(val / s + zp)));
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
return result;
|
|
130
|
+
}
|
|
131
|
+
/**
|
|
132
|
+
* Quantize float32 data to int4 (packed as uint8, 2 values per byte)
|
|
133
|
+
*/
|
|
134
|
+
function quantizeToInt4(data, scale, zeroPoint) {
|
|
135
|
+
const packedLength = Math.ceil(data.length / 2);
|
|
136
|
+
const result = new Uint8Array(packedLength);
|
|
137
|
+
for (let i = 0; i < data.length; i += 2) {
|
|
138
|
+
const val1 = data[i] ?? 0;
|
|
139
|
+
const val2 = data[i + 1] ?? 0;
|
|
140
|
+
// Quantize to range [-8, 7] then shift to [0, 15]
|
|
141
|
+
const q1 = Math.max(0, Math.min(15, Math.round(val1 / scale + zeroPoint + 8)));
|
|
142
|
+
const q2 = Math.max(0, Math.min(15, Math.round(val2 / scale + zeroPoint + 8)));
|
|
143
|
+
// Pack two 4-bit values into one byte
|
|
144
|
+
result[i >> 1] = (q1 << 4) | q2;
|
|
145
|
+
}
|
|
146
|
+
return result;
|
|
147
|
+
}
|
|
148
|
+
/**
|
|
149
|
+
* Convert float32 to float16 (stored in Uint16Array)
|
|
150
|
+
*/
|
|
151
|
+
function quantizeToFloat16(data) {
|
|
152
|
+
const result = new Uint16Array(data.length);
|
|
153
|
+
for (let i = 0; i < data.length; i++) {
|
|
154
|
+
result[i] = float32ToFloat16(data[i] ?? 0);
|
|
155
|
+
}
|
|
156
|
+
return result;
|
|
157
|
+
}
|
|
158
|
+
/**
|
|
159
|
+
* Convert a single float32 value to float16 bits
|
|
160
|
+
*/
|
|
161
|
+
function float32ToFloat16(value) {
|
|
162
|
+
const float32View = new Float32Array(1);
|
|
163
|
+
const int32View = new Int32Array(float32View.buffer);
|
|
164
|
+
float32View[0] = value;
|
|
165
|
+
const f = int32View[0];
|
|
166
|
+
const sign = (f >> 16) & 0x8000;
|
|
167
|
+
const exponent = ((f >> 23) & 0xff) - 127 + 15;
|
|
168
|
+
const mantissa = f & 0x7fffff;
|
|
169
|
+
if (exponent <= 0) {
|
|
170
|
+
// Denormalized or zero
|
|
171
|
+
if (exponent < -10) {
|
|
172
|
+
return sign;
|
|
173
|
+
}
|
|
174
|
+
const m = (mantissa | 0x800000) >> (1 - exponent);
|
|
175
|
+
return sign | (m >> 13);
|
|
176
|
+
}
|
|
177
|
+
else if (exponent >= 31) {
|
|
178
|
+
// Overflow to infinity
|
|
179
|
+
return sign | 0x7c00;
|
|
180
|
+
}
|
|
181
|
+
return sign | (exponent << 10) | (mantissa >> 13);
|
|
182
|
+
}
|
|
183
|
+
/**
|
|
184
|
+
* Dequantize int8 data back to float32
|
|
185
|
+
*/
|
|
186
|
+
export function dequantizeInt8(data, scale, zeroPoint, perChannel = false, channelSize = data.length) {
|
|
187
|
+
const result = new Float32Array(data.length);
|
|
188
|
+
if (perChannel && scale instanceof Float32Array) {
|
|
189
|
+
const numChannels = scale.length;
|
|
190
|
+
for (let c = 0; c < numChannels; c++) {
|
|
191
|
+
const s = scale[c] ?? 1;
|
|
192
|
+
const zp = zeroPoint[c] ?? 0;
|
|
193
|
+
for (let i = 0; i < channelSize; i++) {
|
|
194
|
+
const idx = c * channelSize + i;
|
|
195
|
+
result[idx] = ((data[idx] ?? 0) - zp) * s;
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
else {
|
|
200
|
+
const s = scale;
|
|
201
|
+
const zp = zeroPoint;
|
|
202
|
+
for (let i = 0; i < data.length; i++) {
|
|
203
|
+
result[i] = ((data[i] ?? 0) - zp) * s;
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
return result;
|
|
207
|
+
}
|
|
208
|
+
/**
|
|
209
|
+
* Dequantize uint8 data back to float32
|
|
210
|
+
*/
|
|
211
|
+
export function dequantizeUint8(data, scale, zeroPoint, perChannel = false, channelSize = data.length) {
|
|
212
|
+
const result = new Float32Array(data.length);
|
|
213
|
+
if (perChannel && scale instanceof Float32Array) {
|
|
214
|
+
const numChannels = scale.length;
|
|
215
|
+
for (let c = 0; c < numChannels; c++) {
|
|
216
|
+
const s = scale[c] ?? 1;
|
|
217
|
+
const zp = zeroPoint[c] ?? 0;
|
|
218
|
+
for (let i = 0; i < channelSize; i++) {
|
|
219
|
+
const idx = c * channelSize + i;
|
|
220
|
+
result[idx] = ((data[idx] ?? 0) - zp) * s;
|
|
221
|
+
}
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
else {
|
|
225
|
+
const s = scale;
|
|
226
|
+
const zp = zeroPoint;
|
|
227
|
+
for (let i = 0; i < data.length; i++) {
|
|
228
|
+
result[i] = ((data[i] ?? 0) - zp) * s;
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
return result;
|
|
232
|
+
}
|
|
233
|
+
/**
|
|
234
|
+
* Convert float16 bits back to float32
|
|
235
|
+
*/
|
|
236
|
+
export function float16ToFloat32(value) {
|
|
237
|
+
const sign = (value & 0x8000) >> 15;
|
|
238
|
+
const exponent = (value & 0x7c00) >> 10;
|
|
239
|
+
const mantissa = value & 0x03ff;
|
|
240
|
+
if (exponent === 0) {
|
|
241
|
+
if (mantissa === 0) {
|
|
242
|
+
return sign === 0 ? 0 : -0;
|
|
243
|
+
}
|
|
244
|
+
// Denormalized
|
|
245
|
+
return (sign === 0 ? 1 : -1) * Math.pow(2, -14) * (mantissa / 1024);
|
|
246
|
+
}
|
|
247
|
+
else if (exponent === 31) {
|
|
248
|
+
if (mantissa === 0) {
|
|
249
|
+
return sign === 0 ? Infinity : -Infinity;
|
|
250
|
+
}
|
|
251
|
+
return NaN;
|
|
252
|
+
}
|
|
253
|
+
return (sign === 0 ? 1 : -1) * Math.pow(2, exponent - 15) * (1 + mantissa / 1024);
|
|
254
|
+
}
|
|
255
|
+
/**
|
|
256
|
+
* Dequantize float16 data back to float32
|
|
257
|
+
*/
|
|
258
|
+
export function dequantizeFloat16(data) {
|
|
259
|
+
const result = new Float32Array(data.length);
|
|
260
|
+
for (let i = 0; i < data.length; i++) {
|
|
261
|
+
result[i] = float16ToFloat32(data[i] ?? 0);
|
|
262
|
+
}
|
|
263
|
+
return result;
|
|
264
|
+
}
|
|
265
|
+
/**
|
|
266
|
+
* Parse ONNX model to extract weights
|
|
267
|
+
* Note: This is a simplified parser for demonstration
|
|
268
|
+
*/
|
|
269
|
+
function parseModelWeights(modelData) {
|
|
270
|
+
// Check if it's an ONNX model by magic number
|
|
271
|
+
// const view = new DataView(modelData); // Reserved for future ONNX header parsing
|
|
272
|
+
const weights = [];
|
|
273
|
+
// Simple heuristic: look for float32 arrays in the buffer
|
|
274
|
+
// In a real implementation, we'd use proper ONNX parsing
|
|
275
|
+
const float32Array = new Float32Array(modelData);
|
|
276
|
+
// Create a single weight tensor from the model data
|
|
277
|
+
// This is a placeholder - real implementation would parse ONNX properly
|
|
278
|
+
weights.push({
|
|
279
|
+
name: 'model_weights',
|
|
280
|
+
data: float32Array,
|
|
281
|
+
shape: [float32Array.length],
|
|
282
|
+
dtype: 'float32',
|
|
283
|
+
});
|
|
284
|
+
return weights;
|
|
285
|
+
}
|
|
286
|
+
/**
|
|
287
|
+
* Serialize quantized model to ArrayBuffer
|
|
288
|
+
*/
|
|
289
|
+
function serializeQuantizedModel(model) {
|
|
290
|
+
// Create a simple binary format:
|
|
291
|
+
// Header: version (4 bytes) + type (4 bytes) + originalSize (8 bytes) + numWeights (4 bytes)
|
|
292
|
+
// For each weight: nameLen (4) + name + shapeLen (4) + shape + dtypeLen (4) + dtype +
|
|
293
|
+
// origDtypeLen (4) + origDtype + hasScale (1) + scale + hasZP (1) + zp + dataLen (8) + data
|
|
294
|
+
const encoder = new TextEncoder();
|
|
295
|
+
// Calculate total size
|
|
296
|
+
let totalSize = 20; // Header
|
|
297
|
+
for (const weight of model.weights) {
|
|
298
|
+
const nameBytes = encoder.encode(weight.name);
|
|
299
|
+
const dtypeBytes = encoder.encode(weight.dtype);
|
|
300
|
+
const origDtypeBytes = encoder.encode(weight.originalDtype);
|
|
301
|
+
totalSize += 4 + nameBytes.length; // name
|
|
302
|
+
totalSize += 4 + weight.shape.length * 4; // shape
|
|
303
|
+
totalSize += 4 + dtypeBytes.length; // dtype
|
|
304
|
+
totalSize += 4 + origDtypeBytes.length; // originalDtype
|
|
305
|
+
totalSize += 1; // hasScale
|
|
306
|
+
if (weight.scale !== undefined) {
|
|
307
|
+
totalSize += Array.isArray(weight.scale) ? 4 + weight.scale.length * 4 : 4;
|
|
308
|
+
}
|
|
309
|
+
totalSize += 1; // hasZeroPoint
|
|
310
|
+
if (weight.zeroPoint !== undefined) {
|
|
311
|
+
totalSize += Array.isArray(weight.zeroPoint) ? 4 + weight.zeroPoint.length * 4 : 4;
|
|
312
|
+
}
|
|
313
|
+
totalSize += 8 + weight.data.byteLength; // data
|
|
314
|
+
}
|
|
315
|
+
const buffer = new ArrayBuffer(totalSize);
|
|
316
|
+
const view = new DataView(buffer);
|
|
317
|
+
const uint8 = new Uint8Array(buffer);
|
|
318
|
+
let offset = 0;
|
|
319
|
+
// Write header
|
|
320
|
+
view.setUint32(offset, model.version, true);
|
|
321
|
+
offset += 4;
|
|
322
|
+
view.setUint32(offset, ['int8', 'uint8', 'int4', 'float16', 'dynamic'].indexOf(model.quantizationType), true);
|
|
323
|
+
offset += 4;
|
|
324
|
+
// Write originalSize as two 32-bit integers (for 64-bit compatibility)
|
|
325
|
+
view.setUint32(offset, model.originalSize & 0xFFFFFFFF, true);
|
|
326
|
+
offset += 4;
|
|
327
|
+
view.setUint32(offset, (model.originalSize / 0x100000000) >>> 0, true);
|
|
328
|
+
offset += 4;
|
|
329
|
+
view.setUint32(offset, model.weights.length, true);
|
|
330
|
+
offset += 4;
|
|
331
|
+
// Write weights
|
|
332
|
+
for (const weight of model.weights) {
|
|
333
|
+
const nameBytes = encoder.encode(weight.name);
|
|
334
|
+
const dtypeBytes = encoder.encode(weight.dtype);
|
|
335
|
+
const origDtypeBytes = encoder.encode(weight.originalDtype);
|
|
336
|
+
// Name
|
|
337
|
+
view.setUint32(offset, nameBytes.length, true);
|
|
338
|
+
offset += 4;
|
|
339
|
+
uint8.set(nameBytes, offset);
|
|
340
|
+
offset += nameBytes.length;
|
|
341
|
+
// Shape
|
|
342
|
+
view.setUint32(offset, weight.shape.length, true);
|
|
343
|
+
offset += 4;
|
|
344
|
+
for (const dim of weight.shape) {
|
|
345
|
+
view.setInt32(offset, dim, true);
|
|
346
|
+
offset += 4;
|
|
347
|
+
}
|
|
348
|
+
// Dtype
|
|
349
|
+
view.setUint32(offset, dtypeBytes.length, true);
|
|
350
|
+
offset += 4;
|
|
351
|
+
uint8.set(dtypeBytes, offset);
|
|
352
|
+
offset += dtypeBytes.length;
|
|
353
|
+
// Original dtype
|
|
354
|
+
view.setUint32(offset, origDtypeBytes.length, true);
|
|
355
|
+
offset += 4;
|
|
356
|
+
uint8.set(origDtypeBytes, offset);
|
|
357
|
+
offset += origDtypeBytes.length;
|
|
358
|
+
// Scale
|
|
359
|
+
if (weight.scale !== undefined) {
|
|
360
|
+
view.setUint8(offset, 1);
|
|
361
|
+
offset += 1;
|
|
362
|
+
if (Array.isArray(weight.scale)) {
|
|
363
|
+
view.setUint32(offset, weight.scale.length, true);
|
|
364
|
+
offset += 4;
|
|
365
|
+
for (const s of weight.scale) {
|
|
366
|
+
view.setFloat32(offset, s, true);
|
|
367
|
+
offset += 4;
|
|
368
|
+
}
|
|
369
|
+
}
|
|
370
|
+
else {
|
|
371
|
+
view.setUint32(offset, 1, true);
|
|
372
|
+
offset += 4;
|
|
373
|
+
view.setFloat32(offset, weight.scale, true);
|
|
374
|
+
offset += 4;
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
else {
|
|
378
|
+
view.setUint8(offset, 0);
|
|
379
|
+
offset += 1;
|
|
380
|
+
}
|
|
381
|
+
// Zero point
|
|
382
|
+
if (weight.zeroPoint !== undefined) {
|
|
383
|
+
view.setUint8(offset, 1);
|
|
384
|
+
offset += 1;
|
|
385
|
+
if (Array.isArray(weight.zeroPoint)) {
|
|
386
|
+
view.setUint32(offset, weight.zeroPoint.length, true);
|
|
387
|
+
offset += 4;
|
|
388
|
+
for (const zp of weight.zeroPoint) {
|
|
389
|
+
view.setInt32(offset, zp, true);
|
|
390
|
+
offset += 4;
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
else {
|
|
394
|
+
view.setUint32(offset, 1, true);
|
|
395
|
+
offset += 4;
|
|
396
|
+
view.setInt32(offset, weight.zeroPoint, true);
|
|
397
|
+
offset += 4;
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
else {
|
|
401
|
+
view.setUint8(offset, 0);
|
|
402
|
+
offset += 1;
|
|
403
|
+
}
|
|
404
|
+
// Data
|
|
405
|
+
const dataLow = weight.data.byteLength & 0xFFFFFFFF;
|
|
406
|
+
const dataHigh = (weight.data.byteLength / 0x100000000) >>> 0;
|
|
407
|
+
view.setUint32(offset, dataLow, true);
|
|
408
|
+
offset += 4;
|
|
409
|
+
view.setUint32(offset, dataHigh, true);
|
|
410
|
+
offset += 4;
|
|
411
|
+
uint8.set(new Uint8Array(weight.data), offset);
|
|
412
|
+
offset += weight.data.byteLength;
|
|
413
|
+
}
|
|
414
|
+
return buffer;
|
|
415
|
+
}
|
|
416
|
+
/**
|
|
417
|
+
* Quantize a model
|
|
418
|
+
*/
|
|
419
|
+
export async function quantizeModel(modelData, options) {
|
|
420
|
+
const { type, skipPatterns = [], perChannel = false, symmetric = true, onProgress, minTensorSize = 100, } = options;
|
|
421
|
+
const originalSize = modelData.byteLength;
|
|
422
|
+
const layerStats = [];
|
|
423
|
+
let tensorsQuantized = 0;
|
|
424
|
+
let tensorsSkipped = 0;
|
|
425
|
+
// Parse model weights
|
|
426
|
+
onProgress?.({ stage: 'analyzing', current: 0, total: 1, percent: 0 });
|
|
427
|
+
const weights = parseModelWeights(modelData);
|
|
428
|
+
const quantizedWeights = [];
|
|
429
|
+
let totalParams = 0;
|
|
430
|
+
let quantizedParams = 0;
|
|
431
|
+
const scales = [];
|
|
432
|
+
// Quantize each weight tensor
|
|
433
|
+
for (let i = 0; i < weights.length; i++) {
|
|
434
|
+
const weight = weights[i];
|
|
435
|
+
const percent = ((i + 1) / weights.length) * 100;
|
|
436
|
+
onProgress?.({
|
|
437
|
+
stage: 'quantizing',
|
|
438
|
+
current: i + 1,
|
|
439
|
+
total: weights.length,
|
|
440
|
+
percent,
|
|
441
|
+
layerName: weight.name,
|
|
442
|
+
});
|
|
443
|
+
totalParams += weight.data.length;
|
|
444
|
+
// Check if should skip
|
|
445
|
+
const shouldSkip = weight.data.length < minTensorSize ||
|
|
446
|
+
skipPatterns.some(pattern => {
|
|
447
|
+
if (typeof pattern === 'string') {
|
|
448
|
+
return weight.name.includes(pattern);
|
|
449
|
+
}
|
|
450
|
+
return pattern.test(weight.name);
|
|
451
|
+
});
|
|
452
|
+
if (shouldSkip) {
|
|
453
|
+
tensorsSkipped++;
|
|
454
|
+
layerStats.push({
|
|
455
|
+
name: weight.name,
|
|
456
|
+
originalDtype: weight.dtype,
|
|
457
|
+
quantizedDtype: weight.dtype,
|
|
458
|
+
originalSize: weight.data.byteLength,
|
|
459
|
+
quantizedSize: weight.data.byteLength,
|
|
460
|
+
scale: 1,
|
|
461
|
+
zeroPoint: 0,
|
|
462
|
+
minValue: Math.min(...weight.data),
|
|
463
|
+
maxValue: Math.max(...weight.data),
|
|
464
|
+
skipped: true,
|
|
465
|
+
skipReason: weight.data.length < minTensorSize
|
|
466
|
+
? 'Tensor too small'
|
|
467
|
+
: 'Matched skip pattern',
|
|
468
|
+
});
|
|
469
|
+
quantizedWeights.push({
|
|
470
|
+
name: weight.name,
|
|
471
|
+
data: weight.data.buffer.slice(0),
|
|
472
|
+
shape: weight.shape,
|
|
473
|
+
dtype: weight.dtype,
|
|
474
|
+
originalDtype: weight.dtype,
|
|
475
|
+
});
|
|
476
|
+
continue;
|
|
477
|
+
}
|
|
478
|
+
// Calculate quantization parameters
|
|
479
|
+
const bits = type === 'int4' ? 4 : 8;
|
|
480
|
+
const params = calculateQuantParams(weight.data, bits, symmetric, perChannel, 0, weight.shape);
|
|
481
|
+
// Quantize data
|
|
482
|
+
let quantizedData;
|
|
483
|
+
let quantizedDtype;
|
|
484
|
+
switch (type) {
|
|
485
|
+
case 'int8':
|
|
486
|
+
const int8Data = quantizeToInt8(weight.data, params.scale, params.zeroPoint, perChannel, perChannel ? weight.data.length / (weight.shape[0] ?? 1) : weight.data.length);
|
|
487
|
+
quantizedData = int8Data.buffer.slice(0);
|
|
488
|
+
quantizedDtype = 'int8';
|
|
489
|
+
break;
|
|
490
|
+
case 'uint8':
|
|
491
|
+
const uint8Data = quantizeToUint8(weight.data, params.scale, params.zeroPoint, perChannel, perChannel ? weight.data.length / (weight.shape[0] ?? 1) : weight.data.length);
|
|
492
|
+
quantizedData = uint8Data.buffer.slice(0);
|
|
493
|
+
quantizedDtype = 'uint8';
|
|
494
|
+
break;
|
|
495
|
+
case 'int4':
|
|
496
|
+
const int4Data = quantizeToInt4(weight.data, params.scale, params.zeroPoint);
|
|
497
|
+
quantizedData = int4Data.buffer.slice(0);
|
|
498
|
+
quantizedDtype = 'int4';
|
|
499
|
+
break;
|
|
500
|
+
case 'float16':
|
|
501
|
+
const fp16Data = quantizeToFloat16(weight.data);
|
|
502
|
+
quantizedData = fp16Data.buffer.slice(0);
|
|
503
|
+
quantizedDtype = 'float16';
|
|
504
|
+
break;
|
|
505
|
+
case 'dynamic':
|
|
506
|
+
default:
|
|
507
|
+
// Dynamic quantization: use int8 for weights
|
|
508
|
+
const dynData = quantizeToInt8(weight.data, params.scale, params.zeroPoint, perChannel, perChannel ? weight.data.length / (weight.shape[0] ?? 1) : weight.data.length);
|
|
509
|
+
quantizedData = dynData.buffer.slice(0);
|
|
510
|
+
quantizedDtype = 'int8';
|
|
511
|
+
break;
|
|
512
|
+
}
|
|
513
|
+
tensorsQuantized++;
|
|
514
|
+
quantizedParams += weight.data.length;
|
|
515
|
+
const scaleValue = params.scale instanceof Float32Array
|
|
516
|
+
? Array.from(params.scale)
|
|
517
|
+
: params.scale;
|
|
518
|
+
const zpValue = params.zeroPoint instanceof Int32Array
|
|
519
|
+
? Array.from(params.zeroPoint)
|
|
520
|
+
: params.zeroPoint;
|
|
521
|
+
if (typeof scaleValue === 'number') {
|
|
522
|
+
scales.push(scaleValue);
|
|
523
|
+
}
|
|
524
|
+
else {
|
|
525
|
+
scales.push(...scaleValue);
|
|
526
|
+
}
|
|
527
|
+
layerStats.push({
|
|
528
|
+
name: weight.name,
|
|
529
|
+
originalDtype: weight.dtype,
|
|
530
|
+
quantizedDtype,
|
|
531
|
+
originalSize: weight.data.byteLength,
|
|
532
|
+
quantizedSize: quantizedData.byteLength,
|
|
533
|
+
scale: scaleValue,
|
|
534
|
+
zeroPoint: zpValue,
|
|
535
|
+
minValue: params.min,
|
|
536
|
+
maxValue: params.max,
|
|
537
|
+
skipped: false,
|
|
538
|
+
});
|
|
539
|
+
quantizedWeights.push({
|
|
540
|
+
name: weight.name,
|
|
541
|
+
data: quantizedData,
|
|
542
|
+
shape: weight.shape,
|
|
543
|
+
dtype: quantizedDtype,
|
|
544
|
+
originalDtype: weight.dtype,
|
|
545
|
+
scale: scaleValue,
|
|
546
|
+
zeroPoint: zpValue,
|
|
547
|
+
});
|
|
548
|
+
}
|
|
549
|
+
// Pack into final format
|
|
550
|
+
onProgress?.({ stage: 'packing', current: 0, total: 1, percent: 0 });
|
|
551
|
+
const quantizedModel = {
|
|
552
|
+
version: 1,
|
|
553
|
+
quantizationType: type,
|
|
554
|
+
originalSize,
|
|
555
|
+
weights: quantizedWeights,
|
|
556
|
+
};
|
|
557
|
+
const quantizedData = serializeQuantizedModel(quantizedModel);
|
|
558
|
+
onProgress?.({ stage: 'complete', current: 1, total: 1, percent: 100 });
|
|
559
|
+
// Calculate statistics
|
|
560
|
+
const avgScale = scales.length > 0
|
|
561
|
+
? scales.reduce((a, b) => a + b, 0) / scales.length
|
|
562
|
+
: 1;
|
|
563
|
+
const minScale = scales.length > 0 ? Math.min(...scales) : 1;
|
|
564
|
+
const maxScale = scales.length > 0 ? Math.max(...scales) : 1;
|
|
565
|
+
// Estimate quantization error (very rough approximation)
|
|
566
|
+
const bitsReduction = type === 'int4' ? 8 : type === 'float16' ? 2 : 4;
|
|
567
|
+
const errorEstimate = avgScale / bitsReduction;
|
|
568
|
+
return {
|
|
569
|
+
data: quantizedData,
|
|
570
|
+
originalSize,
|
|
571
|
+
quantizedSize: quantizedData.byteLength,
|
|
572
|
+
compressionRatio: originalSize / quantizedData.byteLength,
|
|
573
|
+
tensorsQuantized,
|
|
574
|
+
tensorsSkipped,
|
|
575
|
+
layerStats,
|
|
576
|
+
stats: {
|
|
577
|
+
totalParameters: totalParams,
|
|
578
|
+
quantizedParameters: quantizedParams,
|
|
579
|
+
averageScale: avgScale,
|
|
580
|
+
minScale,
|
|
581
|
+
maxScale,
|
|
582
|
+
errorEstimate,
|
|
583
|
+
},
|
|
584
|
+
};
|
|
585
|
+
}
|
|
586
|
+
// ============================================================================
|
|
587
|
+
// Tensor Quantization (for individual tensors)
|
|
588
|
+
// ============================================================================
|
|
589
|
+
/**
|
|
590
|
+
* Quantize a single EdgeFlowTensor
|
|
591
|
+
*/
|
|
592
|
+
export function quantizeTensor(tensor, type, options = {}) {
|
|
593
|
+
const { symmetric = true, perChannel = false } = options;
|
|
594
|
+
const data = tensor.toFloat32Array();
|
|
595
|
+
const shape = tensor.shape;
|
|
596
|
+
const bits = type === 'int4' ? 4 : 8;
|
|
597
|
+
const params = calculateQuantParams(data, bits, symmetric, perChannel, 0, shape);
|
|
598
|
+
let quantizedData;
|
|
599
|
+
let dtype;
|
|
600
|
+
switch (type) {
|
|
601
|
+
case 'int8':
|
|
602
|
+
quantizedData = quantizeToInt8(data, params.scale, params.zeroPoint, perChannel);
|
|
603
|
+
dtype = 'int32'; // Store as int32 since we don't have int8 dtype
|
|
604
|
+
break;
|
|
605
|
+
case 'uint8':
|
|
606
|
+
quantizedData = quantizeToUint8(data, params.scale, params.zeroPoint, perChannel);
|
|
607
|
+
dtype = 'int32';
|
|
608
|
+
break;
|
|
609
|
+
case 'float16':
|
|
610
|
+
quantizedData = quantizeToFloat16(data);
|
|
611
|
+
dtype = 'float32'; // Will be stored differently
|
|
612
|
+
break;
|
|
613
|
+
default:
|
|
614
|
+
quantizedData = quantizeToInt8(data, params.scale, params.zeroPoint, perChannel);
|
|
615
|
+
dtype = 'int32';
|
|
616
|
+
}
|
|
617
|
+
const scaleValue = params.scale instanceof Float32Array
|
|
618
|
+
? Array.from(params.scale)
|
|
619
|
+
: params.scale;
|
|
620
|
+
const zpValue = params.zeroPoint instanceof Int32Array
|
|
621
|
+
? Array.from(params.zeroPoint)
|
|
622
|
+
: params.zeroPoint;
|
|
623
|
+
return {
|
|
624
|
+
tensor: new EdgeFlowTensor(Array.from(quantizedData), shape, dtype),
|
|
625
|
+
scale: scaleValue,
|
|
626
|
+
zeroPoint: zpValue,
|
|
627
|
+
};
|
|
628
|
+
}
|
|
629
|
+
/**
|
|
630
|
+
* Dequantize a tensor back to float32
|
|
631
|
+
*/
|
|
632
|
+
export function dequantizeTensor(tensor, scale, zeroPoint, type) {
|
|
633
|
+
const data = tensor.toArray();
|
|
634
|
+
const shape = tensor.shape;
|
|
635
|
+
let dequantizedData;
|
|
636
|
+
const scaleArr = Array.isArray(scale) ? new Float32Array(scale) : scale;
|
|
637
|
+
const zpArr = Array.isArray(zeroPoint) ? new Int32Array(zeroPoint) : zeroPoint;
|
|
638
|
+
const perChannel = Array.isArray(scale);
|
|
639
|
+
switch (type) {
|
|
640
|
+
case 'int8':
|
|
641
|
+
dequantizedData = dequantizeInt8(new Int8Array(data.map(Number)), scaleArr, zpArr, perChannel);
|
|
642
|
+
break;
|
|
643
|
+
case 'uint8':
|
|
644
|
+
dequantizedData = dequantizeUint8(new Uint8Array(data.map(Number)), scaleArr, zpArr, perChannel);
|
|
645
|
+
break;
|
|
646
|
+
case 'float16':
|
|
647
|
+
dequantizedData = dequantizeFloat16(new Uint16Array(data.map(Number)));
|
|
648
|
+
break;
|
|
649
|
+
default:
|
|
650
|
+
dequantizedData = dequantizeInt8(new Int8Array(data.map(Number)), scaleArr, zpArr, perChannel);
|
|
651
|
+
}
|
|
652
|
+
return new EdgeFlowTensor(Array.from(dequantizedData), shape, 'float32');
|
|
653
|
+
}
|
|
654
|
+
/**
|
|
655
|
+
* Prune a tensor using magnitude-based pruning
|
|
656
|
+
*/
|
|
657
|
+
export function pruneTensor(tensor, options = {}) {
|
|
658
|
+
const { ratio = 0.5, method = 'magnitude', threshold } = options;
|
|
659
|
+
const data = tensor.toFloat32Array();
|
|
660
|
+
const shape = tensor.shape;
|
|
661
|
+
const mask = new Float32Array(data.length);
|
|
662
|
+
const prunedData = new Float32Array(data.length);
|
|
663
|
+
let prunedCount = 0;
|
|
664
|
+
if (method === 'magnitude') {
|
|
665
|
+
// Get threshold based on ratio
|
|
666
|
+
const absValues = Array.from(data).map(Math.abs).sort((a, b) => a - b);
|
|
667
|
+
const thresholdIndex = Math.floor(absValues.length * ratio);
|
|
668
|
+
const computedThreshold = threshold ?? (absValues[thresholdIndex] ?? 0);
|
|
669
|
+
for (let i = 0; i < data.length; i++) {
|
|
670
|
+
if (Math.abs(data[i] ?? 0) > computedThreshold) {
|
|
671
|
+
mask[i] = 1;
|
|
672
|
+
prunedData[i] = data[i] ?? 0;
|
|
673
|
+
}
|
|
674
|
+
else {
|
|
675
|
+
mask[i] = 0;
|
|
676
|
+
prunedData[i] = 0;
|
|
677
|
+
prunedCount++;
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
}
|
|
681
|
+
else if (method === 'random') {
|
|
682
|
+
for (let i = 0; i < data.length; i++) {
|
|
683
|
+
if (Math.random() > ratio) {
|
|
684
|
+
mask[i] = 1;
|
|
685
|
+
prunedData[i] = data[i] ?? 0;
|
|
686
|
+
}
|
|
687
|
+
else {
|
|
688
|
+
mask[i] = 0;
|
|
689
|
+
prunedData[i] = 0;
|
|
690
|
+
prunedCount++;
|
|
691
|
+
}
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
return {
|
|
695
|
+
tensor: new EdgeFlowTensor(Array.from(prunedData), shape, 'float32'),
|
|
696
|
+
mask: new EdgeFlowTensor(Array.from(mask), shape, 'float32'),
|
|
697
|
+
sparsity: prunedCount / data.length,
|
|
698
|
+
};
|
|
699
|
+
}
|
|
700
|
+
/**
|
|
701
|
+
* Prune a model
|
|
702
|
+
*/
|
|
703
|
+
export async function pruneModel(modelData, options = {}) {
|
|
704
|
+
const { onProgress } = options;
|
|
705
|
+
onProgress?.({ current: 0, total: 1, percent: 0 });
|
|
706
|
+
// This is a simplified implementation
|
|
707
|
+
// Real implementation would parse the model properly
|
|
708
|
+
const weights = parseModelWeights(modelData);
|
|
709
|
+
let totalParams = 0;
|
|
710
|
+
let prunedParams = 0;
|
|
711
|
+
for (const weight of weights) {
|
|
712
|
+
totalParams += weight.data.length;
|
|
713
|
+
const tensor = new EdgeFlowTensor(Array.from(weight.data), weight.shape, 'float32');
|
|
714
|
+
const { sparsity } = pruneTensor(tensor, options);
|
|
715
|
+
prunedParams += Math.floor(weight.data.length * sparsity);
|
|
716
|
+
}
|
|
717
|
+
onProgress?.({ current: 1, total: 1, percent: 100 });
|
|
718
|
+
return {
|
|
719
|
+
data: modelData, // In a real implementation, we'd create a sparse format
|
|
720
|
+
originalSize: modelData.byteLength,
|
|
721
|
+
prunedSize: modelData.byteLength, // Would be smaller with sparse format
|
|
722
|
+
sparsity: prunedParams / totalParams,
|
|
723
|
+
parametersPruned: prunedParams,
|
|
724
|
+
totalParameters: totalParams,
|
|
725
|
+
};
|
|
726
|
+
}
|
|
727
|
+
/**
|
|
728
|
+
* Analyze a model
|
|
729
|
+
*/
|
|
730
|
+
export async function analyzeModel(modelData) {
|
|
731
|
+
const weights = parseModelWeights(modelData);
|
|
732
|
+
const totalSize = modelData.byteLength;
|
|
733
|
+
const dtypeBreakdown = {};
|
|
734
|
+
let totalParams = 0;
|
|
735
|
+
const tensorInfos = [];
|
|
736
|
+
for (const weight of weights) {
|
|
737
|
+
totalParams += weight.data.length;
|
|
738
|
+
const bytesPerElement = weight.dtype === 'float32' ? 4
|
|
739
|
+
: weight.dtype === 'float16' ? 2
|
|
740
|
+
: weight.dtype === 'int8' ? 1
|
|
741
|
+
: 4;
|
|
742
|
+
const size = weight.data.length * bytesPerElement;
|
|
743
|
+
if (!dtypeBreakdown[weight.dtype]) {
|
|
744
|
+
dtypeBreakdown[weight.dtype] = { count: 0, size: 0 };
|
|
745
|
+
}
|
|
746
|
+
dtypeBreakdown[weight.dtype].count++;
|
|
747
|
+
dtypeBreakdown[weight.dtype].size += size;
|
|
748
|
+
tensorInfos.push({
|
|
749
|
+
name: weight.name,
|
|
750
|
+
size,
|
|
751
|
+
shape: weight.shape,
|
|
752
|
+
});
|
|
753
|
+
}
|
|
754
|
+
// Sort by size and get top 10
|
|
755
|
+
tensorInfos.sort((a, b) => b.size - a.size);
|
|
756
|
+
const largestTensors = tensorInfos.slice(0, 10);
|
|
757
|
+
// Estimate quantized sizes
|
|
758
|
+
const estimatedQuantizedSizes = {
|
|
759
|
+
int8: Math.ceil(totalSize / 4),
|
|
760
|
+
uint8: Math.ceil(totalSize / 4),
|
|
761
|
+
int4: Math.ceil(totalSize / 8),
|
|
762
|
+
float16: Math.ceil(totalSize / 2),
|
|
763
|
+
dynamic: Math.ceil(totalSize / 4),
|
|
764
|
+
};
|
|
765
|
+
// Recommend quantization based on model size
|
|
766
|
+
let recommendedQuantization = 'dynamic';
|
|
767
|
+
if (totalSize > 500 * 1024 * 1024) {
|
|
768
|
+
recommendedQuantization = 'int4';
|
|
769
|
+
}
|
|
770
|
+
else if (totalSize > 100 * 1024 * 1024) {
|
|
771
|
+
recommendedQuantization = 'int8';
|
|
772
|
+
}
|
|
773
|
+
else if (totalSize > 50 * 1024 * 1024) {
|
|
774
|
+
recommendedQuantization = 'float16';
|
|
775
|
+
}
|
|
776
|
+
return {
|
|
777
|
+
totalSize,
|
|
778
|
+
tensorCount: weights.length,
|
|
779
|
+
totalParameters: totalParams,
|
|
780
|
+
dtypeBreakdown,
|
|
781
|
+
largestTensors,
|
|
782
|
+
estimatedMemory: totalParams * 4, // Assuming float32 at runtime
|
|
783
|
+
recommendedQuantization,
|
|
784
|
+
estimatedQuantizedSizes,
|
|
785
|
+
};
|
|
786
|
+
}
|
|
787
|
+
/**
|
|
788
|
+
* Export a model to different formats
|
|
789
|
+
* Note: This is a placeholder - real implementation would require proper format conversion
|
|
790
|
+
*/
|
|
791
|
+
export async function exportModel(modelData, options) {
|
|
792
|
+
const { format, quantize } = options;
|
|
793
|
+
// Apply quantization if requested
|
|
794
|
+
let data = modelData;
|
|
795
|
+
if (quantize) {
|
|
796
|
+
const result = await quantizeModel(modelData, { type: quantize });
|
|
797
|
+
data = result.data;
|
|
798
|
+
}
|
|
799
|
+
// Format conversion would happen here
|
|
800
|
+
// For now, we just return the (possibly quantized) data
|
|
801
|
+
switch (format) {
|
|
802
|
+
case 'edgeflow':
|
|
803
|
+
return data;
|
|
804
|
+
case 'onnx':
|
|
805
|
+
// Would convert to ONNX format
|
|
806
|
+
return data;
|
|
807
|
+
case 'tflite':
|
|
808
|
+
// Would convert to TFLite format
|
|
809
|
+
return data;
|
|
810
|
+
default:
|
|
811
|
+
return data;
|
|
812
|
+
}
|
|
813
|
+
}
|
|
814
|
+
// ============================================================================
|
|
815
|
+
// Exports
|
|
816
|
+
// ============================================================================
|
|
817
|
+
export default {
|
|
818
|
+
quantizeModel,
|
|
819
|
+
quantizeTensor,
|
|
820
|
+
dequantizeTensor,
|
|
821
|
+
pruneModel,
|
|
822
|
+
pruneTensor,
|
|
823
|
+
analyzeModel,
|
|
824
|
+
exportModel,
|
|
825
|
+
dequantizeInt8,
|
|
826
|
+
dequantizeUint8,
|
|
827
|
+
dequantizeFloat16,
|
|
828
|
+
float16ToFloat32,
|
|
829
|
+
};
|
|
830
|
+
//# sourceMappingURL=quantization.js.map
|