@simulatte/doppler 0.1.7 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +19 -0
- package/package.json +21 -36
- package/src/browser/browser-converter.js +5 -0
- package/src/client/doppler-registry.json +1 -17
- package/src/config/kernel-path-loader.d.ts +5 -0
- package/src/config/kernel-path-loader.js +13 -0
- package/src/config/kernels/registry.json +74 -0
- package/src/config/loader.js +3 -0
- package/src/config/merge-contract-check.js +7 -0
- package/src/config/presets/kernel-paths/gemma3-q4k-dequant-f32w-f32a-online.json +56 -0
- package/src/config/presets/kernel-paths/lfm2-q4k-dequant-f32a-nosubgroups.json +61 -0
- package/src/config/presets/kernel-paths/registry.json +14 -0
- package/src/config/presets/models/gemma2.json +2 -1
- package/src/config/presets/models/gemma3.json +2 -0
- package/src/config/presets/models/qwen3.json +4 -3
- package/src/config/presets/models/qwen3_5.json +16 -0
- package/src/config/presets/runtime/model/qwen3-5-layer-probe.json +52 -0
- package/src/config/presets/runtime/model/qwen3-5-linear-attn-debug.json +90 -0
- package/src/config/schema/conversion.schema.d.ts +1 -0
- package/src/config/schema/manifest.schema.d.ts +1 -1
- package/src/config/schema/manifest.schema.js +1 -1
- package/src/config/schema/storage.schema.js +1 -1
- package/src/converter/conversion-plan.js +10 -2
- package/src/converter/core.js +2 -0
- package/src/converter/manifest-inference.js +12 -22
- package/src/converter/parsers/transformer.js +4 -0
- package/src/converter/quantization-info.js +5 -1
- package/src/converter/quantizer.js +19 -12
- package/src/converter/rope-config.js +8 -6
- package/src/converter/tokenizer-utils.d.ts +1 -0
- package/src/converter/tokenizer-utils.js +4 -1
- package/src/debug/reference/hf_qwen35_linear_attn_debug.py +268 -0
- package/src/distribution/shard-delivery.js +6 -1
- package/src/formats/rdrr/parsing.d.ts +4 -0
- package/src/formats/rdrr/parsing.js +14 -1
- package/src/gpu/kernels/index.d.ts +8 -0
- package/src/gpu/kernels/index.js +6 -0
- package/src/gpu/kernels/matmul-selection.js +47 -4
- package/src/gpu/kernels/matmul.d.ts +2 -0
- package/src/gpu/kernels/matmul.js +1 -1
- package/src/gpu/kernels/rmsnorm.js +9 -2
- package/src/gpu/kernels/split_qg.d.ts +50 -0
- package/src/gpu/kernels/split_qg.js +46 -0
- package/src/gpu/kernels/split_qg.wgsl +58 -0
- package/src/gpu/kernels/split_qg_f16.wgsl +62 -0
- package/src/gpu/weight-buffer.d.ts +1 -1
- package/src/gpu/weight-buffer.js +1 -1
- package/src/inference/browser-harness.d.ts +2 -0
- package/src/inference/browser-harness.js +20 -1
- package/src/inference/pipelines/diffusion/helpers.js +3 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +8 -2
- package/src/inference/pipelines/text/attention/output-projection.d.ts +12 -0
- package/src/inference/pipelines/text/attention/output-projection.js +8 -0
- package/src/inference/pipelines/text/attention/projections.d.ts +10 -1
- package/src/inference/pipelines/text/attention/projections.js +41 -11
- package/src/inference/pipelines/text/attention/record.js +15 -6
- package/src/inference/pipelines/text/attention/run.js +50 -6
- package/src/inference/pipelines/text/config.js +14 -0
- package/src/inference/pipelines/text/execution-plan.js +5 -4
- package/src/inference/pipelines/text/generator-runtime.js +5 -0
- package/src/inference/pipelines/text/generator-steps.d.ts +6 -0
- package/src/inference/pipelines/text/generator-steps.js +43 -15
- package/src/inference/pipelines/text/generator.js +50 -17
- package/src/inference/pipelines/text/init.d.ts +13 -0
- package/src/inference/pipelines/text/init.js +16 -5
- package/src/inference/pipelines/text/layer.js +1 -0
- package/src/inference/pipelines/text/linear-attention.d.ts +5 -0
- package/src/inference/pipelines/text/linear-attention.js +33 -3
- package/src/inference/pipelines/text/logits/gpu.js +2 -2
- package/src/inference/pipelines/text/logits/index.d.ts +6 -1
- package/src/inference/pipelines/text/logits/index.js +3 -1
- package/src/inference/pipelines/text/model-load.js +3 -0
- package/src/inference/pipelines/text/sampling.js +52 -6
- package/src/inference/test-harness.js +2 -2
- package/src/loader/final-weights-loader.js +2 -0
- package/src/loader/shard-cache.js +3 -2
- package/src/loader/tensors/tensor-loader.js +6 -1
- package/src/rules/inference/dtype.rules.json +5 -0
- package/src/rules/inference/kernel-path.rules.json +2 -2
- package/src/rules/kernels/split-qg.rules.json +6 -0
- package/src/rules/rule-registry.js +2 -0
- package/src/storage/downloader.js +2 -1
- package/src/storage/shard-manager.js +4 -3
- package/src/tooling/conversion-config-materializer.js +3 -5
- package/src/tooling/node-converter.js +3 -0
- package/src/tooling/node-source-runtime.js +36 -0
- package/src/types/model.d.ts +5 -0
- package/tools/doppler-cli.js +6 -1
|
@@ -122,6 +122,20 @@ function resolveTokenText(tokenizer, tokenIds, fallbackText = '?', renderTokenTe
|
|
|
122
122
|
return fallbackText;
|
|
123
123
|
}
|
|
124
124
|
|
|
125
|
+
export function shouldRetryWithFinitenessFallback(error) {
|
|
126
|
+
if (error?.name === 'FinitenessError') {
|
|
127
|
+
return true;
|
|
128
|
+
}
|
|
129
|
+
const message = typeof error?.message === 'string'
|
|
130
|
+
? error.message
|
|
131
|
+
: (typeof error === 'string' ? error : '');
|
|
132
|
+
if (!message.startsWith('[Sampling]')) {
|
|
133
|
+
return false;
|
|
134
|
+
}
|
|
135
|
+
return message.includes('no finite candidate logits after masking the pad token')
|
|
136
|
+
|| message.includes('Softmax produced no finite candidate probabilities');
|
|
137
|
+
}
|
|
138
|
+
|
|
125
139
|
export class PipelineGenerator {
|
|
126
140
|
|
|
127
141
|
#state;
|
|
@@ -351,7 +365,7 @@ export class PipelineGenerator {
|
|
|
351
365
|
try {
|
|
352
366
|
prefillLogits = await this._prefill(inputIds, opts);
|
|
353
367
|
} catch (error) {
|
|
354
|
-
if (error
|
|
368
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
355
369
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefill. Retrying with F32 precision.`);
|
|
356
370
|
prefillLogits = await this._retryWithFinitenessFallback(
|
|
357
371
|
opts,
|
|
@@ -395,13 +409,34 @@ export class PipelineGenerator {
|
|
|
395
409
|
log.debug('Pipeline', `After rep penalty top-5: ${topAfterPenalty.map(t => `"${t.text}"(${(t.prob * 100).toFixed(1)}%)`).join(', ')}`);
|
|
396
410
|
}
|
|
397
411
|
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
412
|
+
let firstToken;
|
|
413
|
+
try {
|
|
414
|
+
firstToken = sample(prefillLogits, {
|
|
415
|
+
temperature: opts.temperature,
|
|
416
|
+
topP: opts.topP,
|
|
417
|
+
topK: opts.topK,
|
|
418
|
+
padTokenId,
|
|
419
|
+
seed: opts.seed,
|
|
420
|
+
});
|
|
421
|
+
} catch (error) {
|
|
422
|
+
if (!shouldRetryWithFinitenessFallback(error)) {
|
|
423
|
+
throw error;
|
|
424
|
+
}
|
|
425
|
+
log.warn('Pipeline', 'FinitenessGuard caught non-finite prefill logits at sampling. Retrying with F32 precision.');
|
|
426
|
+
prefillLogits = await this._retryWithFinitenessFallback(
|
|
427
|
+
opts,
|
|
428
|
+
'prefill-sample',
|
|
429
|
+
() => this._prefill(inputIds, opts)
|
|
430
|
+
);
|
|
431
|
+
applyRepetitionPenalty(prefillLogits, generatedIds, opts.repetitionPenalty);
|
|
432
|
+
firstToken = sample(prefillLogits, {
|
|
433
|
+
temperature: opts.temperature,
|
|
434
|
+
topP: opts.topP,
|
|
435
|
+
topK: opts.topK,
|
|
436
|
+
padTokenId,
|
|
437
|
+
seed: opts.seed,
|
|
438
|
+
});
|
|
439
|
+
}
|
|
405
440
|
|
|
406
441
|
if (opts.debug) {
|
|
407
442
|
const firstTokenText = resolveTokenText(this.#state.tokenizer, [firstToken], `[${firstToken}]`, (tokens) => this.#state.tokenizer?.decode?.(tokens, true, false));
|
|
@@ -479,7 +514,7 @@ export class PipelineGenerator {
|
|
|
479
514
|
try {
|
|
480
515
|
prefillResult = await this._prefillToHidden(inputIds, opts);
|
|
481
516
|
} catch (error) {
|
|
482
|
-
if (error
|
|
517
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
483
518
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillKVOnly. Retrying with F32 precision.`);
|
|
484
519
|
prefillResult = await this._retryWithFinitenessFallback(
|
|
485
520
|
opts,
|
|
@@ -544,7 +579,7 @@ export class PipelineGenerator {
|
|
|
544
579
|
try {
|
|
545
580
|
prefillResult = await this._prefillToHidden(inputIds, opts);
|
|
546
581
|
} catch (error) {
|
|
547
|
-
if (error
|
|
582
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
548
583
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf during prefillWithEmbedding. Retrying with F32 precision.`);
|
|
549
584
|
prefillResult = await this._retryWithFinitenessFallback(
|
|
550
585
|
opts,
|
|
@@ -833,7 +868,7 @@ export class PipelineGenerator {
|
|
|
833
868
|
try {
|
|
834
869
|
nextToken = await this._decodeStep(generatedIds, opts);
|
|
835
870
|
} catch (singleTokenError) {
|
|
836
|
-
if (singleTokenError
|
|
871
|
+
if (shouldRetryWithFinitenessFallback(singleTokenError)) {
|
|
837
872
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at batch step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
|
|
838
873
|
nextToken = await this._retryDecodeStepWithFinitenessWindow(
|
|
839
874
|
generatedIds,
|
|
@@ -858,7 +893,7 @@ export class PipelineGenerator {
|
|
|
858
893
|
try {
|
|
859
894
|
nextToken = await this._decodeStep(generatedIds, opts);
|
|
860
895
|
} catch (error) {
|
|
861
|
-
if (error
|
|
896
|
+
if (shouldRetryWithFinitenessFallback(error)) {
|
|
862
897
|
log.warn('Pipeline', `FinitenessGuard caught NaN/Inf at step ${tokensGenerated}. Truncating KV cache and retrying token with F32 precision.`);
|
|
863
898
|
nextToken = await this._retryDecodeStepWithFinitenessWindow(
|
|
864
899
|
generatedIds,
|
|
@@ -918,11 +953,9 @@ export class PipelineGenerator {
|
|
|
918
953
|
throw new Error('Embed buffer not found or not a supported buffer type');
|
|
919
954
|
}
|
|
920
955
|
const embedBuffer = isWeightBuffer(embedBufferRaw) ? embedBufferRaw.buffer : embedBufferRaw;
|
|
921
|
-
const embedDtype =
|
|
922
|
-
?
|
|
923
|
-
:
|
|
924
|
-
? embedBufferRaw.dtype
|
|
925
|
-
: null;
|
|
956
|
+
const embedDtype = isCpuWeightBuffer(embedBufferRaw)
|
|
957
|
+
? embedBufferRaw.dtype
|
|
958
|
+
: getWeightDtype(embedBufferRaw);
|
|
926
959
|
if (opts.debug) {
|
|
927
960
|
const embedSize = embedBuffer instanceof GPUBuffer ? embedBuffer.size : 'N/A';
|
|
928
961
|
log.debug('Pipeline', `Embed buffer: type=${embedBuffer?.constructor?.name}, size=${embedSize}, dtype=${embedDtype}`);
|
|
@@ -190,6 +190,12 @@ export interface WeightLoadResult {
|
|
|
190
190
|
layerRouterWeights: Map<number, RouterWeights>;
|
|
191
191
|
}
|
|
192
192
|
|
|
193
|
+
export interface ResolvedQ4KConfig {
|
|
194
|
+
useFusedQ4K: boolean;
|
|
195
|
+
q4kLayout: 'row' | 'col' | null;
|
|
196
|
+
keepF32Weights: boolean;
|
|
197
|
+
}
|
|
198
|
+
|
|
193
199
|
/** Options for loadWeights */
|
|
194
200
|
export interface LoadWeightsOptions {
|
|
195
201
|
storageContext?: PipelineStorageContext;
|
|
@@ -211,6 +217,13 @@ export function loadWeights(
|
|
|
211
217
|
options?: LoadWeightsOptions
|
|
212
218
|
): Promise<WeightLoadResult>;
|
|
213
219
|
|
|
220
|
+
export function resolveQ4KConfig(
|
|
221
|
+
manifest: Manifest,
|
|
222
|
+
kernelPath?: KernelPathSchema | null,
|
|
223
|
+
kernelPathSource?: KernelPathSource,
|
|
224
|
+
keepF32Weights?: boolean
|
|
225
|
+
): ResolvedQ4KConfig;
|
|
226
|
+
|
|
214
227
|
/**
|
|
215
228
|
* Apply Gemma chat template to a prompt.
|
|
216
229
|
*/
|
|
@@ -11,7 +11,7 @@ import { getDopplerLoader } from '../../../loader/doppler-loader.js';
|
|
|
11
11
|
import { log, setGPUDevice, trace as debugTrace } from '../../../debug/index.js';
|
|
12
12
|
import { getRuntimeConfig } from '../../../config/runtime.js';
|
|
13
13
|
import { PAGED_LAYOUT_SEQ_LEN_THRESHOLD } from '../../../config/schema/index.js';
|
|
14
|
-
import { isKernelPathFusedQ4K } from '../../../config/kernel-path-loader.js';
|
|
14
|
+
import { isKernelPathFusedQ4K, kernelPathRequiresF32MatmulWeights } from '../../../config/kernel-path-loader.js';
|
|
15
15
|
import { createWeightBuffer, getWeightDtype, isWeightBuffer } from '../../../gpu/weight-buffer.js';
|
|
16
16
|
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
17
17
|
import {
|
|
@@ -128,7 +128,7 @@ function createRemoteStorageContext(baseUrl, manifest) {
|
|
|
128
128
|
}
|
|
129
129
|
|
|
130
130
|
|
|
131
|
-
function resolveQ4KConfig(
|
|
131
|
+
export function resolveQ4KConfig(
|
|
132
132
|
manifest,
|
|
133
133
|
kernelPath,
|
|
134
134
|
kernelPathSource = 'none',
|
|
@@ -150,18 +150,23 @@ function resolveQ4KConfig(
|
|
|
150
150
|
);
|
|
151
151
|
}
|
|
152
152
|
let useFused = kernelPath ? isKernelPathFusedQ4K(kernelPath) : hasSubgroups;
|
|
153
|
+
const kernelPathKeepsF32Weights = kernelPathRequiresF32MatmulWeights(kernelPath);
|
|
153
154
|
if (q4kLayout === 'col') {
|
|
154
155
|
useFused = false;
|
|
155
156
|
}
|
|
157
|
+
const resolvedKeepF32Weights = keepF32Weights || kernelPathKeepsF32Weights;
|
|
156
158
|
|
|
157
159
|
const pathLabel = kernelPath?.id ?? 'auto';
|
|
158
160
|
const layoutLabel = q4kLayout ?? 'none';
|
|
159
|
-
debugTrace.loader(
|
|
161
|
+
debugTrace.loader(
|
|
162
|
+
`Q4K config: fused=${useFused}, kernelPath=${pathLabel}, source=${kernelPathSource}, ` +
|
|
163
|
+
`layout=${layoutLabel}, keepF32Weights=${resolvedKeepF32Weights}, subgroups=${hasSubgroups}`
|
|
164
|
+
);
|
|
160
165
|
|
|
161
166
|
return {
|
|
162
167
|
useFusedQ4K: useFused,
|
|
163
168
|
q4kLayout,
|
|
164
|
-
keepF32Weights,
|
|
169
|
+
keepF32Weights: resolvedKeepF32Weights,
|
|
165
170
|
};
|
|
166
171
|
}
|
|
167
172
|
|
|
@@ -502,6 +507,12 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
|
|
|
502
507
|
cacheLayout = 'paged';
|
|
503
508
|
layoutSource = 'threshold';
|
|
504
509
|
}
|
|
510
|
+
if (forceContiguousKVCache && cacheLayout === 'paged') {
|
|
511
|
+
throw new Error(
|
|
512
|
+
'Paged KV cache layout is not supported for models with full-attention layers. ' +
|
|
513
|
+
'Set runtime.inference.kvcache.layout to "contiguous" instead.'
|
|
514
|
+
);
|
|
515
|
+
}
|
|
505
516
|
if (debug && cacheLayout !== runtimeKV.layout) {
|
|
506
517
|
log.debug('Pipeline', `KV cache layout override: ${runtimeKV.layout} -> ${cacheLayout} (${layoutSource})`);
|
|
507
518
|
}
|
|
@@ -599,7 +610,7 @@ export function createKVCache(modelConfig, useGPU, debug = false, runtimeConfig)
|
|
|
599
610
|
|
|
600
611
|
if (debug) {
|
|
601
612
|
if (forceContiguousKVCache && modelConfig.layerTypes) {
|
|
602
|
-
log.debug('Pipeline', 'Layer pattern includes full-attention layers;
|
|
613
|
+
log.debug('Pipeline', 'Layer pattern includes full-attention layers; paged layout blocked, contiguous enforced.');
|
|
603
614
|
}
|
|
604
615
|
const isSliding = kvCache instanceof SlidingWindowKVCache;
|
|
605
616
|
log.debug('Pipeline', `KV cache: type=${kvCache?.constructor?.name || 'unknown'}, kvDtype=${kvCache.kvDtype}, layout=${kvCache.layout}, maxSeqLen=${kvCache.maxSeqLen}, windowSize=${isSliding ? kvCache.windowSize : null}`);
|
|
@@ -276,6 +276,7 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
276
276
|
: (ropeFreqsSin),
|
|
277
277
|
kvCache: ((kvCache)),
|
|
278
278
|
stats: context.stats,
|
|
279
|
+
debugProbes: context.debugProbes,
|
|
279
280
|
linearRuntime: context.linearAttentionRuntime ?? null,
|
|
280
281
|
};
|
|
281
282
|
|
|
@@ -84,6 +84,11 @@ export declare function inferLinearNormMode(
|
|
|
84
84
|
}
|
|
85
85
|
): LinearNormMode | null;
|
|
86
86
|
|
|
87
|
+
export declare function applyLinearNormWeightOffset(
|
|
88
|
+
values: Float32Array,
|
|
89
|
+
rmsNormWeightOffset: boolean
|
|
90
|
+
): Float32Array;
|
|
91
|
+
|
|
87
92
|
export declare function resetLinearAttentionRuntime(
|
|
88
93
|
runtime: LinearAttentionRuntime | null | undefined
|
|
89
94
|
): LinearAttentionRuntime;
|
|
@@ -5,6 +5,8 @@ import { log } from '../../../debug/index.js';
|
|
|
5
5
|
import { decodeReadback } from './debug-utils/index.js';
|
|
6
6
|
import { runLinearAttentionCoreGPU } from '../../../gpu/kernels/linear-attention-core.js';
|
|
7
7
|
import { runProbes } from './probes.js';
|
|
8
|
+
import { QK_K, Q4K_BLOCK_BYTES } from '../../../config/schema/index.js';
|
|
9
|
+
import { dequantizeQ4KM } from '../../../converter/quantizer.js';
|
|
8
10
|
|
|
9
11
|
const LINEAR_RUNTIME_SCHEMA_VERSION = 1;
|
|
10
12
|
const QK_L2NORM_EPS = 1e-6;
|
|
@@ -34,6 +36,15 @@ function bytesFromDtype(dtype) {
|
|
|
34
36
|
return 4;
|
|
35
37
|
}
|
|
36
38
|
|
|
39
|
+
export function applyLinearNormWeightOffset(values, rmsNormWeightOffset) {
|
|
40
|
+
if (!(values instanceof Float32Array)) {
|
|
41
|
+
throw new Error('applyLinearNormWeightOffset requires Float32Array input.');
|
|
42
|
+
}
|
|
43
|
+
// Qwen linear-attention output norm uses direct weights even when surrounding
|
|
44
|
+
// transformer RMSNorm sites use the Gemma-style (1 + weight) formula.
|
|
45
|
+
return values;
|
|
46
|
+
}
|
|
47
|
+
|
|
37
48
|
function cloneLayerRuntimeState(layerState) {
|
|
38
49
|
return {
|
|
39
50
|
layerIdx: layerState.layerIdx,
|
|
@@ -283,9 +294,27 @@ async function readWeightAsF32(weight, expectedElements, label) {
|
|
|
283
294
|
if (!elementCount && isWeightBuffer(weight) && Array.isArray(weight.shape) && weight.shape.length > 0) {
|
|
284
295
|
elementCount = weight.shape.reduce((total, dim) => total * Math.max(1, Math.trunc(Number(dim) || 0)), 1);
|
|
285
296
|
}
|
|
297
|
+
const isQ4K = sourceDtype === 'q4k' || sourceDtype === 'q4_k_m' || sourceDtype === 'q4_k';
|
|
286
298
|
if (!elementCount) {
|
|
287
|
-
|
|
288
|
-
|
|
299
|
+
if (isQ4K) {
|
|
300
|
+
elementCount = Math.trunc(sourceBuffer.size / Q4K_BLOCK_BYTES) * QK_K;
|
|
301
|
+
} else {
|
|
302
|
+
const inferredBytes = sourceDtype === 'f16' || sourceDtype === 'bf16' ? 2 : 4;
|
|
303
|
+
elementCount = Math.trunc(sourceBuffer.size / inferredBytes);
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
if (isQ4K) {
|
|
308
|
+
const numBlocks = Math.ceil(elementCount / QK_K);
|
|
309
|
+
const q4kBytes = numBlocks * Q4K_BLOCK_BYTES;
|
|
310
|
+
const raw = await readBuffer(sourceBuffer, q4kBytes);
|
|
311
|
+
const decoded = dequantizeQ4KM(new Uint8Array(raw), numBlocks, [elementCount]);
|
|
312
|
+
if (expectedElements != null && decoded.length !== expectedElements) {
|
|
313
|
+
throw new Error(
|
|
314
|
+
`Weight "${label}" Q4K decoded length ${decoded.length}, expected ${expectedElements}.`
|
|
315
|
+
);
|
|
316
|
+
}
|
|
317
|
+
return decoded;
|
|
289
318
|
}
|
|
290
319
|
|
|
291
320
|
if (!sourceDtype) {
|
|
@@ -454,6 +483,7 @@ async function createLayerRuntimeState(
|
|
|
454
483
|
expectedNormElements,
|
|
455
484
|
`L${layerIdx}.linear_attn.norm.weight`
|
|
456
485
|
);
|
|
486
|
+
const runtimeNorm = applyLinearNormWeightOffset(norm, config.rmsNormWeightOffset === true);
|
|
457
487
|
|
|
458
488
|
const aNegExp = new Float32Array(aLog.length);
|
|
459
489
|
for (let i = 0; i < aLog.length; i++) {
|
|
@@ -490,7 +520,7 @@ async function createLayerRuntimeState(
|
|
|
490
520
|
convWeight,
|
|
491
521
|
dtBias,
|
|
492
522
|
aNegExp,
|
|
493
|
-
normWeight:
|
|
523
|
+
normWeight: runtimeNorm,
|
|
494
524
|
convState,
|
|
495
525
|
recurrentState,
|
|
496
526
|
convWeightGPU: null,
|
|
@@ -304,7 +304,7 @@ export async function computeLogitsGPU(
|
|
|
304
304
|
|
|
305
305
|
const logitsTensor = await runMatmul(normedTensor, lmHeadBuffer, numTokens, matmulVocabSize, hiddenSize, {
|
|
306
306
|
transposeB: 'auto',
|
|
307
|
-
role:
|
|
307
|
+
role: 'lm_head',
|
|
308
308
|
kernelPath: config.kernelPath ?? null,
|
|
309
309
|
});
|
|
310
310
|
|
|
@@ -391,7 +391,7 @@ export async function recordLogitsGPU(
|
|
|
391
391
|
// Record matmul (no submit)
|
|
392
392
|
const logitsTensor = await recordMatmul(recorder, normedTensor, lmHeadBuffer, numTokens, matmulVocabSize, hiddenSize, {
|
|
393
393
|
transposeB: 'auto',
|
|
394
|
-
role:
|
|
394
|
+
role: 'lm_head',
|
|
395
395
|
kernelPath: config.kernelPath ?? null,
|
|
396
396
|
});
|
|
397
397
|
|
|
@@ -25,6 +25,10 @@ export { computeLogitsGPU, recordLogitsGPU, computeChunkedLogitsGPU, resolveCpuW
|
|
|
25
25
|
// Re-export utilities
|
|
26
26
|
export { extractLastPositionLogits, finalizeLogits } from './utils.js';
|
|
27
27
|
|
|
28
|
+
export interface ComputeLogitsOptions {
|
|
29
|
+
lastPositionOnly?: boolean;
|
|
30
|
+
}
|
|
31
|
+
|
|
28
32
|
/**
|
|
29
33
|
* Compute logits from hidden states.
|
|
30
34
|
*
|
|
@@ -53,5 +57,6 @@ export function computeLogits(
|
|
|
53
57
|
debugFlags?: LogitsDebugFlags,
|
|
54
58
|
getNormWeightBuffer?: (weight: GPUBuffer | Float32Array | ArrayBuffer, label: string) => GPUBuffer,
|
|
55
59
|
debugCheckBuffer?: (buffer: GPUBuffer, label: string, numTokens: number, expectedDim?: number) => Promise<void>,
|
|
56
|
-
debugProbes?: ProbeConfigSchema[] | null
|
|
60
|
+
debugProbes?: ProbeConfigSchema[] | null,
|
|
61
|
+
options?: ComputeLogitsOptions
|
|
57
62
|
): Promise<Float32Array>;
|
|
@@ -253,6 +253,7 @@ export async function computeLogits(
|
|
|
253
253
|
|
|
254
254
|
const lastPositionOnly = options?.lastPositionOnly === true && numTokens > 1;
|
|
255
255
|
const matmulRows = lastPositionOnly ? 1 : numTokens;
|
|
256
|
+
const matmulPhaseOverride = lastPositionOnly ? 'prefill' : null;
|
|
256
257
|
let matmulInputTensor = normedTensor;
|
|
257
258
|
let matmulInputOwned = false;
|
|
258
259
|
if (lastPositionOnly) {
|
|
@@ -270,7 +271,8 @@ export async function computeLogits(
|
|
|
270
271
|
// HuggingFace models store lm_head as [vocabSize, hiddenSize], so transposeB=true
|
|
271
272
|
const logitsTensor = await runMatmul(matmulInputTensor, lmHeadBuffer, matmulRows, matmulVocabSize, hiddenSize, {
|
|
272
273
|
transposeB: 'auto',
|
|
273
|
-
role:
|
|
274
|
+
role: 'lm_head',
|
|
275
|
+
phaseOverride: matmulPhaseOverride,
|
|
274
276
|
kernelPath: config.kernelPath ?? null,
|
|
275
277
|
});
|
|
276
278
|
await runProbes('logits', logitsTensor.buffer, {
|
|
@@ -234,6 +234,9 @@ function buildManifestDecodeLoopRuntimePatch(manifest) {
|
|
|
234
234
|
|
|
235
235
|
export function applyModelBatchingRuntimeDefaults(runtimeConfig, manifest, modelConfig) {
|
|
236
236
|
void modelConfig;
|
|
237
|
+
if (manifest?.inference?.schema === 'doppler.execution/v0') {
|
|
238
|
+
return runtimeConfig;
|
|
239
|
+
}
|
|
237
240
|
const batching = runtimeConfig?.inference?.batching;
|
|
238
241
|
const generation = runtimeConfig?.inference?.generation;
|
|
239
242
|
const runtimeBatchingAtDefaults = isRuntimeBatchingAtGlobalDefaults(batching);
|
|
@@ -58,6 +58,30 @@ export function softmax(logits) {
|
|
|
58
58
|
return exps;
|
|
59
59
|
}
|
|
60
60
|
|
|
61
|
+
function countFiniteCandidates(logits, padTokenId) {
|
|
62
|
+
let finiteCandidateCount = 0;
|
|
63
|
+
for (let i = 0; i < logits.length; i++) {
|
|
64
|
+
if (padTokenId != null && i === padTokenId) {
|
|
65
|
+
continue;
|
|
66
|
+
}
|
|
67
|
+
if (Number.isFinite(logits[i])) {
|
|
68
|
+
finiteCandidateCount += 1;
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
return finiteCandidateCount;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
function assertFiniteSamplingCandidates(logits, padTokenId, label) {
|
|
75
|
+
const finiteCandidateCount = countFiniteCandidates(logits, padTokenId);
|
|
76
|
+
if (finiteCandidateCount > 0) {
|
|
77
|
+
return;
|
|
78
|
+
}
|
|
79
|
+
throw new Error(
|
|
80
|
+
`[Sampling] ${label} has no finite candidate logits after masking the pad token. ` +
|
|
81
|
+
'Upstream decode likely produced NaN/Inf or an all-masked distribution.'
|
|
82
|
+
);
|
|
83
|
+
}
|
|
84
|
+
|
|
61
85
|
|
|
62
86
|
export function sample(logits, opts) {
|
|
63
87
|
const { temperature, topP, topK, decode, debug = false, padTokenId, seed } = opts;
|
|
@@ -66,16 +90,28 @@ export function sample(logits, opts) {
|
|
|
66
90
|
logits[padTokenId] = -Infinity;
|
|
67
91
|
}
|
|
68
92
|
|
|
93
|
+
assertFiniteSamplingCandidates(logits, padTokenId, 'Logits');
|
|
94
|
+
|
|
69
95
|
// Greedy (argmax) when temperature = 0
|
|
70
96
|
if (temperature === 0) {
|
|
71
|
-
let maxIdx =
|
|
72
|
-
let maxVal =
|
|
73
|
-
for (let i =
|
|
74
|
-
|
|
75
|
-
|
|
97
|
+
let maxIdx = -1;
|
|
98
|
+
let maxVal = -Infinity;
|
|
99
|
+
for (let i = 0; i < logits.length; i++) {
|
|
100
|
+
const value = logits[i];
|
|
101
|
+
if (!Number.isFinite(value)) {
|
|
102
|
+
continue;
|
|
103
|
+
}
|
|
104
|
+
if (value > maxVal) {
|
|
105
|
+
maxVal = value;
|
|
76
106
|
maxIdx = i;
|
|
77
107
|
}
|
|
78
108
|
}
|
|
109
|
+
if (maxIdx < 0) {
|
|
110
|
+
throw new Error(
|
|
111
|
+
'[Sampling] Greedy sampling could not find a finite candidate logit. ' +
|
|
112
|
+
'Upstream decode likely produced NaN/Inf.'
|
|
113
|
+
);
|
|
114
|
+
}
|
|
79
115
|
if (debug) {
|
|
80
116
|
const text = decode?.([maxIdx]) ?? '?';
|
|
81
117
|
trace.sample(`Greedy: id=${maxIdx} "${text}" logit=${maxVal.toFixed(4)}`);
|
|
@@ -96,7 +132,17 @@ export function sample(logits, opts) {
|
|
|
96
132
|
|
|
97
133
|
let candidates = [];
|
|
98
134
|
for (let i = 0; i < probs.length; i++) {
|
|
99
|
-
|
|
135
|
+
const probability = probs[i];
|
|
136
|
+
if (!Number.isFinite(probability) || probability <= 0) {
|
|
137
|
+
continue;
|
|
138
|
+
}
|
|
139
|
+
candidates.push({ token: i, prob: probability });
|
|
140
|
+
}
|
|
141
|
+
if (candidates.length === 0) {
|
|
142
|
+
throw new Error(
|
|
143
|
+
'[Sampling] Softmax produced no finite candidate probabilities. ' +
|
|
144
|
+
'Upstream decode likely produced NaN/Inf logits.'
|
|
145
|
+
);
|
|
100
146
|
}
|
|
101
147
|
candidates.sort((a, b) => b.prob - a.prob);
|
|
102
148
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
|
|
2
2
|
|
|
3
3
|
import { initDevice, getDevice, getKernelCapabilities } from '../gpu/device.js';
|
|
4
|
-
import { parseManifest } from '../formats/rdrr/index.js';
|
|
4
|
+
import { parseManifest, getExpectedShardHash } from '../formats/rdrr/index.js';
|
|
5
5
|
import { createPipeline } from './pipelines/text.js';
|
|
6
6
|
import { log as debugLog } from '../debug/index.js';
|
|
7
7
|
import { getRuntimeConfig, setRuntimeConfig } from '../config/runtime.js';
|
|
@@ -168,7 +168,7 @@ export function createHttpShardLoader(baseUrl, manifest, log) {
|
|
|
168
168
|
distributionConfig,
|
|
169
169
|
algorithm,
|
|
170
170
|
requiredEncoding,
|
|
171
|
-
expectedHash: shard
|
|
171
|
+
expectedHash: getExpectedShardHash(shard, algorithm) || null,
|
|
172
172
|
expectedSize: Number.isFinite(shard.size) ? Math.floor(shard.size) : null,
|
|
173
173
|
expectedManifestVersionSet: manifestVersionSet,
|
|
174
174
|
writeToStore: false,
|
|
@@ -36,6 +36,8 @@ function isLikelyFinalNormName(name) {
|
|
|
36
36
|
return (
|
|
37
37
|
lower === 'norm.weight' ||
|
|
38
38
|
lower.includes('model.norm.weight') ||
|
|
39
|
+
lower.includes('language_model.norm.weight') ||
|
|
40
|
+
lower.includes('model.language_model.norm.weight') ||
|
|
39
41
|
lower.includes('embedding_norm.weight') ||
|
|
40
42
|
lower.includes('model.embedding_norm.weight') ||
|
|
41
43
|
lower.includes('final_layernorm.weight') ||
|
|
@@ -5,6 +5,7 @@ import {
|
|
|
5
5
|
computeHash,
|
|
6
6
|
getStorageBackendType,
|
|
7
7
|
} from '../storage/shard-manager.js';
|
|
8
|
+
import { getExpectedShardHash } from '../formats/rdrr/index.js';
|
|
8
9
|
import { formatBytes } from '../storage/quota.js';
|
|
9
10
|
import { log, trace as debugTrace } from '../debug/index.js';
|
|
10
11
|
import { getRuntimeConfig } from '../config/runtime.js';
|
|
@@ -484,11 +485,11 @@ export class ShardCache {
|
|
|
484
485
|
// Verify hash if enabled
|
|
485
486
|
if (this.#verifyHashes && this.#manifest) {
|
|
486
487
|
const shardInfo = this.#manifest.shards?.[shardIndex];
|
|
487
|
-
const
|
|
488
|
+
const algorithm = shardInfo?.hashAlgorithm ?? this.#manifest.hashAlgorithm;
|
|
489
|
+
const expectedHash = getExpectedShardHash(shardInfo, algorithm);
|
|
488
490
|
if (!expectedHash) {
|
|
489
491
|
throw new Error(`Shard ${shardIndex} missing hash in manifest.`);
|
|
490
492
|
}
|
|
491
|
-
const algorithm = shardInfo?.hashAlgorithm ?? this.#manifest.hashAlgorithm;
|
|
492
493
|
if (!algorithm) {
|
|
493
494
|
throw new Error(`Manifest missing hashAlgorithm for shard ${shardIndex}.`);
|
|
494
495
|
}
|
|
@@ -309,8 +309,9 @@ export async function loadBF16(shardData, location, name, config) {
|
|
|
309
309
|
const numElements = location.size / 2;
|
|
310
310
|
const caps = config.gpuCapabilities || getKernelCapabilities();
|
|
311
311
|
const isMatmulWeight = shouldDequantizeToF16(location);
|
|
312
|
+
const keepF32Weights = config.keepF32Weights === true;
|
|
312
313
|
|
|
313
|
-
if (caps?.hasF16 && isMatmulWeight) {
|
|
314
|
+
if (caps?.hasF16 && isMatmulWeight && !keepF32Weights) {
|
|
314
315
|
const f16Tensor = await runBF16ToF16(srcBuffer, [numElements], name);
|
|
315
316
|
resultBuffer = f16Tensor.buffer;
|
|
316
317
|
releaseOwnedGpuBuffer(srcBuffer, ownsSrcBuffer);
|
|
@@ -327,6 +328,10 @@ export async function loadBF16(shardData, location, name, config) {
|
|
|
327
328
|
};
|
|
328
329
|
}
|
|
329
330
|
|
|
331
|
+
if (isMatmulWeight && keepF32Weights) {
|
|
332
|
+
debugTrace.loader(`Keeping BF16 matmul weight in f32: ${name} (keepF32Weights=true)`);
|
|
333
|
+
}
|
|
334
|
+
|
|
330
335
|
const dstBuffer = await convertBF16ToF32GPU(srcBuffer, numElements, name);
|
|
331
336
|
resultBuffer = dstBuffer;
|
|
332
337
|
releaseOwnedGpuBuffer(srcBuffer, ownsSrcBuffer);
|
|
@@ -59,6 +59,11 @@
|
|
|
59
59
|
{ "match": { "useF16": true }, "value": "f16" },
|
|
60
60
|
{ "match": {}, "value": { "context": "fallback" } }
|
|
61
61
|
],
|
|
62
|
+
"attentionProjectionOutputDtype": [
|
|
63
|
+
{ "match": { "forceF32": true }, "value": "f32" },
|
|
64
|
+
{ "match": { "useF16": true }, "value": "f16" },
|
|
65
|
+
{ "match": {}, "value": { "context": "fallback" } }
|
|
66
|
+
],
|
|
62
67
|
"bytesPerElement": [
|
|
63
68
|
{ "match": { "dtype": "f16" }, "value": 2 },
|
|
64
69
|
{ "match": {}, "value": 4 }
|
|
@@ -46,7 +46,7 @@
|
|
|
46
46
|
"hasSubgroups": false,
|
|
47
47
|
"kernelPathRef": "lfm2-q4k-dequant-f32a-online"
|
|
48
48
|
},
|
|
49
|
-
"value": "
|
|
49
|
+
"value": "lfm2-q4k-dequant-f32a-nosubgroups"
|
|
50
50
|
},
|
|
51
51
|
{
|
|
52
52
|
"match": {
|
|
@@ -77,7 +77,7 @@
|
|
|
77
77
|
},
|
|
78
78
|
{
|
|
79
79
|
"match": { "kernelPathId": "lfm2-q4k-dequant-f32a-online" },
|
|
80
|
-
"value": "
|
|
80
|
+
"value": "lfm2-q4k-dequant-f32a-nosubgroups"
|
|
81
81
|
},
|
|
82
82
|
{
|
|
83
83
|
"match": { "kernelPathId": "gemma2-f16-f16a" },
|
|
@@ -50,6 +50,7 @@ const sampleRules = await loadJson('./kernels/sample.rules.json', import.meta.ur
|
|
|
50
50
|
const scaleRules = await loadJson('./kernels/scale.rules.json', import.meta.url, 'Failed to load rules');
|
|
51
51
|
const siluRules = await loadJson('./kernels/silu.rules.json', import.meta.url, 'Failed to load rules');
|
|
52
52
|
const splitQkvRules = await loadJson('./kernels/split-qkv.rules.json', import.meta.url, 'Failed to load rules');
|
|
53
|
+
const splitQgRules = await loadJson('./kernels/split-qg.rules.json', import.meta.url, 'Failed to load rules');
|
|
53
54
|
const softmaxRules = await loadJson('./kernels/softmax.rules.json', import.meta.url, 'Failed to load rules');
|
|
54
55
|
const upsample2dRules = await loadJson('./kernels/upsample2d.rules.json', import.meta.url, 'Failed to load rules');
|
|
55
56
|
const configRules = await loadJson('./inference/config.rules.json', import.meta.url, 'Failed to load rules');
|
|
@@ -124,6 +125,7 @@ const RULE_SETS = {
|
|
|
124
125
|
scale: scaleRules,
|
|
125
126
|
silu: siluRules,
|
|
126
127
|
splitQkv: splitQkvRules,
|
|
128
|
+
splitQg: splitQgRules,
|
|
127
129
|
softmax: softmaxRules,
|
|
128
130
|
upsample2d: upsample2dRules,
|
|
129
131
|
},
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import {
|
|
4
4
|
parseManifest,
|
|
5
|
+
getExpectedShardHash,
|
|
5
6
|
getManifestUrl,
|
|
6
7
|
} from '../formats/rdrr/index.js';
|
|
7
8
|
|
|
@@ -726,7 +727,7 @@ export async function downloadModel(
|
|
|
726
727
|
if (!algorithm) {
|
|
727
728
|
throw new Error('Manifest missing hashAlgorithm for download verification.');
|
|
728
729
|
}
|
|
729
|
-
const expectedHash = shardInfo
|
|
730
|
+
const expectedHash = getExpectedShardHash(shardInfo, algorithm);
|
|
730
731
|
if (!expectedHash) {
|
|
731
732
|
throw new Error(`Shard ${shardIndex} is missing hash in manifest`);
|
|
732
733
|
}
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import {
|
|
2
2
|
getManifest,
|
|
3
|
+
getExpectedShardHash,
|
|
3
4
|
getShardInfo,
|
|
4
5
|
getShardCount,
|
|
5
6
|
generateShardFilename,
|
|
@@ -280,7 +281,7 @@ export async function writeShard(shardIndex, data, options = { verify: true }) {
|
|
|
280
281
|
const manifest = getManifest();
|
|
281
282
|
const algorithm = requireManifestHashAlgorithm(manifest, 'shard write');
|
|
282
283
|
const hash = await computeHash(bytes, algorithm);
|
|
283
|
-
const expectedHash = shardInfo
|
|
284
|
+
const expectedHash = getExpectedShardHash(shardInfo, algorithm);
|
|
284
285
|
if (!expectedHash) {
|
|
285
286
|
await backend.deleteFile(shardInfo.filename);
|
|
286
287
|
throw new Error(`Shard ${shardIndex} is missing hash in manifest`);
|
|
@@ -369,7 +370,7 @@ export async function loadShard(shardIndex, options = { verify: false }) {
|
|
|
369
370
|
const manifest = getManifest();
|
|
370
371
|
const algorithm = requireManifestHashAlgorithm(manifest, 'shard load');
|
|
371
372
|
const hash = await computeHash(buffer, algorithm);
|
|
372
|
-
const expectedHash = shardInfo
|
|
373
|
+
const expectedHash = getExpectedShardHash(shardInfo, algorithm);
|
|
373
374
|
if (!expectedHash) {
|
|
374
375
|
throw new Error(`Shard ${shardIndex} is missing hash in manifest`);
|
|
375
376
|
}
|
|
@@ -531,7 +532,7 @@ export async function verifyIntegrity(options = {}) {
|
|
|
531
532
|
const buffer = await loadShard(i, { verify: false });
|
|
532
533
|
const hash = await computeHash(buffer, algorithm);
|
|
533
534
|
const shardInfo = getShardInfo(i);
|
|
534
|
-
const expectedHash = shardInfo
|
|
535
|
+
const expectedHash = getExpectedShardHash(shardInfo, algorithm);
|
|
535
536
|
if (!expectedHash) {
|
|
536
537
|
corruptShards.push(i);
|
|
537
538
|
continue;
|