@simulatte/doppler 0.1.7 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. package/CHANGELOG.md +19 -0
  2. package/package.json +21 -36
  3. package/src/browser/browser-converter.js +5 -0
  4. package/src/client/doppler-registry.json +1 -17
  5. package/src/config/kernel-path-loader.d.ts +5 -0
  6. package/src/config/kernel-path-loader.js +13 -0
  7. package/src/config/kernels/registry.json +74 -0
  8. package/src/config/loader.js +3 -0
  9. package/src/config/merge-contract-check.js +7 -0
  10. package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
  11. package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
  12. package/src/config/presets/kernel-paths/registry.json +14 -0
  13. package/src/config/presets/models/gemma2.json +2 -1
  14. package/src/config/presets/models/gemma3.json +2 -0
  15. package/src/config/presets/models/qwen3.json +4 -3
  16. package/src/config/presets/models/qwen3_5.json +16 -0
  17. package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
  18. package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
  19. package/src/config/schema/conversion.schema.d.ts +1 -0
  20. package/src/config/schema/manifest.schema.d.ts +1 -1
  21. package/src/config/schema/manifest.schema.js +1 -1
  22. package/src/config/schema/storage.schema.js +1 -1
  23. package/src/converter/conversion-plan.js +10 -2
  24. package/src/converter/core.js +2 -0
  25. package/src/converter/manifest-inference.js +12 -22
  26. package/src/converter/parsers/transformer.js +4 -0
  27. package/src/converter/quantization-info.js +5 -1
  28. package/src/converter/quantizer.js +19 -12
  29. package/src/converter/rope-config.js +8 -6
  30. package/src/converter/tokenizer-utils.d.ts +1 -0
  31. package/src/converter/tokenizer-utils.js +4 -1
  32. package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
  33. package/src/distribution/shard-delivery.js +6 -1
  34. package/src/formats/rdrr/parsing.d.ts +4 -0
  35. package/src/formats/rdrr/parsing.js +14 -1
  36. package/src/gpu/kernels/index.d.ts +8 -0
  37. package/src/gpu/kernels/index.js +6 -0
  38. package/src/gpu/kernels/matmul-selection.js +47 -4
  39. package/src/gpu/kernels/matmul.d.ts +2 -0
  40. package/src/gpu/kernels/matmul.js +1 -1
  41. package/src/gpu/kernels/rmsnorm.js +9 -2
  42. package/src/gpu/kernels/split_qg.d.ts +50 -0
  43. package/src/gpu/kernels/split_qg.js +46 -0
  44. package/src/gpu/kernels/split_qg.wgsl +58 -0
  45. package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
  46. package/src/gpu/weight-buffer.d.ts +1 -1
  47. package/src/gpu/weight-buffer.js +1 -1
  48. package/src/inference/browser-harness.d.ts +2 -0
  49. package/src/inference/browser-harness.js +20 -1
  50. package/src/inference/pipelines/diffusion/helpers.js +3 -0
  51. package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
  52. package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
  53. package/src/inference/pipelines/text/attention/output-projection.js +8 -0
  54. package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
  55. package/src/inference/pipelines/text/attention/projections.js +41 -11
  56. package/src/inference/pipelines/text/attention/record.js +15 -6
  57. package/src/inference/pipelines/text/attention/run.js +50 -6
  58. package/src/inference/pipelines/text/config.js +14 -0
  59. package/src/inference/pipelines/text/execution-plan.js +5 -4
  60. package/src/inference/pipelines/text/generator-runtime.js +5 -0
  61. package/src/inference/pipelines/text/generator-steps.d.ts +6 -0
  62. package/src/inference/pipelines/text/generator-steps.js +43 -15
  63. package/src/inference/pipelines/text/generator.js +50 -17
  64. package/src/inference/pipelines/text/init.d.ts +13 -0
  65. package/src/inference/pipelines/text/init.js +16 -5
  66. package/src/inference/pipelines/text/layer.js +1 -0
  67. package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
  68. package/src/inference/pipelines/text/linear-attention.js +33 -3
  69. package/src/inference/pipelines/text/logits/gpu.js +2 -2
  70. package/src/inference/pipelines/text/logits/index.d.ts +6 -1
  71. package/src/inference/pipelines/text/logits/index.js +3 -1
  72. package/src/inference/pipelines/text/model-load.js +3 -0
  73. package/src/inference/pipelines/text/sampling.js +52 -6
  74. package/src/inference/test-harness.js +2 -2
  75. package/src/loader/final-weights-loader.js +2 -0
  76. package/src/loader/shard-cache.js +3 -2
  77. package/src/loader/tensors/tensor-loader.js +6 -1
  78. package/src/rules/inference/dtype.rules.json +5 -0
  79. package/src/rules/inference/kernel-path.rules.json +2 -2
  80. package/src/rules/kernels/split-qg.rules.json +6 -0
  81. package/src/rules/rule-registry.js +2 -0
  82. package/src/storage/downloader.js +2 -1
  83. package/src/storage/shard-manager.js +4 -3
  84. package/src/tooling/conversion-config-materializer.js +3 -5
  85. package/src/tooling/node-converter.js +3 -0
  86. package/src/tooling/node-source-runtime.js +36 -0
  87. package/src/types/model.d.ts +5 -0
  88. package/tools/doppler-cli.js +6 -1
@@ -122,6 +122,20 @@ function resolveTokenText(tokenizer, tokenIds, fallbackText = '?', renderTokenTe
122
122
  return fallbackText;
123
123
  }
124
124
 
125
+ export function shouldRetryWithFinitenessFallback(error) {
126
+ if (error?.name === 'FinitenessError') {
127
+ return true;
128
+ }
129
+ const message = typeof error?.message === 'string'
130
+ ? error.message
131
+ : (typeof error === 'string' ? error : '');
132
+ if (!message.startsWith('[Sampling]')) {
133
+ return false;
134
+ }
135
+ return message.includes('no finite candidate logits after masking the pad token')
136
+ || message.includes('Softmax produced no finite candidate probabilities');
137
+ }
138
+
125
139
  export class PipelineGenerator {
126
140
 
127
141
  #state;
@@ -351,7 +365,7 @@ export class PipelineGenerator {
351
365
  try {
352
366
  prefillLogits = await this._prefill(inputIds, opts);
353
367
  } catch (error) {
354
- if (error.name === 'FinitenessError') {
368
+ if (shouldRetryWithFinitenessFallback(error)) {
355
369
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefill. Retrying with F32 precision.`);
356
370
  prefillLogits = await this._retryWithFinitenessFallback(
357
371
  opts,
@@ -395,13 +409,34 @@ export class PipelineGenerator {
395
409
  log.debug('Pipeline', `After rep penalty top-5: ${topAfterPenalty.map(t => `"${t.text}"(${(t.prob * 100).toFixed(1)}%)`).join(', ')}`);
396
410
  }
397
411
 
398
- const firstToken = sample(prefillLogits, {
399
- temperature: opts.temperature,
400
- topP: opts.topP,
401
- topK: opts.topK,
402
- padTokenId,
403
- seed: opts.seed,
404
- });
412
+ let firstToken;
413
+ try {
414
+ firstToken = sample(prefillLogits, {
415
+ temperature: opts.temperature,
416
+ topP: opts.topP,
417
+ topK: opts.topK,
418
+ padTokenId,
419
+ seed: opts.seed,
420
+ });
421
+ } catch (error) {
422
+ if (!shouldRetryWithFinitenessFallback(error)) {
423
+ throw error;
424
+ }
425
+ log.warn('Pipeline', 'FinitenessGuard caught non-finite prefill logits at sampling. Retrying with F32 precision.');
426
+ prefillLogits = await this._retryWithFinitenessFallback(
427
+ opts,
428
+ 'prefill-sample',
429
+ () => this._prefill(inputIds, opts)
430
+ );
431
+ applyRepetitionPenalty(prefillLogits, generatedIds, opts.repetitionPenalty);
432
+ firstToken = sample(prefillLogits, {
433
+ temperature: opts.temperature,
434
+ topP: opts.topP,
435
+ topK: opts.topK,
436
+ padTokenId,
437
+ seed: opts.seed,
438
+ });
439
+ }
405
440
 
406
441
  if (opts.debug) {
407
442
  const firstTokenText = resolveTokenText(this.#state.tokenizer, [firstToken], `[${firstToken}]`, (tokens) => this.#state.tokenizer?.decode?.(tokens, true, false));
@@ -479,7 +514,7 @@ export class PipelineGenerator {
479
514
  try {
480
515
  prefillResult = await this._prefillToHidden(inputIds, opts);
481
516
  } catch (error) {
482
- if (error.name === 'FinitenessError') {
517
+ if (shouldRetryWithFinitenessFallback(error)) {
483
518
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillKVOnly. Retrying with F32 precision.`);
484
519
  prefillResult = await this._retryWithFinitenessFallback(
485
520
  opts,
@@ -544,7 +579,7 @@ export class PipelineGenerator {
544
579
  try {
545
580
  prefillResult = await this._prefillToHidden(inputIds, opts);
546
581
  } catch (error) {
547
- if (error.name === 'FinitenessError') {
582
+ if (shouldRetryWithFinitenessFallback(error)) {
548
583
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillWithEmbedding. Retrying with F32 precision.`);
549
584
  prefillResult = await this._retryWithFinitenessFallback(
550
585
  opts,
@@ -833,7 +868,7 @@ export class PipelineGenerator {
833
868
  try {
834
869
  nextToken = await this._decodeStep(generatedIds, opts);
835
870
  } catch (singleTokenError) {
836
- if (singleTokenError.name === 'FinitenessError') {
871
+ if (shouldRetryWithFinitenessFallback(singleTokenError)) {
837
872
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at batch step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
838
873
  nextToken = await this._retryDecodeStepWithFinitenessWindow(
839
874
  generatedIds,
@@ -858,7 +893,7 @@ export class PipelineGenerator {
858
893
  try {
859
894
  nextToken = await this._decodeStep(generatedIds, opts);
860
895
  } catch (error) {
861
- if (error.name === 'FinitenessError') {
896
+ if (shouldRetryWithFinitenessFallback(error)) {
862
897
  log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
863
898
  nextToken = await this._retryDecodeStepWithFinitenessWindow(
864
899
  generatedIds,
@@ -918,11 +953,9 @@ export class PipelineGenerator {
918
953
  throw new Error('Embed buffer not found or not a supported buffer type');
919
954
  }
920
955
  const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
921
- const embedDtype = isWeightBuffer(embedBufferRaw)
922
- ? getWeightDtype(embedBufferRaw)
923
- : isCpuWeightBuffer(embedBufferRaw)
924
- ? embedBufferRaw.dtype
925
- : null;
956
+ const embedDtype = isCpuWeightBuffer(embedBufferRaw)
957
+ ? embedBufferRaw.dtype
958
+ : getWeightDtype(embedBufferRaw);
926
959
  if (opts.debug) {
927
960
  const embedSize = embedBuffer instanceof GPUBuffer ? embedBuffer.size : 'N/A';
928
961
  log.debug('Pipeline', `Embed buffer: type=${embedBuffer?.constructor?.name}, size=${embedSize}, dtype=${embedDtype}`);
@@ -190,6 +190,12 @@ export interface WeightLoadResult {
190
190
  layerRouterWeights: Map<number, RouterWeights>;
191
191
  }
192
192
 
193
+ export interface ResolvedQ4KConfig {
194
+ useFusedQ4K: boolean;
195
+ q4kLayout: 'row' | 'col' | null;
196
+ keepF32Weights: boolean;
197
+ }
198
+
193
199
  /** Options for loadWeights */
194
200
  export interface LoadWeightsOptions {
195
201
  storageContext?: PipelineStorageContext;
@@ -211,6 +217,13 @@ export function loadWeights(
211
217
  options?: LoadWeightsOptions
212
218
  ): Promise<WeightLoadResult>;
213
219
 
220
+ export function resolveQ4KConfig(
221
+ manifest: Manifest,
222
+ kernelPath?: KernelPathSchema | null,
223
+ kernelPathSource?: KernelPathSource,
224
+ keepF32Weights?: boolean
225
+ ): ResolvedQ4KConfig;
226
+
214
227
  /**
215
228
  * Apply Gemma chat template to a prompt.
216
229
  */
@@ -11,7 +11,7 @@ import { getDopplerLoader } from '../../../loader/doppler-loader.js';
11
11
  import { log, setGPUDevice, trace as debugTrace } from '../../../debug/index.js';
12
12
  import { getRuntimeConfig } from '../../../config/runtime.js';
13
13
  import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js';
14
- import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
14
+ import { isKernelPathFusedQ4K, kernelPathRequiresF32MatmulWeights } from '../../../config/kernel-path-loader.js';
15
15
  import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
16
16
  import { selectRuleValue } from '../../../rules/rule-registry.js';
17
17
  import {
@@ -128,7 +128,7 @@ function createRemoteStorageContext(baseUrl, manifest) {
128
128
  }
129
129
 
130
130
 
131
- function resolveQ4KConfig(
131
+ export function resolveQ4KConfig(
132
132
  manifest,
133
133
  kernelPath,
134
134
  kernelPathSource = 'none',
@@ -150,18 +150,23 @@ function resolveQ4KConfig(
150
150
  );
151
151
  }
152
152
  let useFused = kernelPath ? isKernelPathFusedQ4K(kernelPath) : hasSubgroups;
153
+ const kernelPathKeepsF32Weights = kernelPathRequiresF32MatmulWeights(kernelPath);
153
154
  if (q4kLayout === 'col') {
154
155
  useFused = false;
155
156
  }
157
+ const resolvedKeepF32Weights = keepF32Weights || kernelPathKeepsF32Weights;
156
158
 
157
159
  const pathLabel = kernelPath?.id ?? 'auto';
158
160
  const layoutLabel = q4kLayout ?? 'none';
159
- debugTrace.loader(`Q4K config: fused=${useFused}, kernelPath=${pathLabel}, source=${kernelPathSource}, layout=${layoutLabel}, subgroups=${hasSubgroups}`);
161
+ debugTrace.loader(
162
+ `Q4K config: fused=${useFused}, kernelPath=${pathLabel}, source=${kernelPathSource}, ` +
163
+ `layout=${layoutLabel}, keepF32Weights=${resolvedKeepF32Weights}, subgroups=${hasSubgroups}`
164
+ );
160
165
 
161
166
  return {
162
167
  useFusedQ4K: useFused,
163
168
  q4kLayout,
164
- keepF32Weights,
169
+ keepF32Weights: resolvedKeepF32Weights,
165
170
  };
166
171
  }
167
172
 
@@ -502,6 +507,12 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
502
507
  cacheLayout = 'paged';
503
508
  layoutSource = 'threshold';
504
509
  }
510
+ if (forceContiguousKVCache && cacheLayout === 'paged') {
511
+ throw new Error(
512
+ 'Paged KV cache layout is not supported for models with full-attention layers. ' +
513
+ 'Set runtime.inference.kvcache.layout to "contiguous" instead.'
514
+ );
515
+ }
505
516
  if (debug && cacheLayout !== runtimeKV.layout) {
506
517
  log.debug('Pipeline', `KV cache layout override: ${runtimeKV.layout} -> ${cacheLayout} (${layoutSource})`);
507
518
  }
@@ -599,7 +610,7 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
599
610
 
600
611
  if (debug) {
601
612
  if (forceContiguousKVCache && modelConfig.layerTypes) {
602
- log.debug('Pipeline', 'Layer pattern includes full-attention layers; forcing contiguous KV cache.');
613
+ log.debug('Pipeline', 'Layer pattern includes full-attention layers; paged layout blocked, contiguous enforced.');
603
614
  }
604
615
  const isSliding = kvCache instanceof SlidingWindowKVCache;
605
616
  log.debug('Pipeline', `KV cache: type=${kvCache?.constructor?.name || 'unknown'}, kvDtype=${kvCache.kvDtype}, layout=${kvCache.layout}, maxSeqLen=${kvCache.maxSeqLen}, windowSize=${isSliding ? kvCache.windowSize : null}`);
@@ -276,6 +276,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
276
276
  : (ropeFreqsSin),
277
277
  kvCache: ((kvCache)),
278
278
  stats: context.stats,
279
+ debugProbes: context.debugProbes,
279
280
  linearRuntime: context.linearAttentionRuntime ?? null,
280
281
  };
281
282
 
@@ -84,6 +84,11 @@ export declare function inferLinearNormMode(
84
84
  }
85
85
  ): LinearNormMode | null;
86
86
 
87
+ export declare function applyLinearNormWeightOffset(
88
+ values: Float32Array,
89
+ rmsNormWeightOffset: boolean
90
+ ): Float32Array;
91
+
87
92
  export declare function resetLinearAttentionRuntime(
88
93
  runtime: LinearAttentionRuntime | null | undefined
89
94
  ): LinearAttentionRuntime;
@@ -5,6 +5,8 @@ import { log } from '../../../debug/index.js';
5
5
  import { decodeReadback } from './debug-utils/index.js';
6
6
  import { runLinearAttentionCoreGPU } from '../../../gpu/kernels/linear-attention-core.js';
7
7
  import { runProbes } from './probes.js';
8
+ import { QK_K, Q4K_BLOCK_BYTES } from '../../../config/schema/index.js';
9
+ import { dequantizeQ4KM } from '../../../converter/quantizer.js';
8
10
 
9
11
  const LINEAR_RUNTIME_SCHEMA_VERSION = 1;
10
12
  const QK_L2NORM_EPS = 1e-6;
@@ -34,6 +36,15 @@ function bytesFromDtype(dtype) {
34
36
  return 4;
35
37
  }
36
38
 
39
+ export function applyLinearNormWeightOffset(values, rmsNormWeightOffset) {
40
+ if (!(values instanceof Float32Array)) {
41
+ throw new Error('applyLinearNormWeightOffset requires Float32Array input.');
42
+ }
43
+ // Qwen linear-attention output norm uses direct weights even when surrounding
44
+ // transformer RMSNorm sites use the Gemma-style (1 + weight) formula.
45
+ return values;
46
+ }
47
+
37
48
  function cloneLayerRuntimeState(layerState) {
38
49
  return {
39
50
  layerIdx: layerState.layerIdx,
@@ -283,9 +294,27 @@ async function readWeightAsF32(weight, expectedElements, label) {
283
294
  if (!elementCount && isWeightBuffer(weight) && Array.isArray(weight.shape) && weight.shape.length > 0) {
284
295
  elementCount = weight.shape.reduce((total, dim) => total * Math.max(1, Math.trunc(Number(dim) || 0)), 1);
285
296
  }
297
+ const isQ4K = sourceDtype === 'q4k' || sourceDtype === 'q4_k_m' || sourceDtype === 'q4_k';
286
298
  if (!elementCount) {
287
- const inferredBytes = sourceDtype === 'f16' || sourceDtype === 'bf16' ? 2 : 4;
288
- elementCount = Math.trunc(sourceBuffer.size / inferredBytes);
299
+ if (isQ4K) {
300
+ elementCount = Math.trunc(sourceBuffer.size / Q4K_BLOCK_BYTES) * QK_K;
301
+ } else {
302
+ const inferredBytes = sourceDtype === 'f16' || sourceDtype === 'bf16' ? 2 : 4;
303
+ elementCount = Math.trunc(sourceBuffer.size / inferredBytes);
304
+ }
305
+ }
306
+
307
+ if (isQ4K) {
308
+ const numBlocks = Math.ceil(elementCount / QK_K);
309
+ const q4kBytes = numBlocks * Q4K_BLOCK_BYTES;
310
+ const raw = await readBuffer(sourceBuffer, q4kBytes);
311
+ const decoded = dequantizeQ4KM(new Uint8Array(raw), numBlocks, [elementCount]);
312
+ if (expectedElements != null && decoded.length !== expectedElements) {
313
+ throw new Error(
314
+ `Weight "${label}" Q4K decoded length ${decoded.length}, expected ${expectedElements}.`
315
+ );
316
+ }
317
+ return decoded;
289
318
  }
290
319
 
291
320
  if (!sourceDtype) {
@@ -454,6 +483,7 @@ async function createLayerRuntimeState(
454
483
  expectedNormElements,
455
484
  `L${layerIdx}.linear_attn.norm.weight`
456
485
  );
486
+ const runtimeNorm = applyLinearNormWeightOffset(norm, config.rmsNormWeightOffset === true);
457
487
 
458
488
  const aNegExp = new Float32Array(aLog.length);
459
489
  for (let i = 0; i < aLog.length; i++) {
@@ -490,7 +520,7 @@ async function createLayerRuntimeState(
490
520
  convWeight,
491
521
  dtBias,
492
522
  aNegExp,
493
- normWeight: norm,
523
+ normWeight: runtimeNorm,
494
524
  convState,
495
525
  recurrentState,
496
526
  convWeightGPU: null,
@@ -304,7 +304,7 @@ export async function computeLogitsGPU(
304
304
 
305
305
  const logitsTensor = await runMatmul(normedTensor, lmHeadBuffer, numTokens, matmulVocabSize, hiddenSize, {
306
306
  transposeB: 'auto',
307
- role: forceStableF32Logits ? undefined : 'lm_head',
307
+ role: 'lm_head',
308
308
  kernelPath: config.kernelPath ?? null,
309
309
  });
310
310
 
@@ -391,7 +391,7 @@ export async function recordLogitsGPU(
391
391
  // Record matmul (no submit)
392
392
  const logitsTensor = await recordMatmul(recorder, normedTensor, lmHeadBuffer, numTokens, matmulVocabSize, hiddenSize, {
393
393
  transposeB: 'auto',
394
- role: forceStableF32Logits ? undefined : 'lm_head',
394
+ role: 'lm_head',
395
395
  kernelPath: config.kernelPath ?? null,
396
396
  });
397
397
 
@@ -25,6 +25,10 @@ export { computeLogitsGPU, recordLogitsGPU, computeChunkedLogitsGPU, resolveCpuW
25
25
  // Re-export utilities
26
26
  export { extractLastPositionLogits, finalizeLogits } from './utils.js';
27
27
 
28
+ export interface ComputeLogitsOptions {
29
+ lastPositionOnly?: boolean;
30
+ }
31
+
28
32
  /**
29
33
  * Compute logits from hidden states.
30
34
  *
@@ -53,5 +57,6 @@ export function computeLogits(
53
57
  debugFlags?: LogitsDebugFlags,
54
58
  getNormWeightBuffer?: (weight: GPUBuffer | Float32Array | ArrayBuffer, label: string) => GPUBuffer,
55
59
  debugCheckBuffer?: (buffer: GPUBuffer, label: string, numTokens: number, expectedDim?: number) => Promise<void>,
56
- debugProbes?: ProbeConfigSchema[] | null
60
+ debugProbes?: ProbeConfigSchema[] | null,
61
+ options?: ComputeLogitsOptions
57
62
  ): Promise<Float32Array>;
@@ -253,6 +253,7 @@ export async function computeLogits(
253
253
 
254
254
  const lastPositionOnly = options?.lastPositionOnly === true && numTokens > 1;
255
255
  const matmulRows = lastPositionOnly ? 1 : numTokens;
256
+ const matmulPhaseOverride = lastPositionOnly ? 'prefill' : null;
256
257
  let matmulInputTensor = normedTensor;
257
258
  let matmulInputOwned = false;
258
259
  if (lastPositionOnly) {
@@ -270,7 +271,8 @@ export async function computeLogits(
270
271
  // HuggingFace models store lm_head as [vocabSize, hiddenSize], so transposeB=true
271
272
  const logitsTensor = await runMatmul(matmulInputTensor, lmHeadBuffer, matmulRows, matmulVocabSize, hiddenSize, {
272
273
  transposeB: 'auto',
273
- role: (forceStableF32Logits || lastPositionOnly) ? undefined : 'lm_head',
274
+ role: 'lm_head',
275
+ phaseOverride: matmulPhaseOverride,
274
276
  kernelPath: config.kernelPath ?? null,
275
277
  });
276
278
  await runProbes('logits', logitsTensor.buffer, {
@@ -234,6 +234,9 @@ function buildManifestDecodeLoopRuntimePatch(manifest) {
234
234
 
235
235
  export function applyModelBatchingRuntimeDefaults(runtimeConfig, manifest, modelConfig) {
236
236
  void modelConfig;
237
+ if (manifest?.inference?.schema === 'doppler.execution/v0') {
238
+ return runtimeConfig;
239
+ }
237
240
  const batching = runtimeConfig?.inference?.batching;
238
241
  const generation = runtimeConfig?.inference?.generation;
239
242
  const runtimeBatchingAtDefaults = isRuntimeBatchingAtGlobalDefaults(batching);
@@ -58,6 +58,30 @@ export function softmax(logits) {
58
58
  return exps;
59
59
  }
60
60
 
61
+ function countFiniteCandidates(logits, padTokenId) {
62
+ let finiteCandidateCount = 0;
63
+ for (let i = 0; i < logits.length; i++) {
64
+ if (padTokenId != null && i === padTokenId) {
65
+ continue;
66
+ }
67
+ if (Number.isFinite(logits[i])) {
68
+ finiteCandidateCount += 1;
69
+ }
70
+ }
71
+ return finiteCandidateCount;
72
+ }
73
+
74
+ function assertFiniteSamplingCandidates(logits, padTokenId, label) {
75
+ const finiteCandidateCount = countFiniteCandidates(logits, padTokenId);
76
+ if (finiteCandidateCount > 0) {
77
+ return;
78
+ }
79
+ throw new Error(
80
+ `[Sampling] ${label} has no finite candidate logits after masking the pad token. ` +
81
+ 'Upstream decode likely produced NaN/Inf or an all-masked distribution.'
82
+ );
83
+ }
84
+
61
85
 
62
86
  export function sample(logits, opts) {
63
87
  const { temperature, topP, topK, decode, debug = false, padTokenId, seed } = opts;
@@ -66,16 +90,28 @@ export function sample(logits, opts) {
66
90
  logits[padTokenId] = -Infinity;
67
91
  }
68
92
 
93
+ assertFiniteSamplingCandidates(logits, padTokenId, 'Logits');
94
+
69
95
  // Greedy (argmax) when temperature = 0
70
96
  if (temperature === 0) {
71
- let maxIdx = 0;
72
- let maxVal = logits[0];
73
- for (let i = 1; i < logits.length; i++) {
74
- if (logits[i] > maxVal) {
75
- maxVal = logits[i];
97
+ let maxIdx = -1;
98
+ let maxVal = -Infinity;
99
+ for (let i = 0; i < logits.length; i++) {
100
+ const value = logits[i];
101
+ if (!Number.isFinite(value)) {
102
+ continue;
103
+ }
104
+ if (value > maxVal) {
105
+ maxVal = value;
76
106
  maxIdx = i;
77
107
  }
78
108
  }
109
+ if (maxIdx < 0) {
110
+ throw new Error(
111
+ '[Sampling] Greedy sampling could not find a finite candidate logit. ' +
112
+ 'Upstream decode likely produced NaN/Inf.'
113
+ );
114
+ }
79
115
  if (debug) {
80
116
  const text = decode?.([maxIdx]) ?? '?';
81
117
  trace.sample(`Greedy: id=${maxIdx} "${text}" logit=${maxVal.toFixed(4)}`);
@@ -96,7 +132,17 @@ export function sample(logits, opts) {
96
132
 
97
133
  let candidates = [];
98
134
  for (let i = 0; i < probs.length; i++) {
99
- candidates.push({ token: i, prob: probs[i] });
135
+ const probability = probs[i];
136
+ if (!Number.isFinite(probability) || probability <= 0) {
137
+ continue;
138
+ }
139
+ candidates.push({ token: i, prob: probability });
140
+ }
141
+ if (candidates.length === 0) {
142
+ throw new Error(
143
+ '[Sampling] Softmax produced no finite candidate probabilities. ' +
144
+ 'Upstream decode likely produced NaN/Inf logits.'
145
+ );
100
146
  }
101
147
  candidates.sort((a, b) => b.prob - a.prob);
102
148
 
@@ -1,7 +1,7 @@
1
1
 
2
2
 
3
3
  import { initDevice, getDevice, getKernelCapabilities } from '../gpu/device.js';
4
- import { parseManifest } from '../formats/rdrr/index.js';
4
+ import { parseManifest, getExpectedShardHash } from '../formats/rdrr/index.js';
5
5
  import { createPipeline } from './pipelines/text.js';
6
6
  import { log as debugLog } from '../debug/index.js';
7
7
  import { getRuntimeConfig, setRuntimeConfig } from '../config/runtime.js';
@@ -168,7 +168,7 @@ export function createHttpShardLoader(baseUrl, manifest, log) {
168
168
  distributionConfig,
169
169
  algorithm,
170
170
  requiredEncoding,
171
- expectedHash: shard.hash ?? null,
171
+ expectedHash: getExpectedShardHash(shard, algorithm) || null,
172
172
  expectedSize: Number.isFinite(shard.size) ? Math.floor(shard.size) : null,
173
173
  expectedManifestVersionSet: manifestVersionSet,
174
174
  writeToStore: false,
@@ -36,6 +36,8 @@ function isLikelyFinalNormName(name) {
36
36
  return (
37
37
  lower === 'norm.weight' ||
38
38
  lower.includes('model.norm.weight') ||
39
+ lower.includes('language_model.norm.weight') ||
40
+ lower.includes('model.language_model.norm.weight') ||
39
41
  lower.includes('embedding_norm.weight') ||
40
42
  lower.includes('model.embedding_norm.weight') ||
41
43
  lower.includes('final_layernorm.weight') ||
@@ -5,6 +5,7 @@ import {
5
5
  computeHash,
6
6
  getStorageBackendType,
7
7
  } from '../storage/shard-manager.js';
8
+ import { getExpectedShardHash } from '../formats/rdrr/index.js';
8
9
  import { formatBytes } from '../storage/quota.js';
9
10
  import { log, trace as debugTrace } from '../debug/index.js';
10
11
  import { getRuntimeConfig } from '../config/runtime.js';
@@ -484,11 +485,11 @@ export class ShardCache {
484
485
  // Verify hash if enabled
485
486
  if (this.#verifyHashes && this.#manifest) {
486
487
  const shardInfo = this.#manifest.shards?.[shardIndex];
487
- const expectedHash = shardInfo?.hash;
488
+ const algorithm = shardInfo?.hashAlgorithm ?? this.#manifest.hashAlgorithm;
489
+ const expectedHash = getExpectedShardHash(shardInfo, algorithm);
488
490
  if (!expectedHash) {
489
491
  throw new Error(`Shard ${shardIndex} missing hash in manifest.`);
490
492
  }
491
- const algorithm = shardInfo?.hashAlgorithm ?? this.#manifest.hashAlgorithm;
492
493
  if (!algorithm) {
493
494
  throw new Error(`Manifest missing hashAlgorithm for shard ${shardIndex}.`);
494
495
  }
@@ -309,8 +309,9 @@ export async function loadBF16(shardData, location, name, config) {
309
309
  const numElements = location.size / 2;
310
310
  const caps = config.gpuCapabilities || getKernelCapabilities();
311
311
  const isMatmulWeight = shouldDequantizeToF16(location);
312
+ const keepF32Weights = config.keepF32Weights === true;
312
313
 
313
- if (caps?.hasF16 && isMatmulWeight) {
314
+ if (caps?.hasF16 && isMatmulWeight && !keepF32Weights) {
314
315
  const f16Tensor = await runBF16ToF16(srcBuffer, [numElements], name);
315
316
  resultBuffer = f16Tensor.buffer;
316
317
  releaseOwnedGpuBuffer(srcBuffer, ownsSrcBuffer);
@@ -327,6 +328,10 @@ export async function loadBF16(shardData, location, name, config) {
327
328
  };
328
329
  }
329
330
 
331
+ if (isMatmulWeight && keepF32Weights) {
332
+ debugTrace.loader(`Keeping BF16 matmul weight in f32: ${name} (keepF32Weights=true)`);
333
+ }
334
+
330
335
  const dstBuffer = await convertBF16ToF32GPU(srcBuffer, numElements, name);
331
336
  resultBuffer = dstBuffer;
332
337
  releaseOwnedGpuBuffer(srcBuffer, ownsSrcBuffer);
@@ -59,6 +59,11 @@
59
59
  { "match": { "useF16": true }, "value": "f16" },
60
60
  { "match": {}, "value": { "context": "fallback" } }
61
61
  ],
62
+ "attentionProjectionOutputDtype": [
63
+ { "match": { "forceF32": true }, "value": "f32" },
64
+ { "match": { "useF16": true }, "value": "f16" },
65
+ { "match": {}, "value": { "context": "fallback" } }
66
+ ],
62
67
  "bytesPerElement": [
63
68
  { "match": { "dtype": "f16" }, "value": 2 },
64
69
  { "match": {}, "value": 4 }
@@ -46,7 +46,7 @@
46
46
  "hasSubgroups": false,
47
47
  "kernelPathRef": "lfm2-q4k-dequant-f32a-online"
48
48
  },
49
- "value": "gemma3-q4k-dequant-f32a-nosubgroups"
49
+ "value": "lfm2-q4k-dequant-f32a-nosubgroups"
50
50
  },
51
51
  {
52
52
  "match": {
@@ -77,7 +77,7 @@
77
77
  },
78
78
  {
79
79
  "match": { "kernelPathId": "lfm2-q4k-dequant-f32a-online" },
80
- "value": "gemma3-q4k-dequant-f32a-nosubgroups"
80
+ "value": "lfm2-q4k-dequant-f32a-nosubgroups"
81
81
  },
82
82
  {
83
83
  "match": { "kernelPathId": "gemma2-f16-f16a" },
@@ -0,0 +1,6 @@
1
+ {
2
+ "variant": [
3
+ { "match": { "outputDtype": "f16" }, "value": "f16" },
4
+ { "match": {}, "value": "default" }
5
+ ]
6
+ }
@@ -50,6 +50,7 @@ const sampleRules = await loadJson('./kernels/sample.rules.json', import.meta.ur
50
50
  const scaleRules = await loadJson('./kernels/scale.rules.json', import.meta.url, 'Failed to load rules');
51
51
  const siluRules = await loadJson('./kernels/silu.rules.json', import.meta.url, 'Failed to load rules');
52
52
  const splitQkvRules = await loadJson('./kernels/split-qkv.rules.json', import.meta.url, 'Failed to load rules');
53
+ const splitQgRules = await loadJson('./kernels/split-qg.rules.json', import.meta.url, 'Failed to load rules');
53
54
  const softmaxRules = await loadJson('./kernels/softmax.rules.json', import.meta.url, 'Failed to load rules');
54
55
  const upsample2dRules = await loadJson('./kernels/upsample2d.rules.json', import.meta.url, 'Failed to load rules');
55
56
  const configRules = await loadJson('./inference/config.rules.json', import.meta.url, 'Failed to load rules');
@@ -124,6 +125,7 @@ const RULE_SETS = {
124
125
  scale: scaleRules,
125
126
  silu: siluRules,
126
127
  splitQkv: splitQkvRules,
128
+ splitQg: splitQgRules,
127
129
  softmax: softmaxRules,
128
130
  upsample2d: upsample2dRules,
129
131
  },
@@ -2,6 +2,7 @@
2
2
 
3
3
  import {
4
4
  parseManifest,
5
+ getExpectedShardHash,
5
6
  getManifestUrl,
6
7
  } from '../formats/rdrr/index.js';
7
8
 
@@ -726,7 +727,7 @@ export async function downloadModel(
726
727
  if (!algorithm) {
727
728
  throw new Error('Manifest missing hashAlgorithm for download verification.');
728
729
  }
729
- const expectedHash = shardInfo.hash;
730
+ const expectedHash = getExpectedShardHash(shardInfo, algorithm);
730
731
  if (!expectedHash) {
731
732
  throw new Error(`Shard ${shardIndex} is missing hash in manifest`);
732
733
  }
@@ -1,5 +1,6 @@
1
1
  import {
2
2
  getManifest,
3
+ getExpectedShardHash,
3
4
  getShardInfo,
4
5
  getShardCount,
5
6
  generateShardFilename,
@@ -280,7 +281,7 @@ export async function writeShard(shardIndex, data, options = { verify: true }) {
280
281
  const manifest = getManifest();
281
282
  const algorithm = requireManifestHashAlgorithm(manifest, 'shard write');
282
283
  const hash = await computeHash(bytes, algorithm);
283
- const expectedHash = shardInfo.hash;
284
+ const expectedHash = getExpectedShardHash(shardInfo, algorithm);
284
285
  if (!expectedHash) {
285
286
  await backend.deleteFile(shardInfo.filename);
286
287
  throw new Error(`Shard ${shardIndex} is missing hash in manifest`);
@@ -369,7 +370,7 @@ export async function loadShard(shardIndex, options = { verify: false }) {
369
370
  const manifest = getManifest();
370
371
  const algorithm = requireManifestHashAlgorithm(manifest, 'shard load');
371
372
  const hash = await computeHash(buffer, algorithm);
372
- const expectedHash = shardInfo.hash;
373
+ const expectedHash = getExpectedShardHash(shardInfo, algorithm);
373
374
  if (!expectedHash) {
374
375
  throw new Error(`Shard ${shardIndex} is missing hash in manifest`);
375
376
  }
@@ -531,7 +532,7 @@ export async function verifyIntegrity(options = {}) {
531
532
  const buffer = await loadShard(i, { verify: false });
532
533
  const hash = await computeHash(buffer, algorithm);
533
534
  const shardInfo = getShardInfo(i);
534
- const expectedHash = shardInfo?.hash;
535
+ const expectedHash = getExpectedShardHash(shardInfo, algorithm);
535
536
  if (!expectedHash) {
536
537
  corruptShards.push(i);
537
538
  continue;