@simulatte/doppler 0.1.8 → 0.1.9

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. package/CHANGELOG.md +14 -1
  2. package/README.md +25 -6
  3. package/package.json +5 -3
  4. package/src/client/doppler-api.browser.js +6 -0
  5. package/src/client/doppler-api.d.ts +3 -0
  6. package/src/client/doppler-api.js +11 -2
  7. package/src/client/doppler-registry.js +3 -5
  8. package/src/client/doppler-registry.json +16 -0
  9. package/src/config/kernels/kernel-ref-digests.js +23 -21
  10. package/src/config/kernels/moe/mixtral.paths.json +46 -0
  11. package/src/config/loader.js +6 -0
  12. package/src/config/platforms/loader.js +3 -1
  13. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-nosubgroups.json +16 -16
  14. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-online.json +8 -8
  15. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32a-small-attn.json +61 -0
  16. package/src/config/presets/kernel-paths/registry.json +7 -0
  17. package/src/config/presets/models/gemma3.json +2 -1
  18. package/src/config/presets/models/gemma4.json +61 -0
  19. package/src/config/presets/models/granite-docling.json +70 -0
  20. package/src/config/presets/models/lfm2.json +6 -1
  21. package/src/config/presets/models/qwen3_vl.json +40 -0
  22. package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +2 -1
  23. package/src/config/presets/runtime/experiments/verify/lfm2-verify.json +46 -0
  24. package/src/config/presets/runtime/experiments/verify/translategemma-verify.json +39 -0
  25. package/src/config/presets/runtime/modes/trace-layers.json +1 -0
  26. package/src/config/presets/runtime/tiers/gemma4-16gb.json +69 -0
  27. package/src/config/presets/runtime/tiers/gemma4-24gb.json +66 -0
  28. package/src/config/presets/runtime/tiers/gemma4-32gb.json +66 -0
  29. package/src/config/runtime.js +3 -0
  30. package/src/config/schema/debug.schema.d.ts +40 -0
  31. package/src/config/schema/debug.schema.js +28 -0
  32. package/src/config/schema/index.js +2 -0
  33. package/src/config/schema/inference-defaults.schema.js +1 -1
  34. package/src/config/schema/kernel-path.schema.d.ts +1 -0
  35. package/src/config/schema/memory-limits.schema.js +2 -2
  36. package/src/config/schema/storage.schema.js +1 -1
  37. package/src/converter/conversion-plan.js +1 -1
  38. package/src/converter/core.js +17 -8
  39. package/src/converter/quantizer.d.ts +5 -0
  40. package/src/converter/quantizer.js +15 -0
  41. package/src/distribution/shard-delivery.js +34 -0
  42. package/src/formats/rdrr/classification.js +32 -0
  43. package/src/gpu/kernel-runtime.js +4 -2
  44. package/src/gpu/kernels/attention.js +2 -1
  45. package/src/gpu/kernels/dequant_f16_out.wgsl +4 -2
  46. package/src/gpu/kernels/dequant_f16_out_vec4.wgsl +5 -2
  47. package/src/gpu/kernels/dequant_shared.wgsl +4 -2
  48. package/src/gpu/kernels/dequant_shared_vec4.wgsl +4 -2
  49. package/src/gpu/kernels/dequant_subgroup.wgsl +6 -2
  50. package/src/gpu/kernels/gated-short-conv.d.ts +63 -0
  51. package/src/gpu/kernels/gated-short-conv.js +284 -0
  52. package/src/gpu/kernels/linear-attention-core.js +37 -17
  53. package/src/gpu/kernels/matmul-selection.js +1 -0
  54. package/src/gpu/kernels/matmul.d.ts +3 -0
  55. package/src/gpu/kernels/matmul.js +70 -1
  56. package/src/gpu/kernels/matmul_gemv_subgroup.wgsl +77 -79
  57. package/src/gpu/kernels/sample.js +1 -3
  58. package/src/gpu/kernels/sample.wgsl +39 -9
  59. package/src/gpu/kernels/sample_f16.wgsl +38 -8
  60. package/src/gpu/kernels/shader-cache.js +9 -4
  61. package/src/inference/kv-cache/base.js +3 -10
  62. package/src/inference/pipelines/diffusion/pipeline.js +2 -1
  63. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +2 -1
  64. package/src/inference/pipelines/text/attention/projections.d.ts +3 -0
  65. package/src/inference/pipelines/text/attention/projections.js +13 -2
  66. package/src/inference/pipelines/text/attention/record.js +1 -0
  67. package/src/inference/pipelines/text/attention/run.js +9 -0
  68. package/src/inference/pipelines/text/config.d.ts +1 -0
  69. package/src/inference/pipelines/text/config.js +32 -4
  70. package/src/inference/pipelines/text/embed.js +26 -7
  71. package/src/inference/pipelines/text/execution-v0-runtime-builders.js +10 -3
  72. package/src/inference/pipelines/text/execution-v0.js +12 -1
  73. package/src/inference/pipelines/text/generator-helpers.js +1 -0
  74. package/src/inference/pipelines/text/generator-runtime.js +14 -0
  75. package/src/inference/pipelines/text/generator-steps.d.ts +9 -0
  76. package/src/inference/pipelines/text/generator-steps.js +46 -29
  77. package/src/inference/pipelines/text/generator.d.ts +5 -0
  78. package/src/inference/pipelines/text/generator.js +320 -166
  79. package/src/inference/pipelines/text/init.d.ts +2 -0
  80. package/src/inference/pipelines/text/init.js +19 -5
  81. package/src/inference/pipelines/text/layer.js +37 -8
  82. package/src/inference/pipelines/text/moe-gpu.js +21 -3
  83. package/src/inference/pipelines/text/moe-shape-validator.d.ts +9 -0
  84. package/src/inference/pipelines/text/moe-shape-validator.js +31 -11
  85. package/src/inference/pipelines/text/ops.js +123 -53
  86. package/src/inference/pipelines/text/probes.js +1 -0
  87. package/src/inference/pipelines/text/state.js +2 -0
  88. package/src/inference/pipelines/text.d.ts +5 -0
  89. package/src/inference/pipelines/text.js +59 -1
  90. package/src/inference/pipelines/vision/encoder.js +386 -0
  91. package/src/inference/pipelines/vision/image-preprocess.js +151 -0
  92. package/src/inference/pipelines/vision/index.js +173 -0
  93. package/src/inference/pipelines/vision/ops.js +78 -0
  94. package/src/inference/pipelines/vision/patch-embed.js +151 -0
  95. package/src/inference/test-harness.js +9 -7
  96. package/src/loader/doppler-loader.d.ts +3 -0
  97. package/src/loader/doppler-loader.js +20 -3
  98. package/src/loader/experts/expert-cache.js +6 -2
  99. package/src/loader/experts/expert-loader.js +6 -2
  100. package/src/loader/layer-loader.js +42 -3
  101. package/src/loader/manifest-config.js +3 -1
  102. package/src/loader/tensors/tensor-loader.d.ts +3 -0
  103. package/src/loader/tensors/tensor-loader.js +124 -3
  104. package/src/rules/kernels/moe.rules.mixtral.json +75 -0
  105. package/src/rules/kernels/softmax.rules.json +2 -0
  106. package/src/rules/rule-registry.d.ts +1 -0
  107. package/src/rules/rule-registry.js +2 -0
  108. package/src/storage/quickstart-downloader.d.ts +3 -0
  109. package/src/storage/quickstart-downloader.js +27 -30
  110. package/src/tooling/node-converter.js +25 -7
  111. package/src/tooling/node-source-runtime.js +29 -5
  112. package/src/tooling/node-webgpu.js +24 -7
  113. package/src/utils/hf-resolve-url.d.ts +16 -0
  114. package/src/utils/hf-resolve-url.js +17 -0
  115. package/src/version.js +1 -1
  116. package/src/tooling/node-convert.d.ts +0 -54
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { getDevice, setTrackSubmits } from '../../../gpu/device.js';
4
- import { releaseBuffer, readBuffer, readBufferSlice } from '../../../memory/buffer-pool.js';
4
+ import { releaseBuffer, readBuffer, readBufferSlice, uploadData } from '../../../memory/buffer-pool.js';
5
5
  import { isGPUSamplingAvailable } from '../../../gpu/kernels/sample.js';
6
6
  import { markWarmed as markKernelCacheWarmed } from '../../../gpu/kernel-selection-cache.js';
7
7
  import { resetSubmitStats, logSubmitStats } from '../../../gpu/submit-tracker.js';
@@ -210,6 +210,14 @@ export class PipelineGenerator {
210
210
  return resolveStepOptions(this.#state, options);
211
211
  }
212
212
 
213
+ _resetDecodeRuntimeState() {
214
+ this.#state.decodeStepCount = 0;
215
+ this.#state.disableRecordedLogits = false;
216
+ this.#state.disableFusedDecode = false;
217
+ resetActiveExecutionPlan(this.#state);
218
+ this.#state.decodeRing?.reset();
219
+ }
220
+
213
221
  _getDecodeHelpers(debugCheckBuffer) {
214
222
  return {
215
223
  buildLayerContext: (recorder, isDecodeMode, debugLayers, executionPlan) =>
@@ -235,102 +243,71 @@ export class PipelineGenerator {
235
243
  );
236
244
  }
237
245
 
238
- _beginFinitenessFallback(opts, reasonLabel) {
239
- const originalPlan = resolveActiveExecutionPlan(this.#state);
240
- const original = {
241
- activePlanId: this.#state.executionPlanState?.activePlanId ?? 'primary',
242
- seed: opts.seed,
243
- };
244
-
245
- const fallbackPlan = activateFallbackExecutionPlan(this.#state);
246
- if (!fallbackPlan) {
247
- throw new Error('[Pipeline] Finiteness fallback plan is unavailable for this model/runtime configuration.');
248
- }
249
- log.warn(
250
- 'Pipeline',
251
- `FinitenessGuard fallback (${reasonLabel}): ` +
252
- `${originalPlan.kernelPathId ?? 'none'} -> ${fallbackPlan.kernelPathId ?? 'none'}`
253
- );
254
-
255
- this.#state.decodeBuffers?.ensureBuffers({
256
- hiddenSize: this.#state.modelConfig.hiddenSize,
257
- intermediateSize: this.#state.modelConfig.intermediateSize,
258
- activationDtype: fallbackPlan.activationDtype,
259
- enablePingPong: true,
260
- });
261
-
262
- if (opts.seed == null) {
263
- const fallbackSeedBase = (this.#state.decodeStepCount + this.#state.currentSeqLen + 1) >>> 0;
264
- opts.seed = (fallbackSeedBase * 2654435761) >>> 0;
265
- }
266
- opts.executionPlan = rebaseExecutionSessionPlan(this.#state, opts.executionPlan);
267
-
268
- return original;
246
+ _resolvePromptTokenIds(prompt, useChatTemplate, contextLabel) {
247
+ const processedPrompt = resolvePromptInput(this.#state, prompt, useChatTemplate, contextLabel);
248
+ const inputIds = this.#state.tokenizer.encode(processedPrompt);
249
+ this._assertTokenIdsInRange(inputIds, `${contextLabel}.encode`);
250
+ return inputIds;
269
251
  }
270
252
 
271
- _endFinitenessFallback(opts, original) {
272
- opts.seed = original.seed;
273
- setActiveExecutionPlan(this.#state, original.activePlanId);
274
- opts.executionPlan = rebaseExecutionSessionPlan(this.#state, opts.executionPlan);
275
- const nextActivationDtype = this._getEffectiveActivationDtype();
276
- this.#state.decodeBuffers?.ensureBuffers({
277
- hiddenSize: this.#state.modelConfig.hiddenSize,
278
- intermediateSize: this.#state.modelConfig.intermediateSize,
279
- activationDtype: nextActivationDtype,
280
- enablePingPong: true,
253
+ _sampleNextTokenFromLogits(logits, generatedIds, opts) {
254
+ const sampledLogits = Float32Array.from(logits);
255
+ applyRepetitionPenalty(sampledLogits, generatedIds, opts.repetitionPenalty);
256
+ const padTokenId = this.#state.tokenizer?.getSpecialTokens?.()?.pad;
257
+ return sample(sampledLogits, {
258
+ temperature: opts.temperature,
259
+ topP: opts.topP,
260
+ topK: opts.topK,
261
+ padTokenId,
262
+ seed: opts.seed,
281
263
  });
282
264
  }
283
265
 
284
- async _retryWithFinitenessFallback(opts, reasonLabel, retryFn) {
285
- if (this._hasFinitenessFallbackWindow()) {
286
- return retryFn();
287
- }
288
- this.#state.kvCache?.truncate(this.#state.currentSeqLen);
289
- const original = this._beginFinitenessFallback(opts, reasonLabel);
290
- try {
291
- return await retryFn();
292
- } finally {
293
- this._endFinitenessFallback(opts, original);
266
+ async _prefillPromptToLogits(prompt, opts, contextLabel) {
267
+ const inputIds = this._resolvePromptTokenIds(prompt, opts.useChatTemplate, contextLabel);
268
+ if (opts.debug) {
269
+ log.debug('Pipeline', `${contextLabel}: ${inputIds.length} tokens`);
294
270
  }
295
- }
296
271
 
297
- async _retryDecodeStepWithFinitenessWindow(generatedIds, opts, reasonLabel) {
298
- const windowTokens = this._resolveDeferredRoundingWindowTokens();
299
- if (windowTokens <= 1) {
300
- return this._retryWithFinitenessFallback(
272
+ let logits;
273
+ try {
274
+ logits = await this._prefill(inputIds, opts);
275
+ } catch (error) {
276
+ if (!shouldRetryWithFinitenessFallback(error)) {
277
+ throw error;
278
+ }
279
+ log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during ${contextLabel}. Retrying with F32 precision.`);
280
+ logits = await this._retryWithFinitenessFallback(
301
281
  opts,
302
- reasonLabel,
303
- () => this._decodeStep(generatedIds, opts)
282
+ contextLabel,
283
+ () => this._prefill(inputIds, opts)
304
284
  );
305
285
  }
306
286
 
307
- this.#state.kvCache?.truncate(this.#state.currentSeqLen);
308
- this._openFinitenessFallbackWindow(opts, reasonLabel, windowTokens);
309
- try {
310
- return await this._decodeStep(generatedIds, opts);
311
- } catch (error) {
312
- this._closeFinitenessFallbackWindow(opts);
313
- throw error;
314
- }
287
+ return { inputIds, logits };
315
288
  }
316
289
 
317
- // ==========================================================================
318
- // Generation Public API
319
- // ==========================================================================
290
+ async _decodeStepToLogits(currentIds, opts) {
291
+ const debugCheckBuffer = this.#state.debug
292
+ ? (buffer, label, numTokens, expectedDim) =>
293
+ debugCheckBufferHelper(this.#state, buffer, label, numTokens, expectedDim)
294
+ : undefined;
295
+ return decodeStepLogits(this.#state, currentIds, opts, this._getDecodeHelpers(debugCheckBuffer));
296
+ }
320
297
 
298
+ async _decodeNextTokenViaLogits(currentIds, opts) {
299
+ const stepResult = await this._decodeStepToLogits(currentIds, opts);
300
+ return this._sampleNextTokenFromLogits(stepResult.logits, currentIds, opts);
301
+ }
321
302
 
322
- async *generate(prompt, options = {}) {
303
+ async *_generateTokensInternal(prompt, options = {}, mode = 'text') {
323
304
  if (!this.#state.isLoaded) throw new Error('Model not loaded');
324
305
  if (this.#state.isGenerating) throw new Error('Generation already in progress');
325
306
 
326
307
  validateCallTimeOptions(options);
327
308
 
328
309
  this.#state.isGenerating = true;
329
- this.#state.decodeStepCount = 0;
330
- this.#state.disableRecordedLogits = false;
331
- this.#state.disableFusedDecode = false;
332
- resetActiveExecutionPlan(this.#state);
333
- this.#state.decodeRing?.reset();
310
+ this._resetDecodeRuntimeState();
334
311
  this.#state.stats.gpuTimePrefillMs = undefined;
335
312
  this.#state.stats.gpuTimeDecodeMs = undefined;
336
313
  this.#state.stats.decodeRecordMs = 0;
@@ -345,14 +322,23 @@ export class PipelineGenerator {
345
322
  log.debug('Pipeline', `ChatTemplate: options=${options.useChatTemplate}, final=${opts.useChatTemplate}`);
346
323
  }
347
324
 
348
- try {
349
- const processedPrompt = resolvePromptInput(this.#state, prompt, opts.useChatTemplate, 'generate');
350
- if (opts.debug && opts.useChatTemplate) {
351
- log.debug('Pipeline', `Applied ${this.#state.modelConfig.chatTemplateType} chat template`);
325
+ const emitToken = async function* (generator, tokenId, textDecoder) {
326
+ if (mode === 'token') {
327
+ yield tokenId;
328
+ if (options.onToken) options.onToken(tokenId, '');
329
+ return;
352
330
  }
331
+ const tokenText = textDecoder(tokenId);
332
+ yield tokenText;
333
+ if (options.onToken) options.onToken(tokenId, tokenText);
334
+ };
353
335
 
354
- const inputIds = this.#state.tokenizer.encode(processedPrompt);
355
- this._assertTokenIdsInRange(inputIds, 'generate.encode');
336
+ try {
337
+ const prefillStart = performance.now();
338
+ const { inputIds, logits: initialPrefillLogits } = await this._prefillPromptToLogits(prompt, opts, 'generate');
339
+ let prefillLogits = initialPrefillLogits;
340
+ this.#state.stats.prefillTimeMs = performance.now() - prefillStart;
341
+ this._assertTokenIdsInRange(inputIds, 'generate.prefillTokens');
356
342
  const generatedIds = [...inputIds];
357
343
  this.#state.stats.prefillTokens = inputIds.length;
358
344
 
@@ -360,24 +346,6 @@ export class PipelineGenerator {
360
346
  log.debug('Pipeline', `Input: ${inputIds.length} tokens`);
361
347
  }
362
348
 
363
- const prefillStart = performance.now();
364
- let prefillLogits;
365
- try {
366
- prefillLogits = await this._prefill(inputIds, opts);
367
- } catch (error) {
368
- if (shouldRetryWithFinitenessFallback(error)) {
369
- log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefill. Retrying with F32 precision.`);
370
- prefillLogits = await this._retryWithFinitenessFallback(
371
- opts,
372
- 'prefill',
373
- () => this._prefill(inputIds, opts)
374
- );
375
- } else {
376
- throw error;
377
- }
378
- }
379
- this.#state.stats.prefillTimeMs = performance.now() - prefillStart;
380
-
381
349
  const intentBundleConfig = this.#state.runtimeConfig.shared.intentBundle;
382
350
  const intentBundle = intentBundleConfig?.bundle;
383
351
  const expectedTopK = intentBundle?.payload?.expectedTopK
@@ -389,20 +357,17 @@ export class PipelineGenerator {
389
357
  const actualTopK = getTopK(
390
358
  prefillLogits,
391
359
  expectedTopK.length,
392
- (tokens) => resolveTokenText(this.#state.tokenizer, tokens),
393
- ).map((token) => token.token);
360
+ (tokens) => resolveTokenText(this.#state.tokenizer, tokens),
361
+ ).map((token) => token.token);
394
362
  const driftResult = enforceLogitDrift(expectedTopK, actualTopK, maxDriftThreshold);
395
363
  if (!driftResult.ok) {
396
364
  throw new Error(`Intent bundle drift check failed: ${driftResult.reason}`);
397
365
  }
398
366
  }
399
367
 
400
- applyRepetitionPenalty(prefillLogits, generatedIds, opts.repetitionPenalty);
401
- const padTokenId = this.#state.tokenizer?.getSpecialTokens?.()?.pad;
402
-
403
368
  if (opts.debug) {
404
369
  const topAfterPenalty = getTopK(
405
- prefillLogits,
370
+ Float32Array.from(prefillLogits),
406
371
  5,
407
372
  (tokens) => resolveTokenText(this.#state.tokenizer, tokens)
408
373
  );
@@ -411,13 +376,7 @@ export class PipelineGenerator {
411
376
 
412
377
  let firstToken;
413
378
  try {
414
- firstToken = sample(prefillLogits, {
415
- temperature: opts.temperature,
416
- topP: opts.topP,
417
- topK: opts.topK,
418
- padTokenId,
419
- seed: opts.seed,
420
- });
379
+ firstToken = this._sampleNextTokenFromLogits(prefillLogits, generatedIds, opts);
421
380
  } catch (error) {
422
381
  if (!shouldRetryWithFinitenessFallback(error)) {
423
382
  throw error;
@@ -428,14 +387,7 @@ export class PipelineGenerator {
428
387
  'prefill-sample',
429
388
  () => this._prefill(inputIds, opts)
430
389
  );
431
- applyRepetitionPenalty(prefillLogits, generatedIds, opts.repetitionPenalty);
432
- firstToken = sample(prefillLogits, {
433
- temperature: opts.temperature,
434
- topP: opts.topP,
435
- topK: opts.topK,
436
- padTokenId,
437
- seed: opts.seed,
438
- });
390
+ firstToken = this._sampleNextTokenFromLogits(prefillLogits, generatedIds, opts);
439
391
  }
440
392
 
441
393
  if (opts.debug) {
@@ -454,9 +406,7 @@ export class PipelineGenerator {
454
406
  (tokens) => this.#state.tokenizer?.decode?.(tokens, false, false)
455
407
  );
456
408
 
457
- const firstText = decodeToken(firstToken);
458
- yield firstText;
459
- if (options.onToken) options.onToken(firstToken, firstText);
409
+ yield* emitToken(this, firstToken, decodeToken);
460
410
 
461
411
  yield* this._runDecodeLoop(generatedIds, opts, options, {
462
412
  stopTokenIds: this.#state.modelConfig.stopTokenIds,
@@ -464,6 +414,7 @@ export class PipelineGenerator {
464
414
  stopSequenceStart: inputIds.length,
465
415
  decodeToken,
466
416
  logBatchPath: opts.debug,
417
+ emitMode: mode,
467
418
  });
468
419
  const tokensGenerated = this.#state.stats.decodeTokens;
469
420
  this.#state.stats.totalTimeMs = performance.now() - startTime;
@@ -495,17 +446,203 @@ export class PipelineGenerator {
495
446
  }
496
447
  }
497
448
 
449
+ _beginFinitenessFallback(opts, reasonLabel) {
450
+ const originalPlan = resolveActiveExecutionPlan(this.#state);
451
+ const original = {
452
+ activePlanId: this.#state.executionPlanState?.activePlanId ?? 'primary',
453
+ seed: opts.seed,
454
+ };
455
+
456
+ const fallbackPlan = activateFallbackExecutionPlan(this.#state);
457
+ if (!fallbackPlan) {
458
+ throw new Error('[Pipeline] Finiteness fallback plan is unavailable for this model/runtime configuration.');
459
+ }
460
+ log.warn(
461
+ 'Pipeline',
462
+ `FinitenessGuard fallback (${reasonLabel}): ` +
463
+ `${originalPlan.kernelPathId ?? 'none'} -> ${fallbackPlan.kernelPathId ?? 'none'}`
464
+ );
498
465
 
499
- async prefillKVOnly(prompt, options = {}) {
466
+ this.#state.decodeBuffers?.ensureBuffers({
467
+ hiddenSize: this.#state.modelConfig.hiddenSize,
468
+ intermediateSize: this.#state.modelConfig.intermediateSize,
469
+ activationDtype: fallbackPlan.activationDtype,
470
+ enablePingPong: true,
471
+ });
472
+
473
+ if (opts.seed == null) {
474
+ const fallbackSeedBase = (this.#state.decodeStepCount + this.#state.currentSeqLen + 1) >>> 0;
475
+ opts.seed = (fallbackSeedBase * 2654435761) >>> 0;
476
+ }
477
+ opts.executionPlan = rebaseExecutionSessionPlan(this.#state, opts.executionPlan);
478
+
479
+ return original;
480
+ }
481
+
482
+ _endFinitenessFallback(opts, original) {
483
+ opts.seed = original.seed;
484
+ setActiveExecutionPlan(this.#state, original.activePlanId);
485
+ opts.executionPlan = rebaseExecutionSessionPlan(this.#state, opts.executionPlan);
486
+ const nextActivationDtype = this._getEffectiveActivationDtype();
487
+ this.#state.decodeBuffers?.ensureBuffers({
488
+ hiddenSize: this.#state.modelConfig.hiddenSize,
489
+ intermediateSize: this.#state.modelConfig.intermediateSize,
490
+ activationDtype: nextActivationDtype,
491
+ enablePingPong: true,
492
+ });
493
+ }
494
+
495
+ async _retryWithFinitenessFallback(opts, reasonLabel, retryFn) {
496
+ if (this._hasFinitenessFallbackWindow()) {
497
+ return retryFn();
498
+ }
499
+ this.#state.kvCache?.truncate(this.#state.currentSeqLen);
500
+ const original = this._beginFinitenessFallback(opts, reasonLabel);
501
+ try {
502
+ return await retryFn();
503
+ } finally {
504
+ this._endFinitenessFallback(opts, original);
505
+ }
506
+ }
507
+
508
+ async _retryDecodeStepWithFinitenessWindow(generatedIds, opts, reasonLabel) {
509
+ const windowTokens = this._resolveDeferredRoundingWindowTokens();
510
+ if (windowTokens <= 1) {
511
+ return this._retryWithFinitenessFallback(
512
+ opts,
513
+ reasonLabel,
514
+ () => this._decodeStep(generatedIds, opts)
515
+ );
516
+ }
517
+
518
+ this.#state.kvCache?.truncate(this.#state.currentSeqLen);
519
+ this._openFinitenessFallbackWindow(opts, reasonLabel, windowTokens);
520
+ try {
521
+ return await this._decodeStep(generatedIds, opts);
522
+ } catch (error) {
523
+ this._closeFinitenessFallbackWindow(opts);
524
+ throw error;
525
+ }
526
+ }
527
+
528
+ // ==========================================================================
529
+ // Generation Public API
530
+ // ==========================================================================
531
+
532
+
533
+ async *generate(prompt, options = {}) {
534
+ yield* this._generateTokensInternal(prompt, options, 'text');
535
+ }
536
+
537
+ async *generateTokens(prompt, options = {}) {
538
+ yield* this._generateTokensInternal(prompt, options, 'token');
539
+ }
540
+
541
+ async generateTokenIds(prompt, options = {}) {
500
542
  if (!this.#state.isLoaded) throw new Error('Model not loaded');
501
- resetActiveExecutionPlan(this.#state);
543
+ if (this.#state.isGenerating) throw new Error('Generation already in progress');
544
+
545
+ validateCallTimeOptions(options);
546
+
547
+ this.#state.isGenerating = true;
548
+ this._resetDecodeRuntimeState();
502
549
  this.#state.stats.gpuTimePrefillMs = undefined;
503
- const opts = resolvePrefillOptions(this.#state, options);
550
+ this.#state.stats.gpuTimeDecodeMs = undefined;
551
+ this.#state.stats.decodeRecordMs = 0;
552
+ this.#state.stats.decodeSubmitWaitMs = 0;
553
+ this.#state.stats.decodeReadbackWaitMs = 0;
554
+ this.#state.stats.ttftMs = 0;
555
+ const startTime = performance.now();
556
+ const opts = resolveGenerateOptions(this.#state, options);
504
557
 
505
- const processedPrompt = resolvePromptInput(this.#state, prompt, opts.useChatTemplate, 'prefillKVOnly');
558
+ try {
559
+ const prefillStart = performance.now();
560
+ const { inputIds, logits: initialPrefillLogits } = await this._prefillPromptToLogits(prompt, opts, 'generateTokenIds');
561
+ let prefillLogits = initialPrefillLogits;
562
+ this.#state.stats.prefillTimeMs = performance.now() - prefillStart;
563
+ this._assertTokenIdsInRange(inputIds, 'generateTokenIds.prefillTokens');
564
+ const generatedIds = [...inputIds];
565
+ this.#state.stats.prefillTokens = inputIds.length;
506
566
 
507
- const inputIds = this.#state.tokenizer.encode(processedPrompt);
508
- this._assertTokenIdsInRange(inputIds, 'prefillKVOnly.encode');
567
+ let firstToken;
568
+ try {
569
+ firstToken = this._sampleNextTokenFromLogits(prefillLogits, generatedIds, opts);
570
+ } catch (error) {
571
+ if (!shouldRetryWithFinitenessFallback(error)) {
572
+ throw error;
573
+ }
574
+ prefillLogits = await this._retryWithFinitenessFallback(
575
+ opts,
576
+ 'prefill-sample',
577
+ () => this._prefill(inputIds, opts)
578
+ );
579
+ firstToken = this._sampleNextTokenFromLogits(prefillLogits, generatedIds, opts);
580
+ }
581
+
582
+ generatedIds.push(firstToken);
583
+ const tokenIds = [firstToken];
584
+ this.#state.stats.ttftMs = performance.now() - startTime;
585
+
586
+ const stopTokenIds = this.#state.modelConfig.stopTokenIds;
587
+ const eosToken = this.#state.tokenizer.getSpecialTokens?.()?.eos;
588
+ const stopSequenceStart = inputIds.length;
589
+ markKernelCacheWarmed();
590
+ const decodeStart = performance.now();
591
+
592
+ while (tokenIds.length < opts.maxTokens) {
593
+ if (options.signal?.aborted) break;
594
+ let nextToken;
595
+ try {
596
+ nextToken = await this._decodeNextTokenViaLogits(generatedIds, opts);
597
+ } catch (error) {
598
+ if (shouldRetryWithFinitenessFallback(error)) {
599
+ nextToken = await this._retryDecodeStepWithFinitenessWindow(
600
+ generatedIds,
601
+ opts,
602
+ `decode-step-${tokenIds.length}`
603
+ );
604
+ } else {
605
+ throw error;
606
+ }
607
+ }
608
+ generatedIds.push(nextToken);
609
+ tokenIds.push(nextToken);
610
+ this._consumeFinitenessFallbackToken(opts);
611
+ if (isStopToken(nextToken, stopTokenIds, eosToken)) {
612
+ break;
613
+ }
614
+ if (opts.stopSequences.length > 0) {
615
+ const fullText = this.#state.tokenizer.decode(generatedIds.slice(stopSequenceStart), false);
616
+ if (opts.stopSequences.some((seq) => fullText.endsWith(seq))) break;
617
+ }
618
+ }
619
+
620
+ this.#state.stats.decodeTimeMs = performance.now() - decodeStart;
621
+ this.#state.stats.tokensGenerated = tokenIds.length;
622
+ this.#state.stats.decodeTokens = tokenIds.length;
623
+ this.#state.stats.totalTimeMs = performance.now() - startTime;
624
+
625
+ return {
626
+ tokenIds,
627
+ stats: this.#state.stats,
628
+ };
629
+ } finally {
630
+ this._closeFinitenessFallbackWindow(opts);
631
+ resetActiveExecutionPlan(this.#state);
632
+ this.#state.isGenerating = false;
633
+ }
634
+ }
635
+
636
+
637
+ async prefillKVOnly(prompt, options = {}) {
638
+ if (!this.#state.isLoaded) throw new Error('Model not loaded');
639
+ if (this.#state.isGenerating && options.__internalGenerate !== true) {
640
+ throw new Error('Generation already in progress');
641
+ }
642
+ this._resetDecodeRuntimeState();
643
+ this.#state.stats.gpuTimePrefillMs = undefined;
644
+ const opts = resolvePrefillOptions(this.#state, options);
645
+ const inputIds = this._resolvePromptTokenIds(prompt, opts.useChatTemplate, 'prefillKVOnly');
509
646
  if (opts.debug) {
510
647
  log.debug('Pipeline', `PrefillKVOnly: ${inputIds.length} tokens`);
511
648
  }
@@ -563,14 +700,13 @@ export class PipelineGenerator {
563
700
 
564
701
  async prefillWithEmbedding(prompt, options = {}) {
565
702
  if (!this.#state.isLoaded) throw new Error('Model not loaded');
566
- resetActiveExecutionPlan(this.#state);
703
+ if (this.#state.isGenerating && options.__internalGenerate !== true) {
704
+ throw new Error('Generation already in progress');
705
+ }
706
+ this._resetDecodeRuntimeState();
567
707
  this.#state.stats.gpuTimePrefillMs = undefined;
568
708
  const opts = resolvePrefillEmbeddingOptions(this.#state, options);
569
-
570
- const processedPrompt = resolvePromptInput(this.#state, prompt, opts.useChatTemplate, 'prefillWithEmbedding');
571
-
572
- const inputIds = this.#state.tokenizer.encode(processedPrompt);
573
- this._assertTokenIdsInRange(inputIds, 'prefillWithEmbedding.encode');
709
+ const inputIds = this._resolvePromptTokenIds(prompt, opts.useChatTemplate, 'prefillWithEmbedding');
574
710
  if (opts.debug) {
575
711
  log.debug('Pipeline', `PrefillWithEmbedding: ${inputIds.length} tokens (mode=${opts.embeddingMode})`);
576
712
  }
@@ -658,19 +794,13 @@ export class PipelineGenerator {
658
794
 
659
795
  async prefillWithLogits(prompt, options = {}) {
660
796
  if (!this.#state.isLoaded) throw new Error('Model not loaded');
661
- resetActiveExecutionPlan(this.#state);
797
+ if (this.#state.isGenerating && options.__internalGenerate !== true) {
798
+ throw new Error('Generation already in progress');
799
+ }
800
+ this._resetDecodeRuntimeState();
662
801
  this.#state.stats.gpuTimePrefillMs = undefined;
663
802
  const opts = resolvePrefillOptions(this.#state, options);
664
-
665
- const processedPrompt = resolvePromptInput(this.#state, prompt, opts.useChatTemplate, 'prefillWithLogits');
666
-
667
- const inputIds = this.#state.tokenizer.encode(processedPrompt);
668
- this._assertTokenIdsInRange(inputIds, 'prefillWithLogits.encode');
669
- if (opts.debug) {
670
- log.debug('Pipeline', `PrefillWithLogits: ${inputIds.length} tokens`);
671
- }
672
-
673
- const logits = await this._prefill(inputIds, opts);
803
+ const { inputIds, logits } = await this._prefillPromptToLogits(prompt, opts, 'prefillWithLogits');
674
804
 
675
805
  const snapshot = this.#state.kvCache?.clone();
676
806
  if (!snapshot) {
@@ -792,6 +922,7 @@ export class PipelineGenerator {
792
922
  stopSequenceStart,
793
923
  decodeToken,
794
924
  logBatchPath = false,
925
+ emitMode = 'text',
795
926
  } = runtime;
796
927
 
797
928
  let tokensGenerated = 1;
@@ -821,6 +952,9 @@ export class PipelineGenerator {
821
952
  }
822
953
  const readbackInterval = executionPlan.readbackInterval;
823
954
  const intervalBatches = readbackInterval == null ? 1 : readbackInterval;
955
+ const padTokenId = this.#state.tokenizer?.getSpecialTokens?.()?.pad;
956
+
957
+ const decodeSingleTokenViaLogits = async () => this._decodeNextTokenViaLogits(generatedIds, opts);
824
958
 
825
959
  if (logBatchPath && useBatchPath) {
826
960
  log.debug(
@@ -846,10 +980,16 @@ export class PipelineGenerator {
846
980
  for (const tokenId of batchResult.tokens) {
847
981
  generatedIds.push(tokenId);
848
982
  tokensGenerated++;
849
- const tokenText = decodeToken(tokenId);
850
- yield tokenText;
851
- if (options.onToken) options.onToken(tokenId, tokenText);
852
- batchTokens.push({ id: tokenId, text: tokenText });
983
+ if (emitMode === 'token') {
984
+ yield tokenId;
985
+ if (options.onToken) options.onToken(tokenId, '');
986
+ batchTokens.push({ id: tokenId, text: '' });
987
+ } else {
988
+ const tokenText = decodeToken(tokenId);
989
+ yield tokenText;
990
+ if (options.onToken) options.onToken(tokenId, tokenText);
991
+ batchTokens.push({ id: tokenId, text: tokenText });
992
+ }
853
993
  if (batchTokens.length === executionPlan.batchSize) {
854
994
  if (options.onBatch) options.onBatch(batchTokens);
855
995
  batchTokens = [];
@@ -866,7 +1006,7 @@ export class PipelineGenerator {
866
1006
  useBatchPath = false;
867
1007
  let nextToken;
868
1008
  try {
869
- nextToken = await this._decodeStep(generatedIds, opts);
1009
+ nextToken = await decodeSingleTokenViaLogits();
870
1010
  } catch (singleTokenError) {
871
1011
  if (shouldRetryWithFinitenessFallback(singleTokenError)) {
872
1012
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at batch step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
@@ -881,9 +1021,14 @@ export class PipelineGenerator {
881
1021
  }
882
1022
  generatedIds.push(nextToken);
883
1023
  tokensGenerated++;
884
- const tokenText = decodeToken(nextToken);
885
- yield tokenText;
886
- if (options.onToken) options.onToken(nextToken, tokenText);
1024
+ if (emitMode === 'token') {
1025
+ yield nextToken;
1026
+ if (options.onToken) options.onToken(nextToken, '');
1027
+ } else {
1028
+ const tokenText = decodeToken(nextToken);
1029
+ yield tokenText;
1030
+ if (options.onToken) options.onToken(nextToken, tokenText);
1031
+ }
887
1032
  this._consumeFinitenessFallbackToken(opts);
888
1033
  if (isStopToken(nextToken, stopTokenIds, eosToken)) break;
889
1034
  }
@@ -891,7 +1036,7 @@ export class PipelineGenerator {
891
1036
  const tokenStart = performance.now();
892
1037
  let nextToken;
893
1038
  try {
894
- nextToken = await this._decodeStep(generatedIds, opts);
1039
+ nextToken = await decodeSingleTokenViaLogits();
895
1040
  } catch (error) {
896
1041
  if (shouldRetryWithFinitenessFallback(error)) {
897
1042
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
@@ -907,9 +1052,14 @@ export class PipelineGenerator {
907
1052
  const tokenTime = performance.now() - tokenStart;
908
1053
  generatedIds.push(nextToken);
909
1054
  tokensGenerated++;
910
- const tokenText = decodeToken(nextToken);
911
- yield tokenText;
912
- if (options.onToken) options.onToken(nextToken, tokenText);
1055
+ const tokenText = emitMode === 'token' ? '' : decodeToken(nextToken);
1056
+ if (emitMode === 'token') {
1057
+ yield nextToken;
1058
+ if (options.onToken) options.onToken(nextToken, '');
1059
+ } else {
1060
+ yield tokenText;
1061
+ if (options.onToken) options.onToken(nextToken, tokenText);
1062
+ }
913
1063
  this._consumeFinitenessFallbackToken(opts);
914
1064
 
915
1065
  if (opts.debug || opts.benchmark) {
@@ -947,6 +1097,13 @@ export class PipelineGenerator {
947
1097
  if (startPos === 0 && hasLinearAttentionLayers(config.layerTypes)) {
948
1098
  this.#state.linearAttentionRuntime = resetLinearAttentionRuntime(this.#state.linearAttentionRuntime);
949
1099
  }
1100
+ if (startPos === 0) {
1101
+ for (const [, convState] of this.#state.convLayerStates) {
1102
+ if (convState.convStateGPU && convState.hiddenSize && convState.kernelSize) {
1103
+ uploadData(convState.convStateGPU, new Float32Array(convState.hiddenSize * (convState.kernelSize - 1)));
1104
+ }
1105
+ }
1106
+ }
950
1107
 
951
1108
  const embedBufferRaw = this.#state.weights.get('embed');
952
1109
  if (!(embedBufferRaw instanceof GPUBuffer) && !isWeightBuffer(embedBufferRaw) && !isCpuWeightBuffer(embedBufferRaw) && !(embedBufferRaw instanceof Float32Array)) {
@@ -1296,18 +1453,15 @@ export class PipelineGenerator {
1296
1453
 
1297
1454
  async decodeStepLogits(currentIds, options = {}) {
1298
1455
  if (!this.#state.isLoaded) throw new Error('Model not loaded');
1299
- if (this.#state.isGenerating) throw new Error('Generation already in progress');
1456
+ if (this.#state.isGenerating && options.__internalGenerate !== true) {
1457
+ throw new Error('Generation already in progress');
1458
+ }
1300
1459
  resetActiveExecutionPlan(this.#state);
1301
1460
 
1302
1461
  validateCallTimeOptions(options);
1303
1462
 
1304
1463
  const opts = this._resolveStepOptions(options);
1305
- const debugCheckBuffer = this.#state.debug
1306
- ? (buffer, label, numTokens, expectedDim) =>
1307
- debugCheckBufferHelper(this.#state, buffer, label, numTokens, expectedDim)
1308
- : undefined;
1309
-
1310
- return decodeStepLogits(this.#state, currentIds, opts, this._getDecodeHelpers(debugCheckBuffer));
1464
+ return this._decodeStepToLogits(currentIds, opts);
1311
1465
  }
1312
1466
 
1313
1467
  async advanceWithToken(tokenId, options = {}) {