web-llm-runner 0.1.14 → 0.1.18

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/lib/index.js CHANGED
@@ -36688,6 +36688,7 @@ class ONNXEngine {
36688
36688
  modelId = null;
36689
36689
  appConfig;
36690
36690
  initProgressCallback;
36691
+ repoId = null;
36691
36692
  // APIs
36692
36693
  chat;
36693
36694
  completions;
@@ -36719,6 +36720,7 @@ class ONNXEngine {
36719
36720
  const { findModelRecord } = await Promise.resolve().then(function () { return support; });
36720
36721
  const record = findModelRecord(id, this.appConfig);
36721
36722
  repoId = record.onnx_id || id;
36723
+ this.repoId = repoId;
36722
36724
  }
36723
36725
  catch (e) {
36724
36726
  log.warn(`Model record not found for ${id}, using raw ID for ONNX.`);
@@ -36739,8 +36741,9 @@ class ONNXEngine {
36739
36741
  }
36740
36742
  try {
36741
36743
  // For T5 models, text2text-generation is the standard task in transformers.js
36742
- const task = repoId.toLowerCase().includes("t5") ? "text2text-generation" : "text-generation";
36743
- this.generator = await pipeline(task, repoId, {
36744
+ const currentRepoId = this.repoId || id;
36745
+ const task = currentRepoId.toLowerCase().includes("t5") ? "text2text-generation" : "text-generation";
36746
+ this.generator = await pipeline(task, currentRepoId, {
36744
36747
  progress_callback: (p) => {
36745
36748
  if (this.initProgressCallback && (p.status === 'progress' || p.status === 'downloading')) {
36746
36749
  const pctValue = (typeof p.progress === 'number') ? p.progress : 0;
@@ -36812,28 +36815,144 @@ class ONNXEngine {
36812
36815
  };
36813
36816
  }
36814
36817
  async *asyncGenerateStreaming(prompt, request) {
36815
- // Current simple implementation yields only a single chunk.
36816
- // In future iterations, we can integrate the Transformers.js TextStreamer
36817
- const result = await this.generateNonStreaming(prompt, request);
36818
- const content = result.choices[0].message.content;
36818
+ if (!this.generator)
36819
+ throw new Error("ONNX model not loaded.");
36820
+ const model = this.modelId;
36821
+ const created = Math.floor(Date.now() / 1000);
36822
+ const id = crypto.randomUUID();
36823
+ const queue = [];
36824
+ let isDone = false;
36825
+ let fullTextSoFar = "";
36826
+ // Run generation in the background
36827
+ (this.repoId || "").toLowerCase().includes("t5") ? "text2text-generation" : "text-generation";
36828
+ this.generator(prompt, {
36829
+ max_new_tokens: request.max_tokens || 256,
36830
+ temperature: request.temperature || 0.7,
36831
+ top_p: request.top_p || 1.0,
36832
+ do_sample: (request.temperature ?? 1.0) > 0,
36833
+ repetition_penalty: request.repetition_penalty || 1.1,
36834
+ callback_function: (beams) => {
36835
+ const decoded = this.generator.tokenizer.decode(beams[0].output_token_ids, { skip_special_tokens: true });
36836
+ const delta = decoded.slice(fullTextSoFar.length);
36837
+ if (delta) {
36838
+ queue.push(delta);
36839
+ fullTextSoFar = decoded;
36840
+ }
36841
+ },
36842
+ }).finally(() => {
36843
+ isDone = true;
36844
+ });
36845
+ while (!isDone || queue.length > 0) {
36846
+ if (queue.length > 0) {
36847
+ const content = queue.shift();
36848
+ yield {
36849
+ id,
36850
+ choices: [{
36851
+ delta: { content },
36852
+ finish_reason: null,
36853
+ index: 0,
36854
+ }],
36855
+ model,
36856
+ object: 'chat.completion.chunk',
36857
+ created,
36858
+ };
36859
+ }
36860
+ else {
36861
+ await new Promise(r => setTimeout(r, 10));
36862
+ }
36863
+ }
36819
36864
  yield {
36820
- id: result.id,
36865
+ id,
36821
36866
  choices: [{
36822
- delta: { role: 'assistant', content: content },
36867
+ delta: {},
36823
36868
  finish_reason: 'stop',
36824
36869
  index: 0,
36825
- logprobs: null
36826
36870
  }],
36827
- model: result.model,
36871
+ model,
36828
36872
  object: 'chat.completion.chunk',
36829
- created: result.created
36873
+ created,
36830
36874
  };
36831
36875
  }
36832
- async completion(_request) {
36833
- throw new Error("Generic completion not yet implemented in ONNXEngine fallback.");
36876
+ async completion(request) {
36877
+ if (!this.generator)
36878
+ throw new Error("ONNX model not loaded.");
36879
+ const prompt = typeof request.prompt === 'string' ? request.prompt : (Array.isArray(request.prompt) ? request.prompt[0] : "");
36880
+ if (request.stream) {
36881
+ return this.asyncGenerateStreamingCompletion(prompt, request);
36882
+ }
36883
+ else {
36884
+ const result = await this.generator(prompt, {
36885
+ max_new_tokens: request.max_tokens || 256,
36886
+ temperature: request.temperature || 0.7,
36887
+ top_p: request.top_p || 1.0,
36888
+ do_sample: (request.temperature ?? 1.0) > 0,
36889
+ repetition_penalty: request.repetition_penalty || 1.1,
36890
+ });
36891
+ const fullText = result[0].generated_text;
36892
+ const text = fullText.startsWith(prompt) ? fullText.slice(prompt.length) : fullText;
36893
+ return {
36894
+ id: crypto.randomUUID(),
36895
+ choices: [{
36896
+ text,
36897
+ finish_reason: 'stop',
36898
+ index: 0,
36899
+ logprobs: null
36900
+ }],
36901
+ model: this.modelId,
36902
+ object: 'text_completion',
36903
+ created: Math.floor(Date.now() / 1000),
36904
+ usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
36905
+ };
36906
+ }
36834
36907
  }
36835
- async embedding(_request) {
36836
- throw new Error("Embeddings not yet implemented in ONNXEngine fallback.");
36908
+ async *asyncGenerateStreamingCompletion(prompt, request) {
36909
+ const id = crypto.randomUUID();
36910
+ const created = Math.floor(Date.now() / 1000);
36911
+ const model = this.modelId;
36912
+ const queue = [];
36913
+ let isDone = false;
36914
+ let fullTextSoFar = "";
36915
+ this.generator(prompt, {
36916
+ max_new_tokens: request.max_tokens || 256,
36917
+ temperature: request.temperature || 0.7,
36918
+ callback_function: (beams) => {
36919
+ const decoded = this.generator.tokenizer.decode(beams[0].output_token_ids, { skip_special_tokens: true });
36920
+ const delta = decoded.slice(fullTextSoFar.length);
36921
+ if (delta) {
36922
+ queue.push(delta);
36923
+ fullTextSoFar = decoded;
36924
+ }
36925
+ },
36926
+ }).finally(() => { isDone = true; });
36927
+ while (!isDone || queue.length > 0) {
36928
+ if (queue.length > 0) {
36929
+ yield {
36930
+ id,
36931
+ choices: [{ text: queue.shift(), finish_reason: null, index: 0 }],
36932
+ model,
36933
+ object: 'text_completion',
36934
+ created,
36935
+ };
36936
+ }
36937
+ else {
36938
+ await new Promise(r => setTimeout(r, 10));
36939
+ }
36940
+ }
36941
+ }
36942
+ async embedding(request) {
36943
+ const input = Array.isArray(request.input) ? request.input : [request.input];
36944
+ const extractor = await pipeline('feature-extraction', this.modelId);
36945
+ const results = await Promise.all(input.map(text => extractor(text, { pooling: 'mean', normalize: true })));
36946
+ return {
36947
+ object: 'list',
36948
+ data: results.map((res, i) => ({
36949
+ object: 'embedding',
36950
+ index: i,
36951
+ embedding: Array.from(res.data)
36952
+ })),
36953
+ model: this.modelId,
36954
+ usage: { prompt_tokens: 0, total_tokens: 0, extra: {} }
36955
+ };
36837
36956
  }
36838
36957
  async runtimeStatsText() {
36839
36958
  return "Backend: ONNX Runtime (WASM/CPU Falback)";
@@ -38031,10 +38150,35 @@ class WebLLM {
38031
38150
  ];
38032
38151
  list = list.filter(m => approvedIds.includes(m.model_id));
38033
38152
  }
38153
+ else {
38154
+ // On Desktop, filter out those that are exclusively ONNX-id based (not for WebGPU)
38155
+ list = list.filter(m => !m.onnx_id);
38156
+ }
38034
38157
  return list.map((m) => m.model_id);
38035
38158
  }
38036
38159
  async local_model_available(model_id) {
38037
- return await hasModelInCache(model_id);
38160
+ const isMLCCached = await hasModelInCache(model_id);
38161
+ if (isMLCCached)
38162
+ return true;
38163
+ // Check ONNX cache fallback
38164
+ const record = prebuiltAppConfig.model_list.find(m => m.model_id === model_id);
38165
+ if (record && record.onnx_id) {
38166
+ return await this.hasONNXInCache(record.onnx_id);
38167
+ }
38168
+ return false;
38169
+ }
38170
+ async hasONNXInCache(onnx_id) {
38171
+ if (typeof caches === 'undefined')
38172
+ return false;
38173
+ try {
38174
+ const cache = await caches.open('transformers-cache');
38175
+ const url = `https://huggingface.co/${onnx_id}/resolve/main/config.json`;
38176
+ const match = await cache.match(url);
38177
+ return !!match;
38178
+ }
38179
+ catch (e) {
38180
+ return false;
38181
+ }
38038
38182
  }
38039
38183
  async download_model(model_id, progressCallback) {
38040
38184
  // Initial feedback
@@ -38055,6 +38199,14 @@ class WebLLM {
38055
38199
  return this.downloadProgress[model_id] || "No progress available.";
38056
38200
  }
38057
38201
  async delete_model(model_id) {
38202
+ const record = prebuiltAppConfig.model_list.find(m => m.model_id === model_id);
38203
+ if (record && record.onnx_id) {
38204
+ // For ONNX, we currently clear the whole transformers-cache for simplicity
38205
+ // as individual file deletion is complex without a full manifest.
38206
+ if (typeof caches !== 'undefined') {
38207
+ await caches.delete('transformers-cache');
38208
+ }
38209
+ }
38058
38210
  await deleteModelAllInfoInCache(model_id);
38059
38211
  }
38060
38212
  // chat endpoints (Stateful)