@agorapete/wllama 3.5.1-q2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/.gitmodules +3 -0
  2. package/.prettierignore +38 -0
  3. package/AGENTS.md +1 -0
  4. package/CMakeLists.txt +131 -0
  5. package/LICENCE +21 -0
  6. package/README-dev.md +178 -0
  7. package/README.md +225 -0
  8. package/README_banner.png +0 -0
  9. package/assets/screenshot_0.png +0 -0
  10. package/cpp/generate_glue_prototype.js +115 -0
  11. package/cpp/glue.hpp +664 -0
  12. package/cpp/test_glue.cpp +80 -0
  13. package/cpp/wllama-context.h +1172 -0
  14. package/cpp/wllama-fs.h +148 -0
  15. package/cpp/wllama.cpp +187 -0
  16. package/cpp/wllama.h +6 -0
  17. package/esm/cache-manager.d.ts +130 -0
  18. package/esm/debug.d.ts +28 -0
  19. package/esm/glue/glue.d.ts +22 -0
  20. package/esm/glue/messages.d.ts +146 -0
  21. package/esm/huggingface.d.ts +31 -0
  22. package/esm/index.cjs +3406 -0
  23. package/esm/index.d.ts +8 -0
  24. package/esm/index.js +3387 -0
  25. package/esm/index.min.js +1 -0
  26. package/esm/index.min.js.map +1 -0
  27. package/esm/model-manager.d.ts +136 -0
  28. package/esm/storage/cos.d.ts +36 -0
  29. package/esm/storage/index.d.ts +33 -0
  30. package/esm/storage/opfs.d.ts +12 -0
  31. package/esm/types/oai-compat.d.ts +278 -0
  32. package/esm/types/types.d.ts +112 -0
  33. package/esm/utils.d.ts +119 -0
  34. package/esm/wasm/source-map.d.ts +1 -0
  35. package/esm/wasm/wllama.wasm +0 -0
  36. package/esm/wasm-from-cdn.d.ts +8 -0
  37. package/esm/wllama.d.ts +397 -0
  38. package/esm/worker.d.ts +92 -0
  39. package/esm/workers-code/generated.d.ts +4 -0
  40. package/guides/intro-v2.md +132 -0
  41. package/guides/intro-v3.1.md +40 -0
  42. package/guides/intro-v3.md +230 -0
  43. package/index.ts +1 -0
  44. package/package.json +71 -0
  45. package/scripts/bisect_test.sh +33 -0
  46. package/scripts/build_hf_space.sh +26 -0
  47. package/scripts/build_source_map.js +269 -0
  48. package/scripts/build_wasm.sh +19 -0
  49. package/scripts/build_worker.sh +38 -0
  50. package/scripts/check_debug_build.js +30 -0
  51. package/scripts/check_package_size.js +25 -0
  52. package/scripts/docker-compose.yml +76 -0
  53. package/scripts/generate_wasm_from_cdn.js +24 -0
  54. package/scripts/http_server.js +44 -0
  55. package/scripts/post_build.sh +32 -0
  56. package/src/cache-manager.ts +358 -0
  57. package/src/debug.ts +111 -0
  58. package/src/glue/glue.ts +291 -0
  59. package/src/glue/messages.ts +773 -0
  60. package/src/huggingface.ts +151 -0
  61. package/src/index.ts +8 -0
  62. package/src/mjs.test.ts +44 -0
  63. package/src/model-manager.test.ts +200 -0
  64. package/src/model-manager.ts +359 -0
  65. package/src/storage/cos.test.ts +83 -0
  66. package/src/storage/cos.ts +171 -0
  67. package/src/storage/index.ts +40 -0
  68. package/src/storage/opfs.ts +119 -0
  69. package/src/types/oai-compat.ts +342 -0
  70. package/src/types/types.ts +133 -0
  71. package/src/utils.test.ts +231 -0
  72. package/src/utils.ts +403 -0
  73. package/src/wasm/source-map.ts +7 -0
  74. package/src/wasm/wllama.js +1 -0
  75. package/src/wasm/wllama.wasm +0 -0
  76. package/src/wasm-from-cdn.ts +13 -0
  77. package/src/wllama.test.ts +392 -0
  78. package/src/wllama.ts +1138 -0
  79. package/src/wllama.wgpu.test.ts +62 -0
  80. package/src/worker.ts +443 -0
  81. package/src/workers-code/generated.ts +11 -0
  82. package/src/workers-code/llama-cpp.js +511 -0
  83. package/src/workers-code/opfs-utils.js +150 -0
  84. package/tsconfig.build.json +34 -0
  85. package/tsup.config.ts +23 -0
  86. package/vitest.config.ts +61 -0
package/src/wllama.ts ADDED
@@ -0,0 +1,1138 @@
1
+ import { ProxyToWorker, type WllamaWorkerResources } from './worker';
2
+ import {
3
+ absoluteUrl,
4
+ canUseAsyncFileRead,
5
+ cbToAsyncIter,
6
+ checkEnvironmentCompatible,
7
+ isFirefox,
8
+ isString,
9
+ isSupportJSPI,
10
+ isSupportMultiThread,
11
+ isSupportWebGPU,
12
+ MMPROJ_FILE_NAME,
13
+ needCompat,
14
+ prepareBlobs,
15
+ } from './utils';
16
+ import CacheManager, { type DownloadOptions } from './cache-manager';
17
+ import { ModelManager, Model, type ModelSource } from './model-manager';
18
+ import type {
19
+ GlueMsgCompletionRes,
20
+ GlueMsgEmbeddingRes,
21
+ GlueMsgRerankRes,
22
+ GlueMsgGetResultRes,
23
+ GlueMsgLoadRes,
24
+ GlueMsgTestBackendOpsRes,
25
+ } from './glue/messages';
26
+ import { LIBLLAMA_VERSION } from './workers-code/generated';
27
+ import type {
28
+ LoadedContextInfo,
29
+ LoadModelParams,
30
+ StreamParams,
31
+ } from './types/types';
32
+ import type {
33
+ ChatCompletionChunk,
34
+ ChatCompletionParams,
35
+ ChatCompletionResponse,
36
+ ChatCompletionUserMessage,
37
+ CreateEmbeddingResponse,
38
+ EmbeddingCreateParams,
39
+ RawCompletionChunk,
40
+ RawCompletionParams,
41
+ RawCompletionResponse,
42
+ RerankParams,
43
+ RerankResponse,
44
+ } from './types/oai-compat';
45
+ import { LogLevel } from './types/types';
46
+ import { getHFModelSource, type HuggingFaceParams } from './huggingface';
47
+ import { WasmCompatFromCDN } from './wasm-from-cdn';
48
+
49
+ export interface WllamaLogger {
50
+ debug: typeof console.debug;
51
+ log: typeof console.log;
52
+ warn: typeof console.warn;
53
+ error: typeof console.error;
54
+ }
55
+
56
+ // TODO: bring back useCache
57
+ export interface WllamaConfig {
58
+ /**
59
+ * If true, suppress all log messages from native CPP code
60
+ */
61
+ suppressNativeLog?: boolean;
62
+ /**
63
+ * Custom logger functions
64
+ */
65
+ logger?: WllamaLogger;
66
+ /**
67
+ * Maximum number of parallel files to be downloaded
68
+ *
69
+ * Default: parallelDownloads = 3
70
+ */
71
+ parallelDownloads?: number;
72
+ /**
73
+ * Allow offline mode. If true, the model will be loaded from cache if it's available.
74
+ *
75
+ * Default: allowOffline = false
76
+ */
77
+ allowOffline?: boolean;
78
+ /**
79
+ * Custom cache manager (only for advanced usage)
80
+ */
81
+ cacheManager?: CacheManager;
82
+ /**
83
+ * Custom model manager (only for advanced usage)
84
+ */
85
+ modelManager?: ModelManager;
86
+ }
87
+
88
+ export interface WllamaChatMessage {
89
+ role: 'system' | 'user' | 'assistant';
90
+ content: string;
91
+ }
92
+
93
+ export interface AssetsPathConfig {
94
+ default: string;
95
+ 'single-thread/wllama.wasm'?: string; // deprecated, use "default" instead
96
+ 'multi-thread/wllama.wasm'?: string; // deprecated, use "default" instead
97
+ }
98
+
99
+ export interface ModelMetadata {
100
+ hparams: {
101
+ nVocab: number;
102
+ nCtxTrain: number;
103
+ nEmbd: number;
104
+ nLayer: number;
105
+ };
106
+ meta: Record<string, string>;
107
+ }
108
+
109
+ /**
110
+ * Logger preset with debug messages suppressed
111
+ */
112
+ export const LoggerWithoutDebug = {
113
+ ...console,
114
+ debug: () => {},
115
+ };
116
+
117
+ export type WllamaErrorType =
118
+ | 'model_not_loaded'
119
+ | 'download_error'
120
+ | 'load_error'
121
+ | 'kv_cache_full'
122
+ | 'unknown_error'
123
+ | 'inference_error';
124
+
125
+ export class WllamaError extends Error {
126
+ type: WllamaErrorType;
127
+ constructor(message: string, type: WllamaErrorType = 'unknown_error') {
128
+ super(message);
129
+ this.type = type;
130
+ }
131
+ }
132
+
133
+ /**
134
+ * AbortError is thrown when the user wants to abort the current operation.
135
+ * This is equivalent to AbortError in Fetch API.
136
+ */
137
+ export class WllamaAbortError extends Error {
138
+ override name: string = 'AbortError';
139
+ constructor() {
140
+ super('Operation aborted');
141
+ }
142
+ }
143
+
144
+ /**
145
+ * RuntimeError is thrown when there is an error in the WASM runtime, such as stack overflow, OOM, etc.
146
+ * Stack trace of the error in the WASM runtime can be included in the error object for debugging purpose.
147
+ */
148
+ export class WllamaRuntimeError extends Error {
149
+ override name: string = 'RuntimeError';
150
+ override stack: string;
151
+ constructor(message: string, stack: string) {
152
+ super(message);
153
+ this.stack = stack;
154
+ }
155
+ }
156
+
157
+ /**
158
+ * Set compatibility options for Wllama.
159
+ * By default, these are set to URL of the latest builds on CDN, which requires internet to download. If you want to use local assets or have your own CDN, follow the instruction from @wllama/wllama-compat package.
160
+ */
161
+ export interface WllamaCompat {
162
+ worker: string | { code: string };
163
+ wasm: string;
164
+ }
165
+
166
+ export class Wllama {
167
+ // The CacheManager and ModelManager are singleton, can be accessed by user
168
+ public cacheManager: CacheManager;
169
+ public modelManager: ModelManager;
170
+
171
+ private compat: WllamaCompat | null = null;
172
+
173
+ private proxy: ProxyToWorker = null as any;
174
+ private config: WllamaConfig;
175
+ private pathConfig: AssetsPathConfig;
176
+ private useMultiThread: boolean = false;
177
+ private nbThreads: number = 1;
178
+ private useEmbeddings: boolean = false;
179
+ private useRerank: boolean = false;
180
+ // available when loaded
181
+ private loadedContextInfo: LoadedContextInfo = null as any;
182
+ private seed: number | undefined = undefined;
183
+ private bosToken: number = -1;
184
+ private eosToken: number = -1;
185
+ private eotToken: number = -1;
186
+ private eogTokens: Set<number> = new Set();
187
+ private addBosToken: boolean = false;
188
+ private addEosToken: boolean = false;
189
+ private mediaMarker?: string;
190
+ private chatTemplate?: string;
191
+ private metadata?: ModelMetadata;
192
+ private hasEncoder: boolean = false;
193
+ private decoderStartToken: number = -1;
194
+
195
+ // note: we overlay instead of using llama-server default_template_kwargs, because we cannot transfer complex data structure via GLUE
196
+ // overlay allow mixed data type or nested structure for kwargs
197
+ private chatTemplateKwargs: Record<string, any> = {};
198
+
199
+ constructor(pathConfig: AssetsPathConfig, wllamaConfig: WllamaConfig = {}) {
200
+ checkEnvironmentCompatible();
201
+ if (!pathConfig) throw new WllamaError('AssetsPathConfig is required');
202
+ this.pathConfig = pathConfig;
203
+ this.config = wllamaConfig;
204
+ this.cacheManager = wllamaConfig.cacheManager ?? new CacheManager();
205
+ this.modelManager =
206
+ wllamaConfig.modelManager ??
207
+ new ModelManager({
208
+ cacheManager: this.cacheManager,
209
+ logger: wllamaConfig.logger ?? console,
210
+ parallelDownloads: wllamaConfig.parallelDownloads,
211
+ allowOffline: wllamaConfig.allowOffline,
212
+ });
213
+ this.setCompat('default');
214
+ }
215
+
216
+ private logger() {
217
+ return this.config.logger ?? console;
218
+ }
219
+
220
+ private checkModelLoaded() {
221
+ if (!this.isModelLoaded()) {
222
+ throw new WllamaError(
223
+ 'loadModel() is not yet called',
224
+ 'model_not_loaded'
225
+ );
226
+ }
227
+ }
228
+
229
+ /**
230
+ * Get the libllama version string, e.g. "b6327-4d74393".
231
+ *
232
+ * @returns version string embedded at build time.
233
+ */
234
+ static getLibllamaVersion(): string {
235
+ return LIBLLAMA_VERSION;
236
+ }
237
+
238
+ /**
239
+ * Set compatibility options for Wllama.
240
+ * @param compat Set to null to disable compatibility, or 'default' to use the default compat resources from CDN.
241
+ * @param mode 'safari' by default; If set to 'firefox_safari', the compat mode will **also** be enabled on Firefox, which will significantly degrade the performance but allow using WebGPU on Firefox.
242
+ */
243
+ setCompat(
244
+ compat: WllamaCompat | null | 'default',
245
+ mode: 'safari' | 'firefox_safari' = 'safari'
246
+ ) {
247
+ if (mode === 'safari') {
248
+ if (isFirefox()) {
249
+ this.compat = null;
250
+ return;
251
+ }
252
+ }
253
+ this.compat = compat === 'default' ? WasmCompatFromCDN : compat;
254
+ }
255
+
256
+ /**
257
+ * Check if the model is loaded via `loadModel()`
258
+ */
259
+ isModelLoaded(): boolean {
260
+ return !!this.proxy && !!this.metadata;
261
+ }
262
+
263
+ /**
264
+ * Get token ID associated to BOS (begin of sentence) token.
265
+ *
266
+ * NOTE: This can only being used after `loadModel` is called.
267
+ *
268
+ * @returns -1 if the model is not loaded.
269
+ */
270
+ getBOS(): number {
271
+ return this.bosToken;
272
+ }
273
+
274
+ /**
275
+ * Get token ID associated to EOS (end of sentence) token.
276
+ *
277
+ * NOTE: This can only being used after `loadModel` is called.
278
+ *
279
+ * @returns -1 if the model is not loaded.
280
+ */
281
+ getEOS(): number {
282
+ return this.eosToken;
283
+ }
284
+
285
+ /**
286
+ * Get token ID associated to EOT (end of turn) token.
287
+ *
288
+ * NOTE: This can only being used after `loadModel` is called.
289
+ *
290
+ * @returns -1 if the model is not loaded.
291
+ */
292
+ getEOT(): number {
293
+ return this.eotToken;
294
+ }
295
+
296
+ /**
297
+ * Check if a given token is end-of-generation token (e.g. EOS, EOT, etc.)
298
+ *
299
+ * @param token the token ID to be checked
300
+ * @returns true if the token is EOS, EOT, or any other end-of-generation tokens
301
+ */
302
+ isTokenEOG(token: number): boolean {
303
+ return (
304
+ token === this.eosToken ||
305
+ token === this.eotToken ||
306
+ this.eogTokens.has(token)
307
+ );
308
+ }
309
+
310
+ /**
311
+ * Get token ID associated to token used by decoder, to start generating output sequence(only usable for encoder-decoder architecture). In other words, encoder uses normal BOS and decoder uses this token.
312
+ *
313
+ * NOTE: This can only being used after `loadModel` is called.
314
+ *
315
+ * @returns -1 if the model is not loaded.
316
+ */
317
+ getDecoderStartToken(): number {
318
+ return this.decoderStartToken;
319
+ }
320
+
321
+ /**
322
+ * Get model hyper-parameters and metadata
323
+ *
324
+ * NOTE: This can only being used after `loadModel` is called.
325
+ *
326
+ * @returns ModelMetadata
327
+ */
328
+ getModelMetadata(): ModelMetadata {
329
+ this.checkModelLoaded();
330
+ return this.metadata!;
331
+ }
332
+
333
+ /**
334
+ * Check if we're currently using multi-thread build.
335
+ *
336
+ * NOTE: This can only being used after `loadModel` is called.
337
+ *
338
+ * @returns true if multi-thread is used.
339
+ */
340
+ isMultithread(): boolean {
341
+ this.checkModelLoaded();
342
+ return this.useMultiThread;
343
+ }
344
+
345
+ /**
346
+ * Get number of threads used in the current context.
347
+ *
348
+ * NOTE: This can only being used after `loadModel` is called.
349
+ *
350
+ * @returns number of threads
351
+ */
352
+ getNumThreads(): number {
353
+ this.checkModelLoaded();
354
+ return this.useMultiThread ? this.nbThreads : 1;
355
+ }
356
+
357
+ /**
358
+ * Check if the current model uses encoder-decoder architecture
359
+ *
360
+ * NOTE: This can only being used after `loadModel` is called.
361
+ *
362
+ * @returns true if multi-thread is used.
363
+ */
364
+ isEncoderDecoderArchitecture(): boolean {
365
+ this.checkModelLoaded();
366
+ return this.hasEncoder;
367
+ }
368
+
369
+ /**
370
+ * Must we add BOS token to the tokenized sequence?
371
+ *
372
+ * NOTE: This can only being used after `loadModel` is called.
373
+ *
374
+ * @returns true if BOS token must be added to the sequence
375
+ */
376
+ mustAddBosToken(): boolean {
377
+ this.checkModelLoaded();
378
+ return this.addBosToken;
379
+ }
380
+
381
+ /**
382
+ * Must we add EOS token to the tokenized sequence?
383
+ *
384
+ * NOTE: This can only being used after `loadModel` is called.
385
+ *
386
+ * @returns true if EOS token must be added to the sequence
387
+ */
388
+ mustAddEosToken(): boolean {
389
+ this.checkModelLoaded();
390
+ return this.addEosToken;
391
+ }
392
+
393
+ /**
394
+ * Get the jinja chat template comes with the model. It only available if the original model (before converting to gguf) has the template in `tokenizer_config.json`
395
+ *
396
+ * NOTE: This can only being used after `loadModel` is called.
397
+ *
398
+ * @returns the jinja template. null if there is no template in gguf
399
+ */
400
+ getChatTemplate(): string | null {
401
+ this.checkModelLoaded();
402
+ return this.chatTemplate ?? null;
403
+ }
404
+
405
+ /**
406
+ * Check if WebGPU is supported by the current environment.
407
+ * @returns true if WebGPU is supported
408
+ */
409
+ isSupportWebGPU(): boolean {
410
+ return isSupportWebGPU();
411
+ }
412
+
413
+ /**
414
+ * Load model from a given URL (or a list of URLs, in case the model is splitted into smaller files)
415
+ * - If the model already been downloaded (via `downloadModel()`), then we will use the cached model
416
+ * - Else, we download the model from internet
417
+ * @param modelSourceOrURL
418
+ * @param params
419
+ */
420
+ async loadModelFromUrl(
421
+ modelSourceOrURL: ModelSource | string,
422
+ params: LoadModelParams & DownloadOptions & { useCache?: boolean } = {}
423
+ ): Promise<void> {
424
+ const source: ModelSource = isString(modelSourceOrURL)
425
+ ? ({ url: modelSourceOrURL } as ModelSource)
426
+ : (modelSourceOrURL as ModelSource);
427
+ const useCache = params.useCache ?? true;
428
+ const model = useCache
429
+ ? await this.modelManager.getModelOrDownload(source, params)
430
+ : await this.modelManager.downloadModel(source, params);
431
+ const blobs = await model.open();
432
+ return await this.loadModel(blobs, params);
433
+ }
434
+
435
+ /**
436
+ * Load model from a given Hugging Face model ID and file path.
437
+ *
438
+ * @param hfOptions
439
+ * @param params
440
+ */
441
+ async loadModelFromHF(
442
+ hfOptions: HuggingFaceParams,
443
+ params: LoadModelParams & DownloadOptions & { useCache?: boolean } = {}
444
+ ) {
445
+ const source = await getHFModelSource(hfOptions);
446
+ return await this.loadModelFromUrl(source, params);
447
+ }
448
+
449
+ /**
450
+ * Load model from a given list of Blob.
451
+ *
452
+ * You can pass multiple buffers into the function (in case the model contains multiple shards).
453
+ *
454
+ * @param ggufBlobsOrModel Can be either list of Blobs (in case you use local file), or a Model object (in case you use ModelManager)
455
+ * @param params LoadModelParams
456
+ */
457
+ async loadModel(
458
+ ggufBlobsOrModel: Blob[] | Model,
459
+ params: LoadModelParams = {}
460
+ ): Promise<void> {
461
+ const blobs: Blob[] =
462
+ ggufBlobsOrModel instanceof Model
463
+ ? await ggufBlobsOrModel.open()
464
+ : [...(ggufBlobsOrModel as Blob[])]; // copy array
465
+ if (blobs.some((b) => b.size === 0)) {
466
+ throw new WllamaError(
467
+ 'Input model (or splits) must be non-empty Blob or File',
468
+ 'load_error'
469
+ );
470
+ }
471
+ if (!this.pathConfig['default']) {
472
+ throw new WllamaError(
473
+ '"default" is missing from pathConfig',
474
+ 'load_error'
475
+ );
476
+ }
477
+
478
+ if (this.proxy) {
479
+ throw new WllamaError('Module is already initialized', 'load_error');
480
+ }
481
+ // detect if we can use multi-thread and webgpu
482
+ const supportMultiThread = await isSupportMultiThread();
483
+ const hwConccurency = Math.floor((navigator.hardwareConcurrency || 1) / 2);
484
+ const nbThreads = params.n_threads ?? hwConccurency;
485
+ this.nbThreads = nbThreads;
486
+ this.useMultiThread = supportMultiThread && nbThreads > 1;
487
+
488
+ // initialize the worker
489
+ const workerResources = this.getWorkerResources();
490
+ this.proxy = new ProxyToWorker(
491
+ workerResources,
492
+ this.useMultiThread ? nbThreads : 0, // 0 means disable pthread
493
+ this.config.suppressNativeLog ?? false,
494
+ this.logger()
495
+ );
496
+ let logLevel = params.log_level ?? LogLevel.INFO;
497
+ if (this.config.suppressNativeLog) {
498
+ logLevel = 9999 as any;
499
+ }
500
+
501
+ const modelFiles = await prepareBlobs(blobs);
502
+ await this.proxy.moduleInit(modelFiles.all);
503
+
504
+ // run it
505
+ this.logger().debug('Calling wllamaStart...');
506
+ const startResult: any = await this.proxy.wllamaStart();
507
+ if (!startResult.success) {
508
+ throw new WllamaError(
509
+ `Error while calling start function, result = ${startResult}`
510
+ );
511
+ }
512
+
513
+ // load the model
514
+ this.logger().debug('Loading model...');
515
+ const loadResult: GlueMsgLoadRes = await this.proxy.wllamaAction('load', {
516
+ _name: 'load_req',
517
+ log_level: logLevel,
518
+ // if async read is not supported, use mmap; refer to README-dev.md for more details
519
+ use_mmap: !canUseAsyncFileRead(workerResources.compat),
520
+ use_mlock: false,
521
+ n_gpu_layers: params.n_gpu_layers ?? 99999,
522
+ n_ctx: params.n_ctx ?? 1024,
523
+ n_threads: this.useMultiThread ? nbThreads : 1,
524
+ n_ctx_auto: false, // not supported for now
525
+ mmproj_path: modelFiles.mmproj
526
+ ? `/models/${MMPROJ_FILE_NAME}`
527
+ : undefined,
528
+ model_paths: modelFiles.llm.map((f) => `models/${f.name}`),
529
+ embeddings: params.embeddings,
530
+ offload_kqv: params.offload_kqv,
531
+ n_batch: params.n_batch,
532
+ pooling_type: params.pooling_type as string,
533
+ rope_scaling_type: params.rope_scaling_type as string,
534
+ rope_freq_base: params.rope_freq_base,
535
+ rope_freq_scale: params.rope_freq_scale,
536
+ yarn_ext_factor: params.yarn_ext_factor,
537
+ yarn_attn_factor: params.yarn_attn_factor,
538
+ yarn_beta_fast: params.yarn_beta_fast,
539
+ yarn_beta_slow: params.yarn_beta_slow,
540
+ yarn_orig_ctx: params.yarn_orig_ctx,
541
+ cache_type_k: params.cache_type_k as string,
542
+ cache_type_v: params.cache_type_v as string,
543
+ n_parallel: 1, // only support single sequence for now
544
+ kv_unified: false, // TODO: support kv unified cache
545
+ flash_attn: params.flash_attn,
546
+ swa_full: params.swa_full,
547
+ chat_template: params.chat_template,
548
+ jinja: params.jinja,
549
+ reasoning: params.reasoning,
550
+ image_min_tokens: params.image_min_tokens,
551
+ image_max_tokens: params.image_max_tokens,
552
+ warmup: params.warmup,
553
+ no_kv_offload: params.no_kv_offload,
554
+ mmproj_offload: params.mmproj_offload,
555
+ cont_batching: params.cont_batching,
556
+ n_keep: params.n_keep,
557
+ ctx_shift: params.ctx_shift,
558
+ cache_idle_slots: params.cache_idle_slots,
559
+ n_cache_reuse: params.n_cache_reuse,
560
+ lora_paths: params.lora_adapters?.map((a) => a.path),
561
+ lora_scales: params.lora_adapters?.map((a) => a.scale ?? 1.0),
562
+ lora_init_without_apply: params.lora_init_without_apply,
563
+ spec_draft_model: params.spec_draft_model,
564
+ spec_draft_ngl: params.spec_draft_ngl,
565
+ spec_draft_n_max: params.spec_draft_n_max,
566
+ spec_draft_n_min: params.spec_draft_n_min,
567
+ spec_draft_p_min: params.spec_draft_p_min,
568
+ spec_draft_threads: params.spec_draft_threads,
569
+ spec_draft_threads_batch: params.spec_draft_threads_batch,
570
+ kv_overrides_keys: params.kv_overrides
571
+ ? Object.keys(params.kv_overrides)
572
+ : undefined,
573
+ kv_overrides_vals: params.kv_overrides
574
+ ? Object.values(params.kv_overrides)
575
+ : undefined,
576
+ reasoning_budget_tokens: params.reasoning_budget_tokens,
577
+ reasoning_budget_message: params.reasoning_budget_message,
578
+ reasoning_format: params.reasoning_format,
579
+ skip_chat_parsing: params.skip_chat_parsing,
580
+ prefill_assistant: params.prefill_assistant,
581
+ });
582
+ const loadedCtxInfo: LoadedContextInfo & GlueMsgLoadRes = {
583
+ ...loadResult,
584
+ metadata: {},
585
+ };
586
+ for (let i = 0; i < loadResult.metadata_key.length; i++) {
587
+ loadedCtxInfo.metadata[loadResult.metadata_key[i]] =
588
+ loadResult.metadata_val[i];
589
+ }
590
+ this.seed = params.seed;
591
+ this.bosToken = loadedCtxInfo.token_bos;
592
+ this.eosToken = loadedCtxInfo.token_eos;
593
+ this.eotToken = loadedCtxInfo.token_eot;
594
+ this.useEmbeddings = !!params.embeddings;
595
+ this.useRerank = params.pooling_type == 'rank';
596
+ this.metadata = {
597
+ hparams: {
598
+ nVocab: loadedCtxInfo.n_vocab,
599
+ nCtxTrain: loadedCtxInfo.n_ctx_train,
600
+ nEmbd: loadedCtxInfo.n_embd,
601
+ nLayer: loadedCtxInfo.n_layer,
602
+ },
603
+ meta: loadedCtxInfo.metadata,
604
+ };
605
+ this.hasEncoder = !!loadedCtxInfo.has_encoder;
606
+ this.decoderStartToken = loadedCtxInfo.token_decoder_start;
607
+ this.addBosToken = loadedCtxInfo.add_bos_token;
608
+ this.addEosToken = loadedCtxInfo.add_eos_token;
609
+ this.chatTemplate = loadedCtxInfo.metadata['tokenizer.chat_template'];
610
+ this.loadedContextInfo = loadedCtxInfo;
611
+ this.eogTokens = new Set(loadedCtxInfo.list_tokens_eog);
612
+ this.mediaMarker = loadedCtxInfo.media_marker;
613
+ this.chatTemplateKwargs = params.default_template_kwargs ?? {};
614
+ this.logger().debug({ loadedCtxInfo });
615
+ }
616
+
617
+ getLoadedContextInfo(): LoadedContextInfo {
618
+ this.checkModelLoaded();
619
+ if (!this.loadedContextInfo) {
620
+ throw new WllamaError('Loaded context info is not available');
621
+ }
622
+ // copy object
623
+ return { ...this.loadedContextInfo };
624
+ }
625
+
626
+ //////////////////////////////////////////////
627
+ // High level API
628
+
629
+ /**
630
+ * Calculate embedding vector for a given text.
631
+ * By default, BOS and EOS tokens will be added automatically. You can use the "skipBOS" and "skipEOS" option to disable it.
632
+ * @param options OAI-compatible embedding creation options
633
+ * @returns OAI-compatible embedding response
634
+ */
635
+ async createEmbedding(
636
+ options: EmbeddingCreateParams
637
+ ): Promise<CreateEmbeddingResponse> {
638
+ this.checkModelLoaded();
639
+
640
+ if (!this.useEmbeddings) {
641
+ throw new WllamaError(
642
+ 'Embeddings is not enabled. Please set it via LoadModelParams.embeddings'
643
+ );
644
+ }
645
+
646
+ const result = await this.proxy.wllamaAction<GlueMsgEmbeddingRes>(
647
+ 'embedding',
648
+ {
649
+ _name: 'embd_req',
650
+ data_json: JSON.stringify(options),
651
+ files: [], // TODO: support file input
652
+ }
653
+ );
654
+
655
+ if (!result.success) {
656
+ throw new WllamaError(
657
+ 'Model failed to start inference',
658
+ 'inference_error'
659
+ );
660
+ }
661
+
662
+ return await this.getResponse(options as any, false);
663
+ }
664
+
665
+ /**
666
+ * Rerank a list of documents against a query.
667
+ * Requires the model to be loaded with embeddings: true and pooling_type: 'rank'.
668
+ * @param options Reranking options (query, documents, top_n)
669
+ * @returns Reranking response with relevance scores sorted highest first
670
+ */
671
+ async createRerank(options: RerankParams): Promise<RerankResponse> {
672
+ this.checkModelLoaded();
673
+
674
+ if (!this.useEmbeddings || !this.useRerank) {
675
+ throw new WllamaError(
676
+ 'Rerank is not enabled. Please set it via LoadModelParams: embeddings = true and pooling_type = rank'
677
+ );
678
+ }
679
+
680
+ const top_n = options.top_n ?? options.documents.length;
681
+ let totalTokens = 0;
682
+ const rawResults: Array<{ index: number; score: number }> = [];
683
+
684
+ for (let i = 0; i < options.documents.length; i++) {
685
+ const result = await this.proxy.wllamaAction<GlueMsgRerankRes>('rerank', {
686
+ _name: 'rrnk_req',
687
+ data_json: JSON.stringify({
688
+ query: options.query,
689
+ document: options.documents[i],
690
+ }),
691
+ });
692
+
693
+ if (!result.success) {
694
+ throw new WllamaError(
695
+ 'Model failed to start reranking',
696
+ 'inference_error'
697
+ );
698
+ }
699
+
700
+ const { score, tokens_evaluated } = await this.getRerankResult();
701
+ totalTokens += tokens_evaluated;
702
+ rawResults.push({ index: i, score });
703
+ }
704
+
705
+ rawResults.sort((a, b) => b.score - a.score);
706
+ return {
707
+ model: this.getModelMetadata().meta['general.name'] ?? '',
708
+ object: 'list',
709
+ usage: { prompt_tokens: totalTokens, total_tokens: totalTokens },
710
+ results: rawResults.slice(0, top_n).map(({ index, score }) => ({
711
+ index,
712
+ relevance_score: score,
713
+ })),
714
+ };
715
+ }
716
+
717
+ /**
718
+ * Make chat completion for a given chat messages.
719
+ * @param options OAI-compatible chat completion options
720
+ * @returns OAI-compatible chat completion response (only the final result when stream=false) or an async iterator of completion chunks (when stream=true)
721
+ */
722
+ async createChatCompletion(
723
+ options: ChatCompletionParams & { stream?: false }
724
+ ): Promise<ChatCompletionResponse>;
725
+ async createChatCompletion(
726
+ options: ChatCompletionParams & StreamParams<ChatCompletionChunk>
727
+ ): Promise<void>;
728
+ async createChatCompletion(
729
+ options: ChatCompletionParams & { stream: true }
730
+ ): Promise<AsyncIterable<ChatCompletionChunk>>;
731
+ async createChatCompletion(
732
+ options: ChatCompletionParams
733
+ ): Promise<
734
+ ChatCompletionResponse | void | AsyncIterable<ChatCompletionChunk>
735
+ > {
736
+ // first, try to overlay chatTemplateKwargs
737
+ if (Object.keys(this.chatTemplateKwargs).length > 0) {
738
+ options = {
739
+ ...options,
740
+ chat_template_kwargs: {
741
+ ...this.chatTemplateKwargs,
742
+ ...(options.chat_template_kwargs ?? {}),
743
+ },
744
+ };
745
+ }
746
+
747
+ // then, call the corresponding overloaded function
748
+ if (options.stream && (options as any).onData) {
749
+ await this.createCompletionImpl(options);
750
+ } else if (options.stream) {
751
+ return await this.createCompletionGenerator(options);
752
+ } else {
753
+ return await this.createCompletionImpl({ ...options, stream: false });
754
+ }
755
+ }
756
+
757
+ /**
758
+ * Make (raw) completion for a given text.
759
+ * @param options OAI-compatible completion options
760
+ * @returns OAI-compatible completion response (stream=false), void when done (stream=true + onData), or async iterator (stream=true, no onData)
761
+ */
762
+ async createCompletion(
763
+ options: RawCompletionParams & { stream?: false }
764
+ ): Promise<RawCompletionResponse>;
765
+ async createCompletion(
766
+ options: RawCompletionParams & StreamParams<RawCompletionChunk>
767
+ ): Promise<void>;
768
+ async createCompletion(
769
+ options: RawCompletionParams & { stream: true }
770
+ ): Promise<AsyncIterable<RawCompletionChunk>>;
771
+ async createCompletion(
772
+ options: RawCompletionParams
773
+ ): Promise<RawCompletionResponse | void | AsyncIterable<RawCompletionChunk>> {
774
+ if (options.stream && (options as any).onData) {
775
+ await this.createCompletionImpl(options);
776
+ } else if (options.stream) {
777
+ return await this.createCompletionGenerator(options);
778
+ } else {
779
+ return await this.createCompletionImpl({ ...options, stream: false });
780
+ }
781
+ }
782
+
783
+ /**
784
+ * Private implementation of createCompletion
785
+ */
786
+ private async createCompletionImpl<TOpt, TChunk>(
787
+ options: TOpt
788
+ ): Promise<TChunk> {
789
+ this.checkModelLoaded();
790
+
791
+ const isStream = !!(options as any).stream;
792
+ const isChat = !!(options as any).messages;
793
+ const customOpt: any = {};
794
+ if (this.seed !== undefined) {
795
+ customOpt.seed = this.seed;
796
+ }
797
+ let files: ArrayBuffer[] = [];
798
+ if (isChat) {
799
+ const tmp = this.prepareMultimodalInput(
800
+ options as any as ChatCompletionParams
801
+ );
802
+ options = tmp.params as any;
803
+ files = tmp.files;
804
+ }
805
+ const result = await this.proxy.wllamaAction<GlueMsgCompletionRes>(
806
+ 'completion',
807
+ {
808
+ _name: 'cmpl_req',
809
+ is_chat: isChat,
810
+ data_json: JSON.stringify({ ...options, ...customOpt }),
811
+ files: files.map((f) => new Uint8Array(f)),
812
+ }
813
+ );
814
+
815
+ if (!result.success) {
816
+ throw new WllamaError(
817
+ 'Model failed to start inference',
818
+ 'inference_error'
819
+ );
820
+ }
821
+
822
+ return await this.getResponse(
823
+ options as StreamParams<TChunk> & { abortSignal?: AbortSignal },
824
+ isStream
825
+ );
826
+ }
827
+
828
+ /**
829
+ * Same with `createCompletion`, but returns an async iterator instead.
830
+ * Only called when stream=true and no onData is provided.
831
+ */
832
+ private createCompletionGenerator<TOpt, TChunk>(
833
+ options: TOpt
834
+ ): Promise<AsyncIterable<TChunk>> {
835
+ return new Promise((resolve) => {
836
+ const createGenerator = cbToAsyncIter(
837
+ (callback: (val?: TChunk, done?: boolean, err?: Error) => void) => {
838
+ this.createCompletionImpl<TOpt, TChunk>({
839
+ ...options,
840
+ onData: (chunk: TChunk) => callback(chunk),
841
+ })
842
+ .then(() => callback(undefined, true))
843
+ .catch((err) => callback(undefined, false, err));
844
+ }
845
+ );
846
+ resolve(createGenerator());
847
+ });
848
+ }
849
+
850
+ /**
851
+ * Whether the currently loaded model supports a specific input modality (e.g. image or audio).
852
+ * @param modality
853
+ * @returns
854
+ */
855
+ supportInputModality(modality: 'image' | 'audio'): boolean {
856
+ this.checkModelLoaded();
857
+ if (modality === 'image') {
858
+ return !!this.loadedContextInfo.has_image_input;
859
+ } else if (modality === 'audio') {
860
+ return !!this.loadedContextInfo.has_audio_input;
861
+ } else {
862
+ throw new WllamaError(
863
+ 'Unsupported modality: ' + modality,
864
+ 'unknown_error'
865
+ );
866
+ }
867
+ }
868
+
869
+ /**
870
+ * Unload the model and free all memory.
871
+ *
872
+ * Note: This function will NOT crash if model is not yet loaded
873
+ */
874
+ async exit(): Promise<void> {
875
+ await this.proxy?.wllamaExit();
876
+ this.proxy = null as any;
877
+ }
878
+
879
+ /**
880
+ * [FOR DEBUGGING ONLY] Run ggml backend ops tests without loading any model.
881
+ *
882
+ * Initializes the wasm runtime, executes `test-backend-ops` with the given args, then shuts down.
883
+ *
884
+ * For more info, please refer to guides/debug.md
885
+ *
886
+ * @param args Arguments forwarded to test-backend-ops (e.g. ["-o", "ADD"])
887
+ * @returns retcode (0 = all tests passed) and success flag
888
+ */
889
+ async testBackendOps(
890
+ args: string[] = []
891
+ ): Promise<{ retcode: number; success: boolean }> {
892
+ if (!this.pathConfig['default']) {
893
+ throw new WllamaError(
894
+ '"default" is missing from pathConfig',
895
+ 'load_error'
896
+ );
897
+ }
898
+
899
+ if (!(await isSupportMultiThread())) {
900
+ throw new WllamaError(
901
+ 'Multi-threading is required to run backend ops tests, but it is not supported in the current environment.'
902
+ );
903
+ }
904
+
905
+ const tmpProxy = new ProxyToWorker(
906
+ this.getWorkerResources(),
907
+ 0, // single-thread; no model needed
908
+ this.config.suppressNativeLog ?? false,
909
+ this.logger()
910
+ );
911
+
912
+ try {
913
+ await tmpProxy.moduleInit([]);
914
+
915
+ const startResult: any = await tmpProxy.wllamaStart();
916
+ if (!startResult.success) {
917
+ throw new WllamaError(
918
+ `Error while calling start function, result = ${startResult}`
919
+ );
920
+ }
921
+
922
+ const result = await tmpProxy.wllamaAction<GlueMsgTestBackendOpsRes>(
923
+ 'test_backend_ops',
924
+ { _name: 'tbop_req', args: ['test-backend-ops', ...args] }
925
+ );
926
+
927
+ return { retcode: result.retcode, success: result.success };
928
+ } finally {
929
+ await tmpProxy.wllamaExit();
930
+ }
931
+ }
932
+
933
+ //////////////////////////////////////////////
934
+ // Low level API
935
+
936
+ // TODO: add back
937
+
938
+ /**
939
+ * get debug info
940
+ */
941
+ async _getDebugInfo(): Promise<any> {
942
+ this.checkModelLoaded();
943
+ return await this.proxy.wllamaDebug();
944
+ }
945
+
946
+ //////////////////////////////////////////////
947
+ // Utils
948
+
949
+ private jsonDecode(data_json: string) {
950
+ try {
951
+ return JSON.parse(data_json);
952
+ } catch (e) {
953
+ this.logger().error('Failed to parse JSON:', data_json);
954
+ throw new WllamaError('Failed to parse model output', 'inference_error');
955
+ }
956
+ }
957
+
958
+ private prepareMultimodalInput(params: ChatCompletionParams): {
959
+ params: ChatCompletionParams;
960
+ files: ArrayBuffer[];
961
+ } {
962
+ const msg = params.messages;
963
+ const msgNew: typeof msg = [];
964
+ const files: ArrayBuffer[] = [];
965
+ for (const m of msg) {
966
+ if (Array.isArray(m.content)) {
967
+ const newContent: typeof m.content = [];
968
+ for (const c of m.content) {
969
+ if (c.type === 'text') {
970
+ // no transform for text content
971
+ newContent.push(c);
972
+ } else {
973
+ // replace multimodal input with media marker
974
+ if (!this.mediaMarker) {
975
+ throw new WllamaError(
976
+ 'Media marker is undefined',
977
+ 'inference_error'
978
+ );
979
+ }
980
+ files.push(c.data);
981
+ newContent.push({
982
+ type: 'text',
983
+ text: this.mediaMarker,
984
+ });
985
+ }
986
+ }
987
+ msgNew.push({
988
+ ...m,
989
+ content: newContent,
990
+ } as ChatCompletionUserMessage);
991
+ } else {
992
+ // no transform for non-typed content
993
+ msgNew.push(m);
994
+ }
995
+ }
996
+ return {
997
+ params: {
998
+ ...params,
999
+ messages: msgNew,
1000
+ },
1001
+ files,
1002
+ };
1003
+ }
1004
+
1005
+ private async getRerankResult(): Promise<{
1006
+ score: number;
1007
+ tokens_evaluated: number;
1008
+ }> {
1009
+ while (true) {
1010
+ const chunk = await this.proxy.wllamaAction<GlueMsgGetResultRes>(
1011
+ 'get_result',
1012
+ { _name: 'gres_req' }
1013
+ );
1014
+
1015
+ const jsonString = chunk.data_json;
1016
+ if (jsonString && jsonString.length > 0) {
1017
+ if (chunk.is_error) {
1018
+ const jsonData = this.jsonDecode(jsonString);
1019
+ throw new WllamaError(
1020
+ jsonData.message || 'Unknown reranking error',
1021
+ 'inference_error'
1022
+ );
1023
+ }
1024
+ return this.jsonDecode(jsonString);
1025
+ }
1026
+
1027
+ if (!chunk.has_more) break;
1028
+ }
1029
+
1030
+ throw new WllamaError('No reranking result received', 'inference_error');
1031
+ }
1032
+
1033
+ private async getResponse(
1034
+ options: StreamParams<any> & { abortSignal?: AbortSignal },
1035
+ isStream: boolean
1036
+ ) {
1037
+ let finalResult: any = null;
1038
+
1039
+ while (true) {
1040
+ if (options.abortSignal?.aborted) {
1041
+ throw new WllamaAbortError();
1042
+ }
1043
+ const result_chunk = await this.proxy.wllamaAction<GlueMsgGetResultRes>(
1044
+ 'get_result',
1045
+ {
1046
+ _name: 'gres_req',
1047
+ }
1048
+ );
1049
+
1050
+ const jsonString = result_chunk.data_json;
1051
+ if (!jsonString || jsonString.length === 0) {
1052
+ if (!result_chunk.has_more) {
1053
+ break;
1054
+ } else {
1055
+ continue;
1056
+ }
1057
+ }
1058
+
1059
+ if (jsonString == 'null') {
1060
+ continue; // this is the "is_begin = true" chunk on server side, we can ignore it
1061
+ }
1062
+
1063
+ let jsonData = this.jsonDecode(jsonString);
1064
+ finalResult = jsonData;
1065
+ if (result_chunk.is_error) {
1066
+ this.logger().error('Model returned an error:', jsonData);
1067
+ throw new WllamaError(
1068
+ jsonData.message || 'Unknown inference error',
1069
+ 'inference_error'
1070
+ );
1071
+ }
1072
+
1073
+ if (isStream) {
1074
+ if (!Array.isArray(jsonData)) {
1075
+ jsonData = [jsonData];
1076
+ }
1077
+
1078
+ for (const chunk of jsonData) {
1079
+ options.onData?.(chunk);
1080
+ finalResult = chunk;
1081
+ }
1082
+ }
1083
+
1084
+ if (!result_chunk.has_more) {
1085
+ break;
1086
+ }
1087
+ }
1088
+
1089
+ return finalResult;
1090
+ }
1091
+
1092
+ getWorkerResources(): WllamaWorkerResources {
1093
+ const workerResources: WllamaWorkerResources = {
1094
+ wasmPath: absoluteUrl(this.pathConfig['default']),
1095
+ compat: false,
1096
+ };
1097
+ if (needCompat()) {
1098
+ if (!this.compat) {
1099
+ this.logger().warn(
1100
+ 'Not using compat mode' +
1101
+ (isFirefox()
1102
+ ? ' (expected on Firefox - WebGPU will be disabled)'
1103
+ : '')
1104
+ );
1105
+ } else {
1106
+ const isUsingDefault =
1107
+ this.compat.worker === WasmCompatFromCDN.worker &&
1108
+ this.compat.wasm === WasmCompatFromCDN.wasm;
1109
+ if (isUsingDefault) {
1110
+ this.logger().warn(
1111
+ 'Compatibility mode is activated, using resources from CDN. To use local resources, please refer to @wllama/wllama-compat package.'
1112
+ );
1113
+ this.logger().warn(
1114
+ 'IMPORTANT: Performance will be significantly degraded in compatibility mode.'
1115
+ );
1116
+ }
1117
+
1118
+ workerResources.wasmPath = absoluteUrl(this.compat.wasm);
1119
+ workerResources.jsPath = this.compat.worker;
1120
+ workerResources.compat = true;
1121
+ }
1122
+ }
1123
+
1124
+ if (isFirefox()) {
1125
+ if (workerResources.compat) {
1126
+ this.logger().warn(
1127
+ 'Using compat mode on Firefox, performance will be significantly degraded; Consider enabling "javascript.options.wasm_js_promise_integration" in "about:config".'
1128
+ );
1129
+ } else if (!isSupportJSPI()) {
1130
+ this.logger().warn(
1131
+ 'WebGPU is disabled on Firefox due to missing JSPI support. Please consider enabling compat mode, or enabling "javascript.options.wasm_js_promise_integration" in "about:config".'
1132
+ );
1133
+ }
1134
+ }
1135
+
1136
+ return workerResources;
1137
+ }
1138
+ }