localm-web 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,3 +1,10 @@
1
+ const DOWNLOAD_PATTERN = /\b(fetch|download|loading from cache|cache hit|param)/i;
2
+ const COMPILE_PATTERN = /\b(compil|shader|kernel|tensor|init|allocat|warm)/i;
3
+ function classifyLoadPhase(text) {
4
+ if (DOWNLOAD_PATTERN.test(text)) return "downloading";
5
+ if (COMPILE_PATTERN.test(text)) return "compiling";
6
+ return "loading";
7
+ }
1
8
  class LocalmWebError extends Error {
2
9
  /**
3
10
  * @param message - Human-readable description of the error.
@@ -73,10 +80,18 @@ class WebLLMEngine {
73
80
  progress: report.progress,
74
81
  text: report.text,
75
82
  loaded: 0,
76
- total: 0
83
+ total: 0,
84
+ phase: classifyLoadPhase(report.text)
77
85
  });
78
86
  }
79
87
  });
88
+ onProgress?.({
89
+ progress: 1,
90
+ text: "Model ready.",
91
+ loaded: 0,
92
+ total: 0,
93
+ phase: "ready"
94
+ });
80
95
  } catch (err) {
81
96
  throw new ModelLoadError(`Failed to load model "${modelId}".`, err);
82
97
  }
@@ -130,6 +145,55 @@ class WebLLMEngine {
130
145
  throw new ModelLoadError("Streaming generation failed.", err);
131
146
  }
132
147
  }
148
+ async complete(prompt, options = {}) {
149
+ const engine = this.requireEngine();
150
+ if (options.signal?.aborted) {
151
+ throw new GenerationAbortedError("Generation aborted before start.");
152
+ }
153
+ const completion = await engine.completions.create({
154
+ ...buildSamplingParams(options),
155
+ prompt,
156
+ stream: false
157
+ });
158
+ return completion.choices[0]?.text ?? "";
159
+ }
160
+ async *streamCompletion(prompt, options = {}) {
161
+ const engine = this.requireEngine();
162
+ if (options.signal?.aborted) {
163
+ throw new GenerationAbortedError("Generation aborted before start.");
164
+ }
165
+ const completion = await engine.completions.create({
166
+ ...buildSamplingParams(options),
167
+ prompt,
168
+ stream: true
169
+ });
170
+ let index = 0;
171
+ let finished = false;
172
+ try {
173
+ for await (const chunk of completion) {
174
+ if (options.signal?.aborted) {
175
+ throw new GenerationAbortedError("Generation aborted by signal.");
176
+ }
177
+ const choice = chunk.choices[0];
178
+ const delta = choice?.text ?? "";
179
+ if (delta) {
180
+ yield { text: delta, index, done: false };
181
+ index += 1;
182
+ }
183
+ if (choice?.finish_reason) {
184
+ finished = true;
185
+ yield { text: "", index, done: true };
186
+ index += 1;
187
+ }
188
+ }
189
+ if (!finished) {
190
+ yield { text: "", index, done: true };
191
+ }
192
+ } catch (err) {
193
+ if (err instanceof GenerationAbortedError) throw err;
194
+ throw new ModelLoadError("Streaming completion failed.", err);
195
+ }
196
+ }
133
197
  async unload() {
134
198
  if (this.engine) {
135
199
  await this.engine.unload();
@@ -143,6 +207,283 @@ class WebLLMEngine {
143
207
  return this.engine;
144
208
  }
145
209
  }
210
+ function toSerializableOptions(options = {}) {
211
+ const { signal: _signal, ...rest } = options;
212
+ return rest;
213
+ }
214
+ class WorkerEngine {
215
+ constructor(worker) {
216
+ this.worker = worker;
217
+ this.listener = (event) => this.handleMessage(event.data);
218
+ this.worker.addEventListener("message", this.listener);
219
+ }
220
+ nextId = 1;
221
+ loaded = false;
222
+ currentLoad = null;
223
+ currentLoadId = 0;
224
+ currentLoadProgress = void 0;
225
+ currentUnload = null;
226
+ currentUnloadId = 0;
227
+ pendingGenerates = /* @__PURE__ */ new Map();
228
+ pendingStreams = /* @__PURE__ */ new Map();
229
+ listener;
230
+ isLoaded() {
231
+ return this.loaded;
232
+ }
233
+ async load(modelId, onProgress) {
234
+ if (this.currentLoad) {
235
+ throw new ModelLoadError("Another load is already in progress.");
236
+ }
237
+ const id = this.allocateId();
238
+ this.currentLoadId = id;
239
+ this.currentLoadProgress = onProgress;
240
+ return new Promise((resolve, reject) => {
241
+ this.currentLoad = { resolve, reject };
242
+ this.send({ op: "load", id, modelId });
243
+ });
244
+ }
245
+ async generate(messages, options = {}) {
246
+ const id = this.allocateId();
247
+ return new Promise((resolve, reject) => {
248
+ this.pendingGenerates.set(id, { resolve, reject });
249
+ this.send({
250
+ op: "generate",
251
+ id,
252
+ messages,
253
+ options: toSerializableOptions(options)
254
+ });
255
+ options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
256
+ });
257
+ }
258
+ async *stream(messages, options = {}) {
259
+ const id = this.allocateId();
260
+ const queue = [];
261
+ let done = false;
262
+ let error = null;
263
+ let notify = null;
264
+ const wakeup = () => {
265
+ if (notify) {
266
+ const fn = notify;
267
+ notify = null;
268
+ fn();
269
+ }
270
+ };
271
+ this.pendingStreams.set(id, {
272
+ push: (chunk) => {
273
+ queue.push(chunk);
274
+ wakeup();
275
+ },
276
+ end: () => {
277
+ done = true;
278
+ wakeup();
279
+ },
280
+ fail: (err) => {
281
+ error = err;
282
+ done = true;
283
+ wakeup();
284
+ }
285
+ });
286
+ this.send({
287
+ op: "stream",
288
+ id,
289
+ messages,
290
+ options: toSerializableOptions(options)
291
+ });
292
+ options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
293
+ try {
294
+ while (true) {
295
+ if (queue.length > 0) {
296
+ const chunk = queue.shift();
297
+ if (chunk) yield chunk;
298
+ continue;
299
+ }
300
+ if (error) throw error;
301
+ if (done) return;
302
+ await new Promise((r) => {
303
+ notify = r;
304
+ });
305
+ }
306
+ } finally {
307
+ this.pendingStreams.delete(id);
308
+ }
309
+ }
310
+ async complete(prompt, options = {}) {
311
+ const id = this.allocateId();
312
+ return new Promise((resolve, reject) => {
313
+ this.pendingGenerates.set(id, { resolve, reject });
314
+ this.send({
315
+ op: "complete",
316
+ id,
317
+ prompt,
318
+ options: toSerializableOptions(options)
319
+ });
320
+ options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
321
+ });
322
+ }
323
+ async *streamCompletion(prompt, options = {}) {
324
+ const id = this.allocateId();
325
+ const queue = [];
326
+ let done = false;
327
+ let error = null;
328
+ let notify = null;
329
+ const wakeup = () => {
330
+ if (notify) {
331
+ const fn = notify;
332
+ notify = null;
333
+ fn();
334
+ }
335
+ };
336
+ this.pendingStreams.set(id, {
337
+ push: (chunk) => {
338
+ queue.push(chunk);
339
+ wakeup();
340
+ },
341
+ end: () => {
342
+ done = true;
343
+ wakeup();
344
+ },
345
+ fail: (err) => {
346
+ error = err;
347
+ done = true;
348
+ wakeup();
349
+ }
350
+ });
351
+ this.send({
352
+ op: "stream-completion",
353
+ id,
354
+ prompt,
355
+ options: toSerializableOptions(options)
356
+ });
357
+ options.signal?.addEventListener("abort", () => this.send({ op: "abort", id }));
358
+ try {
359
+ while (true) {
360
+ if (queue.length > 0) {
361
+ const chunk = queue.shift();
362
+ if (chunk) yield chunk;
363
+ continue;
364
+ }
365
+ if (error) throw error;
366
+ if (done) return;
367
+ await new Promise((r) => {
368
+ notify = r;
369
+ });
370
+ }
371
+ } finally {
372
+ this.pendingStreams.delete(id);
373
+ }
374
+ }
375
+ async unload() {
376
+ if (!this.loaded) return;
377
+ if (this.currentUnload) {
378
+ throw new ModelLoadError("Another unload is already in progress.");
379
+ }
380
+ const id = this.allocateId();
381
+ this.currentUnloadId = id;
382
+ return new Promise((resolve, reject) => {
383
+ this.currentUnload = { resolve, reject };
384
+ this.send({ op: "unload", id });
385
+ });
386
+ }
387
+ /** Tear down the underlying worker. The engine is unusable after this. */
388
+ terminate() {
389
+ this.worker.removeEventListener("message", this.listener);
390
+ this.worker.terminate();
391
+ this.loaded = false;
392
+ }
393
+ allocateId() {
394
+ const id = this.nextId;
395
+ this.nextId += 1;
396
+ return id;
397
+ }
398
+ send(req) {
399
+ this.worker.postMessage(req);
400
+ }
401
+ handleMessage(msg) {
402
+ switch (msg.op) {
403
+ case "loaded":
404
+ if (this.currentLoad && msg.id === this.currentLoadId) {
405
+ this.loaded = true;
406
+ this.currentLoad.resolve();
407
+ this.currentLoad = null;
408
+ this.currentLoadProgress = void 0;
409
+ }
410
+ return;
411
+ case "progress":
412
+ if (msg.id === this.currentLoadId) {
413
+ this.currentLoadProgress?.(msg.payload);
414
+ }
415
+ return;
416
+ case "generated": {
417
+ const pending = this.pendingGenerates.get(msg.id);
418
+ if (pending) {
419
+ pending.resolve(msg.text);
420
+ this.pendingGenerates.delete(msg.id);
421
+ }
422
+ return;
423
+ }
424
+ case "token": {
425
+ const stream = this.pendingStreams.get(msg.id);
426
+ stream?.push(msg.chunk);
427
+ return;
428
+ }
429
+ case "stream-end": {
430
+ const stream = this.pendingStreams.get(msg.id);
431
+ stream?.end();
432
+ return;
433
+ }
434
+ case "unloaded":
435
+ if (this.currentUnload && msg.id === this.currentUnloadId) {
436
+ this.loaded = false;
437
+ this.currentUnload.resolve();
438
+ this.currentUnload = null;
439
+ }
440
+ return;
441
+ case "is-loaded":
442
+ return;
443
+ case "error": {
444
+ const err = mapError(msg.name, msg.message);
445
+ if (this.currentLoad && msg.id === this.currentLoadId) {
446
+ this.currentLoad.reject(err);
447
+ this.currentLoad = null;
448
+ this.currentLoadProgress = void 0;
449
+ return;
450
+ }
451
+ if (this.currentUnload && msg.id === this.currentUnloadId) {
452
+ this.currentUnload.reject(err);
453
+ this.currentUnload = null;
454
+ return;
455
+ }
456
+ const generate = this.pendingGenerates.get(msg.id);
457
+ if (generate) {
458
+ generate.reject(err);
459
+ this.pendingGenerates.delete(msg.id);
460
+ return;
461
+ }
462
+ const stream = this.pendingStreams.get(msg.id);
463
+ if (stream) {
464
+ stream.fail(err);
465
+ return;
466
+ }
467
+ return;
468
+ }
469
+ }
470
+ }
471
+ }
472
+ function mapError(name, message) {
473
+ switch (name) {
474
+ case "ModelLoadError":
475
+ return new ModelLoadError(message);
476
+ case "ModelNotLoadedError":
477
+ return new ModelNotLoadedError(message);
478
+ case "GenerationAbortedError":
479
+ return new GenerationAbortedError(message);
480
+ default: {
481
+ const err = new Error(message);
482
+ err.name = name;
483
+ return err;
484
+ }
485
+ }
486
+ }
146
487
  const MODEL_PRESETS = Object.freeze({
147
488
  "phi-3.5-mini-int4": {
148
489
  id: "phi-3.5-mini-int4",
@@ -183,6 +524,15 @@ function resolveModelPreset(modelId) {
183
524
  function listSupportedModels() {
184
525
  return Object.keys(MODEL_PRESETS);
185
526
  }
527
+ function createInferenceWorker() {
528
+ return new Worker(new URL(
529
+ /* @vite-ignore */
530
+ "/assets/inference.worker-CwvQtobb.js",
531
+ import.meta.url
532
+ ), {
533
+ type: "module"
534
+ });
535
+ }
186
536
  class LMTask {
187
537
  constructor(engine, preset) {
188
538
  this.engine = engine;
@@ -198,12 +548,18 @@ class LMTask {
198
548
  */
199
549
  static async createEngine(modelId, options = {}) {
200
550
  const preset = resolveModelPreset(modelId);
201
- const engine = options.engine ?? new WebLLMEngine();
551
+ const engine = options.engine ?? LMTask.defaultEngine(options);
202
552
  if (!engine.isLoaded()) {
203
553
  await engine.load(preset.webllmId, options.onProgress);
204
554
  }
205
555
  return { engine, preset };
206
556
  }
557
+ static defaultEngine(options) {
558
+ if (options.inWorker) {
559
+ return new WorkerEngine(createInferenceWorker());
560
+ }
561
+ return new WebLLMEngine();
562
+ }
207
563
  /** Release engine resources. Safe to call multiple times. */
208
564
  async unload() {
209
565
  await this.engine.unload();
@@ -221,6 +577,14 @@ class ChatReply {
221
577
  this.finishReason = finishReason;
222
578
  }
223
579
  }
580
+ class CompletionResult {
581
+ constructor(text, prompt, tokensGenerated, finishReason) {
582
+ this.text = text;
583
+ this.prompt = prompt;
584
+ this.tokensGenerated = tokensGenerated;
585
+ this.finishReason = finishReason;
586
+ }
587
+ }
224
588
  class Chat extends LMTask {
225
589
  history = [];
226
590
  systemPrompt = null;
@@ -300,6 +664,152 @@ class Chat extends LMTask {
300
664
  return messages;
301
665
  }
302
666
  }
667
+ class Completion extends LMTask {
668
+ constructor(engine, preset) {
669
+ super(engine, preset);
670
+ }
671
+ /**
672
+ * Create and load a `Completion` task for the given model.
673
+ *
674
+ * @param modelId - Friendly model id from the registry (e.g. `"qwen2.5-1.5b-int4"`).
675
+ * @param options - Optional creation options (progress callback, engine override).
676
+ */
677
+ static async create(modelId, options = {}) {
678
+ const { engine, preset } = await LMTask.createEngine(modelId, options);
679
+ return new Completion(engine, preset);
680
+ }
681
+ /**
682
+ * Generate a continuation for the given prompt.
683
+ *
684
+ * @param prompt - Raw text fed to the model.
685
+ * @param options - Generation options.
686
+ * @returns A {@link CompletionResult} with the generated continuation.
687
+ */
688
+ async predict(prompt, options = {}) {
689
+ const text = await this.engine.complete(prompt, options);
690
+ return new CompletionResult(text, prompt, 0, "stop");
691
+ }
692
+ /**
693
+ * Stream a continuation for the given prompt as an async iterable of token
694
+ * chunks.
695
+ *
696
+ * @param prompt - Raw text fed to the model.
697
+ * @param options - Generation options including an optional `signal`.
698
+ */
699
+ async *stream(prompt, options = {}) {
700
+ for await (const chunk of this.engine.streamCompletion(prompt, options)) {
701
+ yield chunk;
702
+ }
703
+ }
704
+ }
705
+ let webllmCachePromise = null;
706
+ async function loadWebLLMCacheHelpers() {
707
+ if (!webllmCachePromise) {
708
+ webllmCachePromise = import("@mlc-ai/web-llm").then((m) => ({
709
+ hasModelInCache: m.hasModelInCache,
710
+ deleteModelInCache: m.deleteModelInCache
711
+ }));
712
+ }
713
+ return webllmCachePromise;
714
+ }
715
+ async function defaultEstimate() {
716
+ if (typeof navigator === "undefined" || !navigator.storage?.estimate) {
717
+ return { usage: 0, quota: 0 };
718
+ }
719
+ const estimate = await navigator.storage.estimate();
720
+ return {
721
+ usage: estimate.usage ?? 0,
722
+ quota: estimate.quota ?? 0
723
+ };
724
+ }
725
+ class ModelCache {
726
+ hasModelHook;
727
+ deleteModelHook;
728
+ estimateHook;
729
+ constructor(options = {}) {
730
+ this.hasModelHook = options.hasModel;
731
+ this.deleteModelHook = options.deleteModel;
732
+ this.estimateHook = options.estimate ?? defaultEstimate;
733
+ }
734
+ /**
735
+ * Whether the model's weights are present in the browser cache.
736
+ *
737
+ * @param modelId - Friendly id from the registry.
738
+ * @throws UnknownModelError if `modelId` is not in the registry.
739
+ */
740
+ async has(modelId) {
741
+ const backendId = resolveModelPreset(modelId).webllmId;
742
+ const fn = this.hasModelHook ?? (await loadWebLLMCacheHelpers()).hasModelInCache;
743
+ return fn(backendId);
744
+ }
745
+ /**
746
+ * Delete a single model's weights from the browser cache. No-op when the
747
+ * model is not cached.
748
+ *
749
+ * @param modelId - Friendly id from the registry.
750
+ * @throws UnknownModelError if `modelId` is not in the registry.
751
+ */
752
+ async delete(modelId) {
753
+ const backendId = resolveModelPreset(modelId).webllmId;
754
+ const fn = this.deleteModelHook ?? (await loadWebLLMCacheHelpers()).deleteModelInCache;
755
+ await fn(backendId);
756
+ }
757
+ /**
758
+ * List the registry models that are currently cached.
759
+ *
760
+ * Iterates `MODEL_PRESETS` and probes each one. Only returns models known
761
+ * to the SDK — models cached by external WebLLM calls outside our registry
762
+ * are not included.
763
+ *
764
+ * @returns Empty list when nothing is cached.
765
+ */
766
+ async list() {
767
+ const fn = this.hasModelHook ?? (await loadWebLLMCacheHelpers()).hasModelInCache;
768
+ const probes = await Promise.all(
769
+ Object.values(MODEL_PRESETS).map(async (preset) => {
770
+ const cached = await fn(preset.webllmId);
771
+ if (!cached) return null;
772
+ const entry = {
773
+ id: preset.id,
774
+ backendId: preset.webllmId,
775
+ family: preset.family,
776
+ parameters: preset.parameters
777
+ };
778
+ return entry;
779
+ })
780
+ );
781
+ return probes.filter((p) => p !== null);
782
+ }
783
+ /**
784
+ * Delete every registry model from the cache. Useful for logout flows or
785
+ * "reset" buttons. Models cached outside the registry are not touched.
786
+ */
787
+ async clear() {
788
+ const fn = this.deleteModelHook ?? (await loadWebLLMCacheHelpers()).deleteModelInCache;
789
+ await Promise.all(Object.values(MODEL_PRESETS).map((p) => fn(p.webllmId)));
790
+ }
791
+ /**
792
+ * Aggregate storage stats from the browser. Returned numbers cover the
793
+ * entire origin (Cache API + IndexedDB + Service Workers + OPFS), not
794
+ * just our model cache — use it for "you have X of Y available" hints.
795
+ */
796
+ async estimateUsage() {
797
+ return this.estimateHook();
798
+ }
799
+ /**
800
+ * Throw a descriptive error if the given id is not in the registry.
801
+ * Exposed for code paths that want to validate before calling other
802
+ * methods (those already throw on their own).
803
+ *
804
+ * @throws UnknownModelError
805
+ */
806
+ static assertKnown(modelId) {
807
+ if (!(modelId in MODEL_PRESETS)) {
808
+ const available = Object.keys(MODEL_PRESETS).join(", ");
809
+ throw new UnknownModelError(`Unknown model "${modelId}". Available models: ${available}.`);
810
+ }
811
+ }
812
+ }
303
813
  async function collectStream(stream) {
304
814
  let acc = "";
305
815
  for await (const chunk of stream) {
@@ -313,22 +823,27 @@ async function* tap(stream, onChunk) {
313
823
  yield chunk;
314
824
  }
315
825
  }
316
- const VERSION = "0.1.0";
826
+ const VERSION = "0.2.0";
317
827
  export {
318
828
  BackendNotAvailableError,
319
829
  Chat,
320
830
  ChatReply,
831
+ Completion,
832
+ CompletionResult,
321
833
  GenerationAbortedError,
322
834
  LMTask,
323
835
  LocalmWebError,
324
836
  MODEL_PRESETS,
837
+ ModelCache,
325
838
  ModelLoadError,
326
839
  ModelNotLoadedError,
327
840
  QuotaExceededError,
328
841
  UnknownModelError,
329
842
  VERSION,
330
843
  WebGPUUnavailableError,
844
+ WorkerEngine,
331
845
  collectStream,
846
+ createInferenceWorker,
332
847
  listSupportedModels,
333
848
  resolveModelPreset,
334
849
  tap