localm-web 0.1.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +154 -0
- package/README.md +3 -3
- package/dist/assets/index-ChQoBCqA.js +23168 -0
- package/dist/assets/index-ChQoBCqA.js.map +1 -0
- package/dist/assets/inference.worker-CwvQtobb.js +330 -0
- package/dist/assets/inference.worker-CwvQtobb.js.map +1 -0
- package/dist/index.d.ts +634 -0
- package/dist/index.js +807 -3
- package/dist/index.js.map +1 -1
- package/package.json +9 -2
package/dist/index.js
CHANGED
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
const DOWNLOAD_PATTERN = /\b(fetch|download|loading from cache|cache hit|param)/i;
|
|
2
|
+
const COMPILE_PATTERN = /\b(compil|shader|kernel|tensor|init|allocat|warm)/i;
|
|
3
|
+
function classifyLoadPhase(text) {
|
|
4
|
+
if (DOWNLOAD_PATTERN.test(text)) return "downloading";
|
|
5
|
+
if (COMPILE_PATTERN.test(text)) return "compiling";
|
|
6
|
+
return "loading";
|
|
7
|
+
}
|
|
1
8
|
class LocalmWebError extends Error {
|
|
2
9
|
/**
|
|
3
10
|
* @param message - Human-readable description of the error.
|
|
@@ -73,10 +80,18 @@ class WebLLMEngine {
|
|
|
73
80
|
progress: report.progress,
|
|
74
81
|
text: report.text,
|
|
75
82
|
loaded: 0,
|
|
76
|
-
total: 0
|
|
83
|
+
total: 0,
|
|
84
|
+
phase: classifyLoadPhase(report.text)
|
|
77
85
|
});
|
|
78
86
|
}
|
|
79
87
|
});
|
|
88
|
+
onProgress?.({
|
|
89
|
+
progress: 1,
|
|
90
|
+
text: "Model ready.",
|
|
91
|
+
loaded: 0,
|
|
92
|
+
total: 0,
|
|
93
|
+
phase: "ready"
|
|
94
|
+
});
|
|
80
95
|
} catch (err) {
|
|
81
96
|
throw new ModelLoadError(`Failed to load model "${modelId}".`, err);
|
|
82
97
|
}
|
|
@@ -130,6 +145,55 @@ class WebLLMEngine {
|
|
|
130
145
|
throw new ModelLoadError("Streaming generation failed.", err);
|
|
131
146
|
}
|
|
132
147
|
}
|
|
148
|
+
async complete(prompt, options = {}) {
|
|
149
|
+
const engine = this.requireEngine();
|
|
150
|
+
if (options.signal?.aborted) {
|
|
151
|
+
throw new GenerationAbortedError("Generation aborted before start.");
|
|
152
|
+
}
|
|
153
|
+
const completion = await engine.completions.create({
|
|
154
|
+
...buildSamplingParams(options),
|
|
155
|
+
prompt,
|
|
156
|
+
stream: false
|
|
157
|
+
});
|
|
158
|
+
return completion.choices[0]?.text ?? "";
|
|
159
|
+
}
|
|
160
|
+
async *streamCompletion(prompt, options = {}) {
|
|
161
|
+
const engine = this.requireEngine();
|
|
162
|
+
if (options.signal?.aborted) {
|
|
163
|
+
throw new GenerationAbortedError("Generation aborted before start.");
|
|
164
|
+
}
|
|
165
|
+
const completion = await engine.completions.create({
|
|
166
|
+
...buildSamplingParams(options),
|
|
167
|
+
prompt,
|
|
168
|
+
stream: true
|
|
169
|
+
});
|
|
170
|
+
let index = 0;
|
|
171
|
+
let finished = false;
|
|
172
|
+
try {
|
|
173
|
+
for await (const chunk of completion) {
|
|
174
|
+
if (options.signal?.aborted) {
|
|
175
|
+
throw new GenerationAbortedError("Generation aborted by signal.");
|
|
176
|
+
}
|
|
177
|
+
const choice = chunk.choices[0];
|
|
178
|
+
const delta = choice?.text ?? "";
|
|
179
|
+
if (delta) {
|
|
180
|
+
yield { text: delta, index, done: false };
|
|
181
|
+
index += 1;
|
|
182
|
+
}
|
|
183
|
+
if (choice?.finish_reason) {
|
|
184
|
+
finished = true;
|
|
185
|
+
yield { text: "", index, done: true };
|
|
186
|
+
index += 1;
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
if (!finished) {
|
|
190
|
+
yield { text: "", index, done: true };
|
|
191
|
+
}
|
|
192
|
+
} catch (err) {
|
|
193
|
+
if (err instanceof GenerationAbortedError) throw err;
|
|
194
|
+
throw new ModelLoadError("Streaming completion failed.", err);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
133
197
|
async unload() {
|
|
134
198
|
if (this.engine) {
|
|
135
199
|
await this.engine.unload();
|
|
@@ -143,6 +207,283 @@ class WebLLMEngine {
|
|
|
143
207
|
return this.engine;
|
|
144
208
|
}
|
|
145
209
|
}
|
|
210
|
+
function toSerializableOptions(options = {}) {
|
|
211
|
+
const { signal: _signal, ...rest } = options;
|
|
212
|
+
return rest;
|
|
213
|
+
}
|
|
214
|
+
class WorkerEngine {
|
|
215
|
+
constructor(worker) {
|
|
216
|
+
this.worker = worker;
|
|
217
|
+
this.listener = (event) => this.handleMessage(event.data);
|
|
218
|
+
this.worker.addEventListener("message", this.listener);
|
|
219
|
+
}
|
|
220
|
+
nextId = 1;
|
|
221
|
+
loaded = false;
|
|
222
|
+
currentLoad = null;
|
|
223
|
+
currentLoadId = 0;
|
|
224
|
+
currentLoadProgress = void 0;
|
|
225
|
+
currentUnload = null;
|
|
226
|
+
currentUnloadId = 0;
|
|
227
|
+
pendingGenerates = /* @__PURE__ */ new Map();
|
|
228
|
+
pendingStreams = /* @__PURE__ */ new Map();
|
|
229
|
+
listener;
|
|
230
|
+
isLoaded() {
|
|
231
|
+
return this.loaded;
|
|
232
|
+
}
|
|
233
|
+
async load(modelId, onProgress) {
|
|
234
|
+
if (this.currentLoad) {
|
|
235
|
+
throw new ModelLoadError("Another load is already in progress.");
|
|
236
|
+
}
|
|
237
|
+
const id = this.allocateId();
|
|
238
|
+
this.currentLoadId = id;
|
|
239
|
+
this.currentLoadProgress = onProgress;
|
|
240
|
+
return new Promise((resolve, reject) => {
|
|
241
|
+
this.currentLoad = { resolve, reject };
|
|
242
|
+
this.send({ op: "load", id, modelId });
|
|
243
|
+
});
|
|
244
|
+
}
|
|
245
|
+
async generate(messages, options = {}) {
|
|
246
|
+
const id = this.allocateId();
|
|
247
|
+
return new Promise((resolve, reject) => {
|
|
248
|
+
this.pendingGenerates.set(id, { resolve, reject });
|
|
249
|
+
this.send({
|
|
250
|
+
op: "generate",
|
|
251
|
+
id,
|
|
252
|
+
messages,
|
|
253
|
+
options: toSerializableOptions(options)
|
|
254
|
+
});
|
|
255
|
+
options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
|
|
256
|
+
});
|
|
257
|
+
}
|
|
258
|
+
async *stream(messages, options = {}) {
|
|
259
|
+
const id = this.allocateId();
|
|
260
|
+
const queue = [];
|
|
261
|
+
let done = false;
|
|
262
|
+
let error = null;
|
|
263
|
+
let notify = null;
|
|
264
|
+
const wakeup = () => {
|
|
265
|
+
if (notify) {
|
|
266
|
+
const fn = notify;
|
|
267
|
+
notify = null;
|
|
268
|
+
fn();
|
|
269
|
+
}
|
|
270
|
+
};
|
|
271
|
+
this.pendingStreams.set(id, {
|
|
272
|
+
push: (chunk) => {
|
|
273
|
+
queue.push(chunk);
|
|
274
|
+
wakeup();
|
|
275
|
+
},
|
|
276
|
+
end: () => {
|
|
277
|
+
done = true;
|
|
278
|
+
wakeup();
|
|
279
|
+
},
|
|
280
|
+
fail: (err) => {
|
|
281
|
+
error = err;
|
|
282
|
+
done = true;
|
|
283
|
+
wakeup();
|
|
284
|
+
}
|
|
285
|
+
});
|
|
286
|
+
this.send({
|
|
287
|
+
op: "stream",
|
|
288
|
+
id,
|
|
289
|
+
messages,
|
|
290
|
+
options: toSerializableOptions(options)
|
|
291
|
+
});
|
|
292
|
+
options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
|
|
293
|
+
try {
|
|
294
|
+
while (true) {
|
|
295
|
+
if (queue.length > 0) {
|
|
296
|
+
const chunk = queue.shift();
|
|
297
|
+
if (chunk) yield chunk;
|
|
298
|
+
continue;
|
|
299
|
+
}
|
|
300
|
+
if (error) throw error;
|
|
301
|
+
if (done) return;
|
|
302
|
+
await new Promise((r) => {
|
|
303
|
+
notify = r;
|
|
304
|
+
});
|
|
305
|
+
}
|
|
306
|
+
} finally {
|
|
307
|
+
this.pendingStreams.delete(id);
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
async complete(prompt, options = {}) {
|
|
311
|
+
const id = this.allocateId();
|
|
312
|
+
return new Promise((resolve, reject) => {
|
|
313
|
+
this.pendingGenerates.set(id, { resolve, reject });
|
|
314
|
+
this.send({
|
|
315
|
+
op: "complete",
|
|
316
|
+
id,
|
|
317
|
+
prompt,
|
|
318
|
+
options: toSerializableOptions(options)
|
|
319
|
+
});
|
|
320
|
+
options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
|
|
321
|
+
});
|
|
322
|
+
}
|
|
323
|
+
async *streamCompletion(prompt, options = {}) {
|
|
324
|
+
const id = this.allocateId();
|
|
325
|
+
const queue = [];
|
|
326
|
+
let done = false;
|
|
327
|
+
let error = null;
|
|
328
|
+
let notify = null;
|
|
329
|
+
const wakeup = () => {
|
|
330
|
+
if (notify) {
|
|
331
|
+
const fn = notify;
|
|
332
|
+
notify = null;
|
|
333
|
+
fn();
|
|
334
|
+
}
|
|
335
|
+
};
|
|
336
|
+
this.pendingStreams.set(id, {
|
|
337
|
+
push: (chunk) => {
|
|
338
|
+
queue.push(chunk);
|
|
339
|
+
wakeup();
|
|
340
|
+
},
|
|
341
|
+
end: () => {
|
|
342
|
+
done = true;
|
|
343
|
+
wakeup();
|
|
344
|
+
},
|
|
345
|
+
fail: (err) => {
|
|
346
|
+
error = err;
|
|
347
|
+
done = true;
|
|
348
|
+
wakeup();
|
|
349
|
+
}
|
|
350
|
+
});
|
|
351
|
+
this.send({
|
|
352
|
+
op: "stream-completion",
|
|
353
|
+
id,
|
|
354
|
+
prompt,
|
|
355
|
+
options: toSerializableOptions(options)
|
|
356
|
+
});
|
|
357
|
+
options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
|
|
358
|
+
try {
|
|
359
|
+
while (true) {
|
|
360
|
+
if (queue.length > 0) {
|
|
361
|
+
const chunk = queue.shift();
|
|
362
|
+
if (chunk) yield chunk;
|
|
363
|
+
continue;
|
|
364
|
+
}
|
|
365
|
+
if (error) throw error;
|
|
366
|
+
if (done) return;
|
|
367
|
+
await new Promise((r) => {
|
|
368
|
+
notify = r;
|
|
369
|
+
});
|
|
370
|
+
}
|
|
371
|
+
} finally {
|
|
372
|
+
this.pendingStreams.delete(id);
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
async unload() {
|
|
376
|
+
if (!this.loaded) return;
|
|
377
|
+
if (this.currentUnload) {
|
|
378
|
+
throw new ModelLoadError("Another unload is already in progress.");
|
|
379
|
+
}
|
|
380
|
+
const id = this.allocateId();
|
|
381
|
+
this.currentUnloadId = id;
|
|
382
|
+
return new Promise((resolve, reject) => {
|
|
383
|
+
this.currentUnload = { resolve, reject };
|
|
384
|
+
this.send({ op: "unload", id });
|
|
385
|
+
});
|
|
386
|
+
}
|
|
387
|
+
/** Tear down the underlying worker. The engine is unusable after this. */
|
|
388
|
+
terminate() {
|
|
389
|
+
this.worker.removeEventListener("message", this.listener);
|
|
390
|
+
this.worker.terminate();
|
|
391
|
+
this.loaded = false;
|
|
392
|
+
}
|
|
393
|
+
allocateId() {
|
|
394
|
+
const id = this.nextId;
|
|
395
|
+
this.nextId += 1;
|
|
396
|
+
return id;
|
|
397
|
+
}
|
|
398
|
+
send(req) {
|
|
399
|
+
this.worker.postMessage(req);
|
|
400
|
+
}
|
|
401
|
+
handleMessage(msg) {
|
|
402
|
+
switch (msg.op) {
|
|
403
|
+
case "loaded":
|
|
404
|
+
if (this.currentLoad && msg.id === this.currentLoadId) {
|
|
405
|
+
this.loaded = true;
|
|
406
|
+
this.currentLoad.resolve();
|
|
407
|
+
this.currentLoad = null;
|
|
408
|
+
this.currentLoadProgress = void 0;
|
|
409
|
+
}
|
|
410
|
+
return;
|
|
411
|
+
case "progress":
|
|
412
|
+
if (msg.id === this.currentLoadId) {
|
|
413
|
+
this.currentLoadProgress?.(msg.payload);
|
|
414
|
+
}
|
|
415
|
+
return;
|
|
416
|
+
case "generated": {
|
|
417
|
+
const pending = this.pendingGenerates.get(msg.id);
|
|
418
|
+
if (pending) {
|
|
419
|
+
pending.resolve(msg.text);
|
|
420
|
+
this.pendingGenerates.delete(msg.id);
|
|
421
|
+
}
|
|
422
|
+
return;
|
|
423
|
+
}
|
|
424
|
+
case "token": {
|
|
425
|
+
const stream = this.pendingStreams.get(msg.id);
|
|
426
|
+
stream?.push(msg.chunk);
|
|
427
|
+
return;
|
|
428
|
+
}
|
|
429
|
+
case "stream-end": {
|
|
430
|
+
const stream = this.pendingStreams.get(msg.id);
|
|
431
|
+
stream?.end();
|
|
432
|
+
return;
|
|
433
|
+
}
|
|
434
|
+
case "unloaded":
|
|
435
|
+
if (this.currentUnload && msg.id === this.currentUnloadId) {
|
|
436
|
+
this.loaded = false;
|
|
437
|
+
this.currentUnload.resolve();
|
|
438
|
+
this.currentUnload = null;
|
|
439
|
+
}
|
|
440
|
+
return;
|
|
441
|
+
case "is-loaded":
|
|
442
|
+
return;
|
|
443
|
+
case "error": {
|
|
444
|
+
const err = mapError(msg.name, msg.message);
|
|
445
|
+
if (this.currentLoad && msg.id === this.currentLoadId) {
|
|
446
|
+
this.currentLoad.reject(err);
|
|
447
|
+
this.currentLoad = null;
|
|
448
|
+
this.currentLoadProgress = void 0;
|
|
449
|
+
return;
|
|
450
|
+
}
|
|
451
|
+
if (this.currentUnload && msg.id === this.currentUnloadId) {
|
|
452
|
+
this.currentUnload.reject(err);
|
|
453
|
+
this.currentUnload = null;
|
|
454
|
+
return;
|
|
455
|
+
}
|
|
456
|
+
const generate = this.pendingGenerates.get(msg.id);
|
|
457
|
+
if (generate) {
|
|
458
|
+
generate.reject(err);
|
|
459
|
+
this.pendingGenerates.delete(msg.id);
|
|
460
|
+
return;
|
|
461
|
+
}
|
|
462
|
+
const stream = this.pendingStreams.get(msg.id);
|
|
463
|
+
if (stream) {
|
|
464
|
+
stream.fail(err);
|
|
465
|
+
return;
|
|
466
|
+
}
|
|
467
|
+
return;
|
|
468
|
+
}
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
}
|
|
472
|
+
function mapError(name, message) {
|
|
473
|
+
switch (name) {
|
|
474
|
+
case "ModelLoadError":
|
|
475
|
+
return new ModelLoadError(message);
|
|
476
|
+
case "ModelNotLoadedError":
|
|
477
|
+
return new ModelNotLoadedError(message);
|
|
478
|
+
case "GenerationAbortedError":
|
|
479
|
+
return new GenerationAbortedError(message);
|
|
480
|
+
default: {
|
|
481
|
+
const err = new Error(message);
|
|
482
|
+
err.name = name;
|
|
483
|
+
return err;
|
|
484
|
+
}
|
|
485
|
+
}
|
|
486
|
+
}
|
|
146
487
|
const MODEL_PRESETS = Object.freeze({
|
|
147
488
|
"phi-3.5-mini-int4": {
|
|
148
489
|
id: "phi-3.5-mini-int4",
|
|
@@ -183,6 +524,71 @@ function resolveModelPreset(modelId) {
|
|
|
183
524
|
function listSupportedModels() {
|
|
184
525
|
return Object.keys(MODEL_PRESETS);
|
|
185
526
|
}
|
|
527
|
+
const EMBEDDING_PRESETS = Object.freeze({
|
|
528
|
+
"bge-small-en-v1.5": {
|
|
529
|
+
id: "bge-small-en-v1.5",
|
|
530
|
+
family: "BGE",
|
|
531
|
+
dimension: 384,
|
|
532
|
+
maxTokens: 512,
|
|
533
|
+
transformersId: "Xenova/bge-small-en-v1.5",
|
|
534
|
+
quantization: "fp32",
|
|
535
|
+
description: "BAAI BGE small English v1.5, 384-dim sentence embeddings."
|
|
536
|
+
},
|
|
537
|
+
"bge-base-en-v1.5": {
|
|
538
|
+
id: "bge-base-en-v1.5",
|
|
539
|
+
family: "BGE",
|
|
540
|
+
dimension: 768,
|
|
541
|
+
maxTokens: 512,
|
|
542
|
+
transformersId: "Xenova/bge-base-en-v1.5",
|
|
543
|
+
quantization: "fp32",
|
|
544
|
+
description: "BAAI BGE base English v1.5, 768-dim sentence embeddings."
|
|
545
|
+
}
|
|
546
|
+
});
|
|
547
|
+
function resolveEmbeddingPreset(modelId) {
|
|
548
|
+
const preset = EMBEDDING_PRESETS[modelId];
|
|
549
|
+
if (!preset) {
|
|
550
|
+
const available = Object.keys(EMBEDDING_PRESETS).join(", ");
|
|
551
|
+
throw new UnknownModelError(
|
|
552
|
+
`Unknown embedding model "${modelId}". Available models: ${available}.`
|
|
553
|
+
);
|
|
554
|
+
}
|
|
555
|
+
return preset;
|
|
556
|
+
}
|
|
557
|
+
function listSupportedEmbeddingModels() {
|
|
558
|
+
return Object.keys(EMBEDDING_PRESETS);
|
|
559
|
+
}
|
|
560
|
+
const RERANKER_PRESETS = Object.freeze({
|
|
561
|
+
"bge-reranker-base": {
|
|
562
|
+
id: "bge-reranker-base",
|
|
563
|
+
family: "BGE Reranker",
|
|
564
|
+
maxTokens: 512,
|
|
565
|
+
transformersId: "Xenova/bge-reranker-base",
|
|
566
|
+
quantization: "fp32",
|
|
567
|
+
description: "BAAI BGE reranker base — multilingual cross-encoder."
|
|
568
|
+
}
|
|
569
|
+
});
|
|
570
|
+
function resolveRerankerPreset(modelId) {
|
|
571
|
+
const preset = RERANKER_PRESETS[modelId];
|
|
572
|
+
if (!preset) {
|
|
573
|
+
const available = Object.keys(RERANKER_PRESETS).join(", ");
|
|
574
|
+
throw new UnknownModelError(
|
|
575
|
+
`Unknown reranker model "${modelId}". Available models: ${available}.`
|
|
576
|
+
);
|
|
577
|
+
}
|
|
578
|
+
return preset;
|
|
579
|
+
}
|
|
580
|
+
function listSupportedRerankerModels() {
|
|
581
|
+
return Object.keys(RERANKER_PRESETS);
|
|
582
|
+
}
|
|
583
|
+
function createInferenceWorker() {
|
|
584
|
+
return new Worker(new URL(
|
|
585
|
+
/* @vite-ignore */
|
|
586
|
+
"/assets/inference.worker-CwvQtobb.js",
|
|
587
|
+
import.meta.url
|
|
588
|
+
), {
|
|
589
|
+
type: "module"
|
|
590
|
+
});
|
|
591
|
+
}
|
|
186
592
|
class LMTask {
|
|
187
593
|
constructor(engine, preset) {
|
|
188
594
|
this.engine = engine;
|
|
@@ -198,12 +604,19 @@ class LMTask {
|
|
|
198
604
|
*/
|
|
199
605
|
static async createEngine(modelId, options = {}) {
|
|
200
606
|
const preset = resolveModelPreset(modelId);
|
|
201
|
-
const engine = options.engine ??
|
|
607
|
+
const engine = options.engine ?? LMTask.defaultEngine(options);
|
|
202
608
|
if (!engine.isLoaded()) {
|
|
203
609
|
await engine.load(preset.webllmId, options.onProgress);
|
|
204
610
|
}
|
|
205
611
|
return { engine, preset };
|
|
206
612
|
}
|
|
613
|
+
static defaultEngine(options) {
|
|
614
|
+
const useWorker = options.inWorker ?? true;
|
|
615
|
+
if (useWorker) {
|
|
616
|
+
return new WorkerEngine(createInferenceWorker());
|
|
617
|
+
}
|
|
618
|
+
return new WebLLMEngine();
|
|
619
|
+
}
|
|
207
620
|
/** Release engine resources. Safe to call multiple times. */
|
|
208
621
|
async unload() {
|
|
209
622
|
await this.engine.unload();
|
|
@@ -221,6 +634,14 @@ class ChatReply {
|
|
|
221
634
|
this.finishReason = finishReason;
|
|
222
635
|
}
|
|
223
636
|
}
|
|
637
|
+
class CompletionResult {
|
|
638
|
+
constructor(text, prompt, tokensGenerated, finishReason) {
|
|
639
|
+
this.text = text;
|
|
640
|
+
this.prompt = prompt;
|
|
641
|
+
this.tokensGenerated = tokensGenerated;
|
|
642
|
+
this.finishReason = finishReason;
|
|
643
|
+
}
|
|
644
|
+
}
|
|
224
645
|
class Chat extends LMTask {
|
|
225
646
|
history = [];
|
|
226
647
|
systemPrompt = null;
|
|
@@ -300,6 +721,376 @@ class Chat extends LMTask {
|
|
|
300
721
|
return messages;
|
|
301
722
|
}
|
|
302
723
|
}
|
|
724
|
+
class Completion extends LMTask {
|
|
725
|
+
constructor(engine, preset) {
|
|
726
|
+
super(engine, preset);
|
|
727
|
+
}
|
|
728
|
+
/**
|
|
729
|
+
* Create and load a `Completion` task for the given model.
|
|
730
|
+
*
|
|
731
|
+
* @param modelId - Friendly model id from the registry (e.g. `"qwen2.5-1.5b-int4"`).
|
|
732
|
+
* @param options - Optional creation options (progress callback, engine override).
|
|
733
|
+
*/
|
|
734
|
+
static async create(modelId, options = {}) {
|
|
735
|
+
const { engine, preset } = await LMTask.createEngine(modelId, options);
|
|
736
|
+
return new Completion(engine, preset);
|
|
737
|
+
}
|
|
738
|
+
/**
|
|
739
|
+
* Generate a continuation for the given prompt.
|
|
740
|
+
*
|
|
741
|
+
* @param prompt - Raw text fed to the model.
|
|
742
|
+
* @param options - Generation options.
|
|
743
|
+
* @returns A {@link CompletionResult} with the generated continuation.
|
|
744
|
+
*/
|
|
745
|
+
async predict(prompt, options = {}) {
|
|
746
|
+
const text = await this.engine.complete(prompt, options);
|
|
747
|
+
return new CompletionResult(text, prompt, 0, "stop");
|
|
748
|
+
}
|
|
749
|
+
/**
|
|
750
|
+
* Stream a continuation for the given prompt as an async iterable of token
|
|
751
|
+
* chunks.
|
|
752
|
+
*
|
|
753
|
+
* @param prompt - Raw text fed to the model.
|
|
754
|
+
* @param options - Generation options including an optional `signal`.
|
|
755
|
+
*/
|
|
756
|
+
async *stream(prompt, options = {}) {
|
|
757
|
+
for await (const chunk of this.engine.streamCompletion(prompt, options)) {
|
|
758
|
+
yield chunk;
|
|
759
|
+
}
|
|
760
|
+
}
|
|
761
|
+
}
|
|
762
|
+
let transformersModulePromise$1 = null;
|
|
763
|
+
async function loadTransformers$1() {
|
|
764
|
+
if (!transformersModulePromise$1) {
|
|
765
|
+
transformersModulePromise$1 = import("@huggingface/transformers");
|
|
766
|
+
}
|
|
767
|
+
return transformersModulePromise$1;
|
|
768
|
+
}
|
|
769
|
+
async function buildDefaultPipeline$1(preset, onProgress) {
|
|
770
|
+
const transformers = await loadTransformers$1();
|
|
771
|
+
try {
|
|
772
|
+
const pipe = await transformers.pipeline("feature-extraction", preset.transformersId, {
|
|
773
|
+
progress_callback: (report) => {
|
|
774
|
+
if (!onProgress) return;
|
|
775
|
+
const r = report;
|
|
776
|
+
onProgress({
|
|
777
|
+
progress: typeof r.progress === "number" ? r.progress / 100 : 0,
|
|
778
|
+
text: r.status ?? "",
|
|
779
|
+
loaded: 0,
|
|
780
|
+
total: 0,
|
|
781
|
+
phase: "downloading"
|
|
782
|
+
});
|
|
783
|
+
}
|
|
784
|
+
});
|
|
785
|
+
return {
|
|
786
|
+
async embed(texts, options) {
|
|
787
|
+
const output = await pipe(texts, {
|
|
788
|
+
pooling: options.pooling,
|
|
789
|
+
normalize: options.normalize
|
|
790
|
+
});
|
|
791
|
+
return output.tolist();
|
|
792
|
+
},
|
|
793
|
+
async unload() {
|
|
794
|
+
if (typeof pipe.dispose === "function") {
|
|
795
|
+
await pipe.dispose();
|
|
796
|
+
}
|
|
797
|
+
}
|
|
798
|
+
};
|
|
799
|
+
} catch (err) {
|
|
800
|
+
throw new ModelLoadError(`Failed to load embedding model "${preset.id}".`, err);
|
|
801
|
+
}
|
|
802
|
+
}
|
|
803
|
+
class Embeddings {
|
|
804
|
+
constructor(pipeline, preset) {
|
|
805
|
+
this.pipeline = pipeline;
|
|
806
|
+
this.preset = preset;
|
|
807
|
+
}
|
|
808
|
+
/**
|
|
809
|
+
* Create and load an `Embeddings` task for the given model.
|
|
810
|
+
*
|
|
811
|
+
* @param modelId - Friendly id from the embedding registry.
|
|
812
|
+
* @param options - Optional creation options.
|
|
813
|
+
* @throws UnknownModelError if `modelId` is not in the registry.
|
|
814
|
+
* @throws ModelLoadError if the underlying pipeline fails to load.
|
|
815
|
+
*/
|
|
816
|
+
static async create(modelId, options = {}) {
|
|
817
|
+
const preset = resolveEmbeddingPreset(modelId);
|
|
818
|
+
const pipeline = options.pipeline ?? await buildDefaultPipeline$1(preset, options.onProgress);
|
|
819
|
+
return new Embeddings(pipeline, preset);
|
|
820
|
+
}
|
|
821
|
+
/**
|
|
822
|
+
* Encode an array of strings into dense vectors.
|
|
823
|
+
*
|
|
824
|
+
* Returns one vector per input, in the same order. Empty input array
|
|
825
|
+
* returns an empty array (no error).
|
|
826
|
+
*
|
|
827
|
+
* @param texts - Input strings.
|
|
828
|
+
* @param options - Pooling + normalization. Defaults: `pooling: "mean"`, `normalize: true`.
|
|
829
|
+
*/
|
|
830
|
+
async embed(texts, options = {}) {
|
|
831
|
+
if (texts.length === 0) return [];
|
|
832
|
+
if (!this.pipeline) {
|
|
833
|
+
throw new ModelNotLoadedError("Embeddings pipeline not initialized.");
|
|
834
|
+
}
|
|
835
|
+
const merged = {
|
|
836
|
+
normalize: options.normalize ?? true,
|
|
837
|
+
pooling: options.pooling ?? "mean"
|
|
838
|
+
};
|
|
839
|
+
return this.pipeline.embed(texts, merged);
|
|
840
|
+
}
|
|
841
|
+
/**
|
|
842
|
+
* Convenience: encode a single string and return its vector.
|
|
843
|
+
*
|
|
844
|
+
* @param text - Input string.
|
|
845
|
+
* @param options - Forwarded to {@link Embeddings.embed}.
|
|
846
|
+
*/
|
|
847
|
+
async embedSingle(text, options = {}) {
|
|
848
|
+
const [vec] = await this.embed([text], options);
|
|
849
|
+
if (!vec) {
|
|
850
|
+
throw new ModelLoadError("Embedding pipeline returned no result.");
|
|
851
|
+
}
|
|
852
|
+
return vec;
|
|
853
|
+
}
|
|
854
|
+
/** Embedding dimension exposed by the loaded model. */
|
|
855
|
+
get dimension() {
|
|
856
|
+
return this.preset.dimension;
|
|
857
|
+
}
|
|
858
|
+
/** Release pipeline resources. Safe to call multiple times. */
|
|
859
|
+
async unload() {
|
|
860
|
+
await this.pipeline.unload?.();
|
|
861
|
+
}
|
|
862
|
+
}
|
|
863
|
+
let transformersModulePromise = null;
|
|
864
|
+
async function loadTransformers() {
|
|
865
|
+
if (!transformersModulePromise) {
|
|
866
|
+
transformersModulePromise = import("@huggingface/transformers");
|
|
867
|
+
}
|
|
868
|
+
return transformersModulePromise;
|
|
869
|
+
}
|
|
870
|
+
function sigmoidValue(x) {
|
|
871
|
+
return 1 / (1 + Math.exp(-x));
|
|
872
|
+
}
|
|
873
|
+
async function buildDefaultPipeline(preset, onProgress) {
|
|
874
|
+
const transformers = await loadTransformers();
|
|
875
|
+
try {
|
|
876
|
+
const tokenizer = await transformers.AutoTokenizer.from_pretrained(preset.transformersId, {
|
|
877
|
+
progress_callback: (report) => {
|
|
878
|
+
if (!onProgress) return;
|
|
879
|
+
const r = report;
|
|
880
|
+
onProgress({
|
|
881
|
+
progress: typeof r.progress === "number" ? r.progress / 100 : 0,
|
|
882
|
+
text: r.status ?? "",
|
|
883
|
+
loaded: 0,
|
|
884
|
+
total: 0,
|
|
885
|
+
phase: "downloading"
|
|
886
|
+
});
|
|
887
|
+
}
|
|
888
|
+
});
|
|
889
|
+
const model = await transformers.AutoModelForSequenceClassification.from_pretrained(
|
|
890
|
+
preset.transformersId,
|
|
891
|
+
{
|
|
892
|
+
progress_callback: (report) => {
|
|
893
|
+
if (!onProgress) return;
|
|
894
|
+
const r = report;
|
|
895
|
+
onProgress({
|
|
896
|
+
progress: typeof r.progress === "number" ? r.progress / 100 : 0,
|
|
897
|
+
text: r.status ?? "",
|
|
898
|
+
loaded: 0,
|
|
899
|
+
total: 0,
|
|
900
|
+
phase: "downloading"
|
|
901
|
+
});
|
|
902
|
+
}
|
|
903
|
+
}
|
|
904
|
+
);
|
|
905
|
+
return {
|
|
906
|
+
async score(query, docs) {
|
|
907
|
+
if (docs.length === 0) return [];
|
|
908
|
+
const queries = docs.map(() => query);
|
|
909
|
+
const tokenize = tokenizer;
|
|
910
|
+
const inputs = tokenize(queries, {
|
|
911
|
+
text_pair: docs,
|
|
912
|
+
padding: true,
|
|
913
|
+
truncation: true,
|
|
914
|
+
max_length: preset.maxTokens
|
|
915
|
+
});
|
|
916
|
+
const callModel = model;
|
|
917
|
+
const outputs = await callModel(inputs);
|
|
918
|
+
const logits = outputs.logits.tolist();
|
|
919
|
+
return logits.map((row) => row[0] ?? 0);
|
|
920
|
+
},
|
|
921
|
+
async unload() {
|
|
922
|
+
const m = model;
|
|
923
|
+
if (typeof m.dispose === "function") await m.dispose();
|
|
924
|
+
}
|
|
925
|
+
};
|
|
926
|
+
} catch (err) {
|
|
927
|
+
throw new ModelLoadError(`Failed to load reranker model "${preset.id}".`, err);
|
|
928
|
+
}
|
|
929
|
+
}
|
|
930
|
+
class Reranker {
|
|
931
|
+
constructor(pipeline, preset) {
|
|
932
|
+
this.pipeline = pipeline;
|
|
933
|
+
this.preset = preset;
|
|
934
|
+
}
|
|
935
|
+
/**
|
|
936
|
+
* Create and load a `Reranker` task for the given model.
|
|
937
|
+
*
|
|
938
|
+
* @param modelId - Friendly id from the reranker registry.
|
|
939
|
+
* @param options - Optional creation options.
|
|
940
|
+
* @throws UnknownModelError if `modelId` is not in the registry.
|
|
941
|
+
* @throws ModelLoadError if the underlying pipeline fails to load.
|
|
942
|
+
*/
|
|
943
|
+
static async create(modelId, options = {}) {
|
|
944
|
+
const preset = resolveRerankerPreset(modelId);
|
|
945
|
+
const pipeline = options.pipeline ?? await buildDefaultPipeline(preset, options.onProgress);
|
|
946
|
+
return new Reranker(pipeline, preset);
|
|
947
|
+
}
|
|
948
|
+
/**
|
|
949
|
+
* Score each document against the query. Returns one score per doc, in
|
|
950
|
+
* the same order. Empty `docs` returns `[]` (no error).
|
|
951
|
+
*
|
|
952
|
+
* @param query - Query string.
|
|
953
|
+
* @param docs - Documents to score.
|
|
954
|
+
* @param options - `sigmoid: true` maps logits into `[0, 1]`.
|
|
955
|
+
*/
|
|
956
|
+
async score(query, docs, options = {}) {
|
|
957
|
+
if (docs.length === 0) return [];
|
|
958
|
+
if (!this.pipeline) {
|
|
959
|
+
throw new ModelNotLoadedError("Reranker pipeline not initialized.");
|
|
960
|
+
}
|
|
961
|
+
const raw = await this.pipeline.score(query, docs);
|
|
962
|
+
return options.sigmoid ? raw.map(sigmoidValue) : raw;
|
|
963
|
+
}
|
|
964
|
+
/**
|
|
965
|
+
* Score and sort documents by score in descending order. Returns a list of
|
|
966
|
+
* {@link RankedDocument}s carrying the original index.
|
|
967
|
+
*
|
|
968
|
+
* @param query - Query string.
|
|
969
|
+
* @param docs - Documents to rank.
|
|
970
|
+
* @param options - Forwarded to {@link Reranker.score}.
|
|
971
|
+
*/
|
|
972
|
+
async rank(query, docs, options = {}) {
|
|
973
|
+
const scores = await this.score(query, docs, options);
|
|
974
|
+
const ranked = scores.map((score, index) => {
|
|
975
|
+
const text = docs[index] ?? "";
|
|
976
|
+
return { text, score, index };
|
|
977
|
+
});
|
|
978
|
+
ranked.sort((a, b) => b.score - a.score);
|
|
979
|
+
return ranked;
|
|
980
|
+
}
|
|
981
|
+
/** Release pipeline resources. Safe to call multiple times. */
|
|
982
|
+
async unload() {
|
|
983
|
+
await this.pipeline.unload?.();
|
|
984
|
+
}
|
|
985
|
+
}
|
|
986
|
+
let webllmCachePromise = null;
|
|
987
|
+
async function loadWebLLMCacheHelpers() {
|
|
988
|
+
if (!webllmCachePromise) {
|
|
989
|
+
webllmCachePromise = import("@mlc-ai/web-llm").then((m) => ({
|
|
990
|
+
hasModelInCache: m.hasModelInCache,
|
|
991
|
+
deleteModelInCache: m.deleteModelInCache
|
|
992
|
+
}));
|
|
993
|
+
}
|
|
994
|
+
return webllmCachePromise;
|
|
995
|
+
}
|
|
996
|
+
async function defaultEstimate() {
|
|
997
|
+
if (typeof navigator === "undefined" || !navigator.storage?.estimate) {
|
|
998
|
+
return { usage: 0, quota: 0 };
|
|
999
|
+
}
|
|
1000
|
+
const estimate = await navigator.storage.estimate();
|
|
1001
|
+
return {
|
|
1002
|
+
usage: estimate.usage ?? 0,
|
|
1003
|
+
quota: estimate.quota ?? 0
|
|
1004
|
+
};
|
|
1005
|
+
}
|
|
1006
|
+
class ModelCache {
|
|
1007
|
+
hasModelHook;
|
|
1008
|
+
deleteModelHook;
|
|
1009
|
+
estimateHook;
|
|
1010
|
+
constructor(options = {}) {
|
|
1011
|
+
this.hasModelHook = options.hasModel;
|
|
1012
|
+
this.deleteModelHook = options.deleteModel;
|
|
1013
|
+
this.estimateHook = options.estimate ?? defaultEstimate;
|
|
1014
|
+
}
|
|
1015
|
+
/**
|
|
1016
|
+
* Whether the model's weights are present in the browser cache.
|
|
1017
|
+
*
|
|
1018
|
+
* @param modelId - Friendly id from the registry.
|
|
1019
|
+
* @throws UnknownModelError if `modelId` is not in the registry.
|
|
1020
|
+
*/
|
|
1021
|
+
async has(modelId) {
|
|
1022
|
+
const backendId = resolveModelPreset(modelId).webllmId;
|
|
1023
|
+
const fn = this.hasModelHook ?? (await loadWebLLMCacheHelpers()).hasModelInCache;
|
|
1024
|
+
return fn(backendId);
|
|
1025
|
+
}
|
|
1026
|
+
/**
|
|
1027
|
+
* Delete a single model's weights from the browser cache. No-op when the
|
|
1028
|
+
* model is not cached.
|
|
1029
|
+
*
|
|
1030
|
+
* @param modelId - Friendly id from the registry.
|
|
1031
|
+
* @throws UnknownModelError if `modelId` is not in the registry.
|
|
1032
|
+
*/
|
|
1033
|
+
async delete(modelId) {
|
|
1034
|
+
const backendId = resolveModelPreset(modelId).webllmId;
|
|
1035
|
+
const fn = this.deleteModelHook ?? (await loadWebLLMCacheHelpers()).deleteModelInCache;
|
|
1036
|
+
await fn(backendId);
|
|
1037
|
+
}
|
|
1038
|
+
/**
|
|
1039
|
+
* List the registry models that are currently cached.
|
|
1040
|
+
*
|
|
1041
|
+
* Iterates `MODEL_PRESETS` and probes each one. Only returns models known
|
|
1042
|
+
* to the SDK — models cached by external WebLLM calls outside our registry
|
|
1043
|
+
* are not included.
|
|
1044
|
+
*
|
|
1045
|
+
* @returns Empty list when nothing is cached.
|
|
1046
|
+
*/
|
|
1047
|
+
async list() {
|
|
1048
|
+
const fn = this.hasModelHook ?? (await loadWebLLMCacheHelpers()).hasModelInCache;
|
|
1049
|
+
const probes = await Promise.all(
|
|
1050
|
+
Object.values(MODEL_PRESETS).map(async (preset) => {
|
|
1051
|
+
const cached = await fn(preset.webllmId);
|
|
1052
|
+
if (!cached) return null;
|
|
1053
|
+
const entry = {
|
|
1054
|
+
id: preset.id,
|
|
1055
|
+
backendId: preset.webllmId,
|
|
1056
|
+
family: preset.family,
|
|
1057
|
+
parameters: preset.parameters
|
|
1058
|
+
};
|
|
1059
|
+
return entry;
|
|
1060
|
+
})
|
|
1061
|
+
);
|
|
1062
|
+
return probes.filter((p) => p !== null);
|
|
1063
|
+
}
|
|
1064
|
+
/**
|
|
1065
|
+
* Delete every registry model from the cache. Useful for logout flows or
|
|
1066
|
+
* "reset" buttons. Models cached outside the registry are not touched.
|
|
1067
|
+
*/
|
|
1068
|
+
async clear() {
|
|
1069
|
+
const fn = this.deleteModelHook ?? (await loadWebLLMCacheHelpers()).deleteModelInCache;
|
|
1070
|
+
await Promise.all(Object.values(MODEL_PRESETS).map((p) => fn(p.webllmId)));
|
|
1071
|
+
}
|
|
1072
|
+
/**
|
|
1073
|
+
* Aggregate storage stats from the browser. Returned numbers cover the
|
|
1074
|
+
* entire origin (Cache API + IndexedDB + Service Workers + OPFS), not
|
|
1075
|
+
* just our model cache — use it for "you have X of Y available" hints.
|
|
1076
|
+
*/
|
|
1077
|
+
async estimateUsage() {
|
|
1078
|
+
return this.estimateHook();
|
|
1079
|
+
}
|
|
1080
|
+
/**
|
|
1081
|
+
* Throw a descriptive error if the given id is not in the registry.
|
|
1082
|
+
* Exposed for code paths that want to validate before calling other
|
|
1083
|
+
* methods (those already throw on their own).
|
|
1084
|
+
*
|
|
1085
|
+
* @throws UnknownModelError
|
|
1086
|
+
*/
|
|
1087
|
+
static assertKnown(modelId) {
|
|
1088
|
+
if (!(modelId in MODEL_PRESETS)) {
|
|
1089
|
+
const available = Object.keys(MODEL_PRESETS).join(", ");
|
|
1090
|
+
throw new UnknownModelError(`Unknown model "${modelId}". Available models: ${available}.`);
|
|
1091
|
+
}
|
|
1092
|
+
}
|
|
1093
|
+
}
|
|
303
1094
|
async function collectStream(stream) {
|
|
304
1095
|
let acc = "";
|
|
305
1096
|
for await (const chunk of stream) {
|
|
@@ -313,24 +1104,37 @@ async function* tap(stream, onChunk) {
|
|
|
313
1104
|
yield chunk;
|
|
314
1105
|
}
|
|
315
1106
|
}
|
|
316
|
-
const VERSION = "0.
|
|
1107
|
+
const VERSION = "0.3.0";
|
|
317
1108
|
export {
|
|
318
1109
|
BackendNotAvailableError,
|
|
319
1110
|
Chat,
|
|
320
1111
|
ChatReply,
|
|
1112
|
+
Completion,
|
|
1113
|
+
CompletionResult,
|
|
1114
|
+
EMBEDDING_PRESETS,
|
|
1115
|
+
Embeddings,
|
|
321
1116
|
GenerationAbortedError,
|
|
322
1117
|
LMTask,
|
|
323
1118
|
LocalmWebError,
|
|
324
1119
|
MODEL_PRESETS,
|
|
1120
|
+
ModelCache,
|
|
325
1121
|
ModelLoadError,
|
|
326
1122
|
ModelNotLoadedError,
|
|
327
1123
|
QuotaExceededError,
|
|
1124
|
+
RERANKER_PRESETS,
|
|
1125
|
+
Reranker,
|
|
328
1126
|
UnknownModelError,
|
|
329
1127
|
VERSION,
|
|
330
1128
|
WebGPUUnavailableError,
|
|
1129
|
+
WorkerEngine,
|
|
331
1130
|
collectStream,
|
|
1131
|
+
createInferenceWorker,
|
|
1132
|
+
listSupportedEmbeddingModels,
|
|
332
1133
|
listSupportedModels,
|
|
1134
|
+
listSupportedRerankerModels,
|
|
1135
|
+
resolveEmbeddingPreset,
|
|
333
1136
|
resolveModelPreset,
|
|
1137
|
+
resolveRerankerPreset,
|
|
334
1138
|
tap
|
|
335
1139
|
};
|
|
336
1140
|
//# sourceMappingURL=index.js.map
|