@simulatte/doppler 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +23 -8
- package/package.json +7 -4
- package/src/config/kernels/kernel-ref-digests.js +39 -39
- package/src/config/kernels/registry.json +42 -2
- package/src/config/loader.js +31 -2
- package/src/config/merge.js +18 -0
- package/src/config/presets/models/qwen3.json +9 -2
- package/src/config/presets/models/transformer.json +5 -0
- package/src/config/required-inference-fields-contract-check.js +6 -0
- package/src/config/schema/inference-defaults.schema.js +3 -0
- package/src/config/schema/inference.schema.d.ts +9 -0
- package/src/config/schema/kernel-path.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.d.ts +6 -0
- package/src/config/schema/manifest.schema.js +3 -0
- package/src/converter/rope-config.js +42 -0
- package/src/gpu/device.js +58 -0
- package/src/gpu/kernels/attention.js +98 -0
- package/src/gpu/kernels/bias_add.wgsl +8 -6
- package/src/gpu/kernels/bias_add_f16.wgsl +8 -5
- package/src/gpu/kernels/conv2d.js +1 -1
- package/src/gpu/kernels/conv2d.wgsl +7 -8
- package/src/gpu/kernels/conv2d_f16.wgsl +7 -8
- package/src/gpu/kernels/depthwise_conv2d.js +2 -1
- package/src/gpu/kernels/depthwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/depthwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +2 -1
- package/src/gpu/kernels/grouped_pointwise_conv2d.wgsl +6 -9
- package/src/gpu/kernels/grouped_pointwise_conv2d_f16.wgsl +6 -9
- package/src/gpu/kernels/matmul.js +25 -0
- package/src/gpu/kernels/pixel_shuffle.js +1 -1
- package/src/gpu/kernels/pixel_shuffle.wgsl +4 -5
- package/src/gpu/kernels/pixel_shuffle_f16.wgsl +4 -5
- package/src/gpu/kernels/relu.js +15 -2
- package/src/gpu/kernels/relu.wgsl +2 -1
- package/src/gpu/kernels/relu_f16.wgsl +2 -1
- package/src/gpu/kernels/repeat_channels.js +1 -1
- package/src/gpu/kernels/repeat_channels.wgsl +4 -5
- package/src/gpu/kernels/repeat_channels_f16.wgsl +4 -5
- package/src/gpu/kernels/residual.js +44 -8
- package/src/gpu/kernels/residual.wgsl +6 -3
- package/src/gpu/kernels/residual_f16.wgsl +2 -1
- package/src/gpu/kernels/residual_f16_vec4.wgsl +2 -1
- package/src/gpu/kernels/residual_vec4.wgsl +2 -1
- package/src/gpu/kernels/rmsnorm.js +58 -6
- package/src/gpu/kernels/rmsnorm.wgsl +14 -6
- package/src/gpu/kernels/rmsnorm_f16.wgsl +10 -2
- package/src/gpu/kernels/rope.d.ts +2 -0
- package/src/gpu/kernels/rope.js +11 -1
- package/src/gpu/kernels/rope.wgsl +56 -40
- package/src/gpu/kernels/sana_linear_attention.js +1 -2
- package/src/gpu/kernels/sana_linear_attention_apply.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_apply_f16.wgsl +4 -5
- package/src/gpu/kernels/sana_linear_attention_summary.wgsl +4 -0
- package/src/gpu/kernels/sana_linear_attention_summary_f16.wgsl +4 -0
- package/src/gpu/kernels/silu.d.ts +1 -0
- package/src/gpu/kernels/silu.js +32 -14
- package/src/gpu/kernels/silu.wgsl +19 -9
- package/src/gpu/kernels/silu_f16.wgsl +19 -9
- package/src/gpu/kernels/transpose.js +15 -2
- package/src/gpu/kernels/transpose.wgsl +5 -6
- package/src/gpu/kernels/upsample2d.js +2 -1
- package/src/gpu/kernels/upsample2d.wgsl +6 -9
- package/src/gpu/kernels/upsample2d_f16.wgsl +6 -9
- package/src/gpu/kernels/utils.js +16 -1
- package/src/inference/browser-harness.js +47 -1
- package/src/inference/pipelines/diffusion/pipeline.js +15 -6
- package/src/inference/pipelines/diffusion/text-encoder-gpu.d.ts +5 -0
- package/src/inference/pipelines/diffusion/text-encoder-gpu.js +27 -15
- package/src/inference/pipelines/text/attention/record.js +11 -2
- package/src/inference/pipelines/text/attention/run.js +11 -2
- package/src/inference/pipelines/text/chat-format.js +25 -1
- package/src/inference/pipelines/text/config.d.ts +4 -0
- package/src/inference/pipelines/text/config.js +68 -1
- package/src/inference/pipelines/text/execution-plan.js +23 -31
- package/src/inference/pipelines/text/execution-v0.js +29 -2
- package/src/inference/pipelines/text/ffn/standard.js +3 -0
- package/src/inference/pipelines/text/init.d.ts +4 -0
- package/src/inference/pipelines/text/init.js +56 -9
- package/src/inference/pipelines/text/layer.js +11 -0
- package/src/inference/pipelines/text.js +4 -0
- package/src/inference/tokenizers/bundled.js +156 -33
- package/src/rules/tooling/command-runtime.rules.json +18 -0
- package/src/tooling/command-api.d.ts +27 -1
- package/src/tooling/command-api.js +142 -3
- package/src/tooling/node-browser-command-runner.d.ts +4 -0
- package/src/tooling/node-browser-command-runner.js +58 -3
- package/src/tooling/node-command-runner.js +15 -0
- package/src/tooling/node-webgpu.js +9 -87
- package/src/training/checkpoint-watch.d.ts +7 -0
- package/src/training/checkpoint-watch.js +106 -0
- package/src/training/checkpoint.d.ts +6 -1
- package/src/training/checkpoint.js +12 -2
- package/src/training/distillation/artifacts.d.ts +71 -0
- package/src/training/distillation/artifacts.js +132 -0
- package/src/training/distillation/checkpoint-watch.d.ts +10 -0
- package/src/training/distillation/checkpoint-watch.js +57 -0
- package/src/training/distillation/dataset.d.ts +59 -0
- package/src/training/distillation/dataset.js +337 -0
- package/src/training/distillation/eval.d.ts +34 -0
- package/src/training/distillation/eval.js +310 -0
- package/src/training/distillation/index.d.ts +29 -0
- package/src/training/distillation/index.js +29 -0
- package/src/training/distillation/runtime.d.ts +20 -0
- package/src/training/distillation/runtime.js +121 -0
- package/src/training/distillation/scoreboard.d.ts +6 -0
- package/src/training/distillation/scoreboard.js +8 -0
- package/src/training/distillation/stage-a.d.ts +45 -0
- package/src/training/distillation/stage-a.js +338 -0
- package/src/training/distillation/stage-b.d.ts +24 -0
- package/src/training/distillation/stage-b.js +20 -0
- package/src/training/index.d.ts +10 -0
- package/src/training/index.js +10 -0
- package/src/training/lora-pipeline.d.ts +40 -0
- package/src/training/lora-pipeline.js +796 -0
- package/src/training/operator-artifacts.d.ts +62 -0
- package/src/training/operator-artifacts.js +140 -0
- package/src/training/operator-command.d.ts +5 -0
- package/src/training/operator-command.js +453 -0
- package/src/training/operator-eval.d.ts +48 -0
- package/src/training/operator-eval.js +230 -0
- package/src/training/operator-scoreboard.d.ts +5 -0
- package/src/training/operator-scoreboard.js +44 -0
- package/src/training/runner.d.ts +52 -0
- package/src/training/runner.js +29 -4
- package/src/training/suite.d.ts +112 -0
- package/src/training/suite.js +9 -9
- package/src/training/workloads.d.ts +164 -0
- package/src/training/workloads.js +539 -0
- package/src/version.js +1 -1
- package/tools/doppler-cli.js +137 -40
|
@@ -42,56 +42,48 @@ function resolveFallbackActivationDtype(primaryActivationDtype) {
|
|
|
42
42
|
function resolveFallbackKernelPath(primaryKernelPath) {
|
|
43
43
|
const primaryKernelPathId = primaryKernelPath?.id ?? null;
|
|
44
44
|
if (!primaryKernelPathId) {
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
};
|
|
45
|
+
throw new Error(
|
|
46
|
+
'[ExecutionPlan] F16 finiteness fallback requires a primary kernel path with a stable id. ' +
|
|
47
|
+
'Add a registered kernelPath id and a finiteness fallback rule.'
|
|
48
|
+
);
|
|
50
49
|
}
|
|
51
50
|
|
|
52
|
-
const
|
|
51
|
+
const explicitFallbackKernelPathId = typeof primaryKernelPath?.finitenessFallbackKernelPathId === 'string'
|
|
52
|
+
&& primaryKernelPath.finitenessFallbackKernelPathId.length > 0
|
|
53
|
+
? primaryKernelPath.finitenessFallbackKernelPathId
|
|
54
|
+
: null;
|
|
53
55
|
|
|
54
|
-
const fallbackKernelPathId = selectRuleValue(
|
|
56
|
+
const fallbackKernelPathId = explicitFallbackKernelPathId ?? selectRuleValue(
|
|
55
57
|
'inference',
|
|
56
58
|
'kernelPath',
|
|
57
59
|
'finitenessFallback',
|
|
58
60
|
{ kernelPathId: primaryKernelPathId }
|
|
59
61
|
);
|
|
60
62
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
63
|
+
if (typeof fallbackKernelPathId !== 'string' || fallbackKernelPathId.length === 0) {
|
|
64
|
+
throw new Error(
|
|
65
|
+
`[ExecutionPlan] Missing finiteness fallback kernel path mapping for "${primaryKernelPathId}". ` +
|
|
66
|
+
'Add an explicit rule in src/rules/inference/kernel-path.rules.json.'
|
|
67
|
+
);
|
|
68
|
+
}
|
|
65
69
|
|
|
66
|
-
if (
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
`
|
|
70
|
+
if (fallbackKernelPathId === primaryKernelPathId) {
|
|
71
|
+
throw new Error(
|
|
72
|
+
`[ExecutionPlan] Invalid finiteness fallback mapping for "${primaryKernelPathId}": ` +
|
|
73
|
+
`fallback kernel path resolves to itself. Add an explicit widening path.`
|
|
70
74
|
);
|
|
71
75
|
}
|
|
72
76
|
|
|
73
77
|
try {
|
|
74
|
-
const kernelPath = resolveKernelPath(
|
|
78
|
+
const kernelPath = resolveKernelPath(fallbackKernelPathId);
|
|
75
79
|
return {
|
|
76
80
|
kernelPath,
|
|
77
|
-
kernelPathId:
|
|
78
|
-
kernelPathSource,
|
|
81
|
+
kernelPathId: fallbackKernelPathId,
|
|
82
|
+
kernelPathSource: 'rule',
|
|
79
83
|
};
|
|
80
84
|
} catch (error) {
|
|
81
|
-
if (primaryKernelPathIsObject) {
|
|
82
|
-
log.warn(
|
|
83
|
-
'Pipeline',
|
|
84
|
-
`[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${resolvedKernelPathId}" ` +
|
|
85
|
-
`for "${primaryKernelPathId}", using inline kernel path as fallback. ${error?.message || error}`
|
|
86
|
-
);
|
|
87
|
-
return {
|
|
88
|
-
kernelPath: primaryKernelPath,
|
|
89
|
-
kernelPathId: primaryKernelPathId,
|
|
90
|
-
kernelPathSource,
|
|
91
|
-
};
|
|
92
|
-
}
|
|
93
85
|
throw new Error(
|
|
94
|
-
`[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${
|
|
86
|
+
`[ExecutionPlan] Failed to resolve finiteness fallback kernel path "${fallbackKernelPathId}" ` +
|
|
95
87
|
`(from "${primaryKernelPathId}"): ${error?.message || error}`
|
|
96
88
|
);
|
|
97
89
|
}
|
|
@@ -7,6 +7,7 @@ import {
|
|
|
7
7
|
resolveExecutionV0KVIO,
|
|
8
8
|
resolveExecutionV0Precision,
|
|
9
9
|
} from '../../../config/execution-v0-contract-check.js';
|
|
10
|
+
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
10
11
|
import {
|
|
11
12
|
EXECUTION_V0_SCHEMA_ID,
|
|
12
13
|
DEFAULT_EXECUTION_V0_POLICIES,
|
|
@@ -856,7 +857,7 @@ function assertInlineKernelPathSessionCompatibility(path, sessionDefaults) {
|
|
|
856
857
|
}
|
|
857
858
|
}
|
|
858
859
|
|
|
859
|
-
function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers) {
|
|
860
|
+
function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers, finitenessFallbackKernelPathId = null) {
|
|
860
861
|
const activationDtype = normalizeDtype(
|
|
861
862
|
sessionDefaults?.compute?.defaults?.activationDtype ?? 'f16',
|
|
862
863
|
'sessionDefaults.compute.defaults.activationDtype'
|
|
@@ -877,6 +878,9 @@ function buildInlineKernelPath(steps, sessionDefaults, modelId, numLayers) {
|
|
|
877
878
|
description: 'Generated from manifest.inference.execution.steps',
|
|
878
879
|
activationDtype,
|
|
879
880
|
kvDtype,
|
|
881
|
+
...(typeof finitenessFallbackKernelPathId === 'string' && finitenessFallbackKernelPathId.length > 0
|
|
882
|
+
? { finitenessFallbackKernelPathId }
|
|
883
|
+
: {}),
|
|
880
884
|
decode: {
|
|
881
885
|
steps: decodeSteps.length > 0 ? decodeSteps : prefillSteps,
|
|
882
886
|
},
|
|
@@ -1107,7 +1111,26 @@ export function compileExecutionV0(options = {}) {
|
|
|
1107
1111
|
...resolvedDecodeSteps.filter((step) => step.phase === 'decode'),
|
|
1108
1112
|
];
|
|
1109
1113
|
|
|
1110
|
-
const
|
|
1114
|
+
const defaultKernelPathId = typeof manifestInference.defaultKernelPath === 'string'
|
|
1115
|
+
&& manifestInference.defaultKernelPath.trim().length > 0
|
|
1116
|
+
? manifestInference.defaultKernelPath.trim()
|
|
1117
|
+
: null;
|
|
1118
|
+
const finitenessFallbackKernelPathId = defaultKernelPathId
|
|
1119
|
+
? selectRuleValue(
|
|
1120
|
+
'inference',
|
|
1121
|
+
'kernelPath',
|
|
1122
|
+
'finitenessFallback',
|
|
1123
|
+
{ kernelPathId: defaultKernelPathId }
|
|
1124
|
+
)
|
|
1125
|
+
: null;
|
|
1126
|
+
|
|
1127
|
+
const kernelPath = buildInlineKernelPath(
|
|
1128
|
+
patchedSteps,
|
|
1129
|
+
resolvedSession,
|
|
1130
|
+
modelId,
|
|
1131
|
+
numLayers,
|
|
1132
|
+
finitenessFallbackKernelPathId
|
|
1133
|
+
);
|
|
1111
1134
|
const layerPipeline = buildLayerPipelineFromExecution(resolvedSteps);
|
|
1112
1135
|
const sessionPatch = buildSessionRuntimePatch(resolvedSession);
|
|
1113
1136
|
const modelOverrides = buildModelRuntimeOverrides(manifestInference);
|
|
@@ -1162,6 +1185,10 @@ export function applyExecutionV0RuntimeConfig(options = {}) {
|
|
|
1162
1185
|
}
|
|
1163
1186
|
|
|
1164
1187
|
const runtimeInferencePatch = { ...executionV0State.runtimeInferencePatch };
|
|
1188
|
+
if (runtimeInference.kernelPath !== undefined) {
|
|
1189
|
+
delete runtimeInferencePatch.kernelPath;
|
|
1190
|
+
delete runtimeInferencePatch.kernelPathSource;
|
|
1191
|
+
}
|
|
1165
1192
|
if (runtimeInferencePatch.modelOverrides) {
|
|
1166
1193
|
runtimeInferencePatch.modelOverrides = mergeRuntimeValues(
|
|
1167
1194
|
runtimeInferencePatch.modelOverrides,
|
|
@@ -42,6 +42,7 @@ export async function processFFNStandard(
|
|
|
42
42
|
hiddenSize,
|
|
43
43
|
probes: context.debugProbes,
|
|
44
44
|
recorder,
|
|
45
|
+
dtype: normedTensor.dtype,
|
|
45
46
|
});
|
|
46
47
|
|
|
47
48
|
// 2. FFN
|
|
@@ -58,6 +59,7 @@ export async function processFFNStandard(
|
|
|
58
59
|
hiddenSize,
|
|
59
60
|
probes: context.debugProbes,
|
|
60
61
|
recorder,
|
|
62
|
+
dtype: ffnOutput.dtype,
|
|
61
63
|
});
|
|
62
64
|
|
|
63
65
|
// 3. Residual add
|
|
@@ -72,6 +74,7 @@ export async function processFFNStandard(
|
|
|
72
74
|
hiddenSize,
|
|
73
75
|
probes: context.debugProbes,
|
|
74
76
|
recorder,
|
|
77
|
+
dtype: output.dtype,
|
|
75
78
|
});
|
|
76
79
|
|
|
77
80
|
if (normedTensor !== postAttn) {
|
|
@@ -71,9 +71,13 @@ export interface PipelineContexts {
|
|
|
71
71
|
*/
|
|
72
72
|
export interface RoPEConfig {
|
|
73
73
|
headDim: number;
|
|
74
|
+
rotaryDim?: number;
|
|
74
75
|
maxSeqLen: number;
|
|
75
76
|
ropeTheta: number;
|
|
76
77
|
ropeLocalTheta?: number | null;
|
|
78
|
+
mropeInterleaved?: boolean;
|
|
79
|
+
mropeSection?: number[] | null;
|
|
80
|
+
partialRotaryFactor?: number | null;
|
|
77
81
|
ropeScale: number;
|
|
78
82
|
ropeLocalScale?: number;
|
|
79
83
|
ropeScalingType?: string | null;
|
|
@@ -206,13 +206,45 @@ function isSameRoPEScalingConfig(
|
|
|
206
206
|
=== (rightScaling?.original_max_position_embeddings ?? null);
|
|
207
207
|
}
|
|
208
208
|
|
|
209
|
+
function resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor) {
|
|
210
|
+
if (rotaryDim != null) {
|
|
211
|
+
if (!Number.isFinite(rotaryDim) || rotaryDim <= 0 || (rotaryDim % 2) !== 0) {
|
|
212
|
+
throw new Error(`RoPE rotary dim must be a positive even integer; got "${rotaryDim}".`);
|
|
213
|
+
}
|
|
214
|
+
if (rotaryDim > headDim) {
|
|
215
|
+
throw new Error(`RoPE rotary dim ${rotaryDim} cannot exceed headDim ${headDim}.`);
|
|
216
|
+
}
|
|
217
|
+
return rotaryDim;
|
|
218
|
+
}
|
|
219
|
+
if (partialRotaryFactor == null) {
|
|
220
|
+
return headDim;
|
|
221
|
+
}
|
|
222
|
+
if (!Number.isFinite(partialRotaryFactor) || partialRotaryFactor <= 0 || partialRotaryFactor > 1) {
|
|
223
|
+
throw new Error(
|
|
224
|
+
`RoPE partialRotaryFactor must be a number in (0, 1]; got "${partialRotaryFactor}".`
|
|
225
|
+
);
|
|
226
|
+
}
|
|
227
|
+
const resolved = Math.trunc(headDim * partialRotaryFactor);
|
|
228
|
+
if (resolved <= 0 || (resolved % 2) !== 0) {
|
|
229
|
+
throw new Error(
|
|
230
|
+
`RoPE partialRotaryFactor=${partialRotaryFactor} with headDim=${headDim} resolves ` +
|
|
231
|
+
`to rotaryDim=${resolved}, but rotaryDim must be a positive even integer.`
|
|
232
|
+
);
|
|
233
|
+
}
|
|
234
|
+
return resolved;
|
|
235
|
+
}
|
|
236
|
+
|
|
209
237
|
|
|
210
238
|
export async function initRoPEFrequencies(config, useGPU) {
|
|
211
239
|
const {
|
|
212
240
|
headDim,
|
|
241
|
+
rotaryDim,
|
|
213
242
|
maxSeqLen,
|
|
214
243
|
ropeTheta,
|
|
215
244
|
ropeLocalTheta,
|
|
245
|
+
mropeInterleaved,
|
|
246
|
+
mropeSection,
|
|
247
|
+
partialRotaryFactor,
|
|
216
248
|
ropeScale,
|
|
217
249
|
ropeLocalScale,
|
|
218
250
|
ropeScalingType,
|
|
@@ -230,14 +262,23 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
230
262
|
const resolvedLocalTheta = ropeLocalTheta ?? ropeTheta;
|
|
231
263
|
const resolvedLocalScalingType = ropeLocalScalingType ?? ropeScalingType;
|
|
232
264
|
const resolvedLocalScaling = ropeLocalScaling ?? ropeScaling;
|
|
265
|
+
const resolvedRotaryDim = resolveRotaryDim(headDim, rotaryDim, partialRotaryFactor);
|
|
266
|
+
const halfDim = resolvedRotaryDim / 2;
|
|
267
|
+
if (mropeInterleaved === true && Array.isArray(mropeSection)) {
|
|
268
|
+
const expandedDim = mropeSection.reduce((sum, entry) => sum + entry, 0) * 2;
|
|
269
|
+
if (expandedDim !== resolvedRotaryDim) {
|
|
270
|
+
throw new Error(
|
|
271
|
+
`RoPE mropeSection expands to ${expandedDim} dims, but rotaryDim is ${resolvedRotaryDim}.`
|
|
272
|
+
);
|
|
273
|
+
}
|
|
274
|
+
}
|
|
233
275
|
|
|
234
|
-
const halfDim = headDim / 2;
|
|
235
276
|
const isYarn = ropeScalingType === 'yarn';
|
|
236
277
|
const isLocalYarn = resolvedLocalScalingType === 'yarn';
|
|
237
278
|
|
|
238
279
|
// Compute global (full_attention) frequencies
|
|
239
280
|
const globalFreqs = computeRoPEFreqsForTheta(
|
|
240
|
-
ropeTheta,
|
|
281
|
+
ropeTheta, resolvedRotaryDim, maxSeqLen, ropeScale, ropeScalingType, ropeScaling
|
|
241
282
|
);
|
|
242
283
|
|
|
243
284
|
// Compute local (sliding_attention) frequencies if different from global.
|
|
@@ -256,7 +297,7 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
256
297
|
if (hasDistinctLocalTheta || hasDistinctLocalScaling) {
|
|
257
298
|
localFreqs = computeRoPEFreqsForTheta(
|
|
258
299
|
resolvedLocalTheta,
|
|
259
|
-
|
|
300
|
+
resolvedRotaryDim,
|
|
260
301
|
maxSeqLen,
|
|
261
302
|
resolvedLocalScale,
|
|
262
303
|
resolvedLocalScalingType,
|
|
@@ -303,9 +344,10 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
303
344
|
|
|
304
345
|
log.debug(
|
|
305
346
|
'Pipeline',
|
|
306
|
-
`RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
|
|
347
|
+
`RoPE frequencies initialized (GPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
|
|
307
348
|
`theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
|
|
308
|
-
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
|
|
349
|
+
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
|
|
350
|
+
`interleaved=${mropeInterleaved === true}`
|
|
309
351
|
);
|
|
310
352
|
|
|
311
353
|
return {
|
|
@@ -318,9 +360,10 @@ export async function initRoPEFrequencies(config, useGPU) {
|
|
|
318
360
|
|
|
319
361
|
log.debug(
|
|
320
362
|
'Pipeline',
|
|
321
|
-
`RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, ` +
|
|
363
|
+
`RoPE frequencies initialized (CPU): ${maxSeqLen} positions, dim=${halfDim}, headDim=${headDim}, rotaryDim=${resolvedRotaryDim}, ` +
|
|
322
364
|
`theta=${ropeTheta}${hasDistinctLocalTheta ? `, localTheta=${resolvedLocalTheta}` : ''}, ` +
|
|
323
|
-
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}`
|
|
365
|
+
`scaling=${ropeScalingType ?? 'none'}:${ropeScale}${hasDistinctLocalScaling ? `, localScaling=${resolvedLocalScalingType ?? 'none'}:${resolvedLocalScale}` : ''}, ` +
|
|
366
|
+
`interleaved=${mropeInterleaved === true}`
|
|
324
367
|
);
|
|
325
368
|
|
|
326
369
|
return {
|
|
@@ -688,6 +731,10 @@ function applyChatMLTemplate(prompt) {
|
|
|
688
731
|
return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n`;
|
|
689
732
|
}
|
|
690
733
|
|
|
734
|
+
function applyQwenTemplate(prompt) {
|
|
735
|
+
return `<|im_start|>user\n${prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n`;
|
|
736
|
+
}
|
|
737
|
+
|
|
691
738
|
function applyTranslateGemmaTemplate() {
|
|
692
739
|
throw new Error(
|
|
693
740
|
'TranslateGemma template requires structured messages. ' +
|
|
@@ -702,7 +749,7 @@ const PROMPT_TEMPLATES = {
|
|
|
702
749
|
'llama3': applyHeaderBasedTemplate,
|
|
703
750
|
'gpt-oss': applyChannelBasedTemplate,
|
|
704
751
|
'chatml': applyChatMLTemplate,
|
|
705
|
-
'qwen':
|
|
752
|
+
'qwen': applyQwenTemplate,
|
|
706
753
|
'translategemma': applyTranslateGemmaTemplate,
|
|
707
754
|
};
|
|
708
755
|
|
|
@@ -721,7 +768,7 @@ export function applyChatTemplate(prompt, templateType) {
|
|
|
721
768
|
export const applyGemmaChatTemplate = applyTurnBasedTemplate;
|
|
722
769
|
export const applyLlama3ChatTemplate = applyHeaderBasedTemplate;
|
|
723
770
|
export const applyGptOssChatTemplate = applyChannelBasedTemplate;
|
|
724
|
-
export const applyQwenChatTemplate =
|
|
771
|
+
export const applyQwenChatTemplate = applyQwenTemplate;
|
|
725
772
|
|
|
726
773
|
|
|
727
774
|
export function isStopToken(token, stopTokenIds, eosTokenId) {
|
|
@@ -259,6 +259,8 @@ export async function processLayerGPU(layerIdx, inputBuffer, numTokens, isPrefil
|
|
|
259
259
|
attentionOutputGate: config.attentionOutputGate,
|
|
260
260
|
causalAttention: config.causalAttention,
|
|
261
261
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
262
|
+
ropeRotaryDim: config.ropeRotaryDim,
|
|
263
|
+
ropeInterleaved: config.ropeInterleaved,
|
|
262
264
|
tokenIds: context.currentTokenIds ?? null,
|
|
263
265
|
kernelPath: context.kernelPath ?? null,
|
|
264
266
|
disableRoPE,
|
|
@@ -661,6 +663,8 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
661
663
|
attentionOutputGate: config.attentionOutputGate,
|
|
662
664
|
causalAttention: config.causalAttention,
|
|
663
665
|
rmsNormWeightOffset: config.rmsNormWeightOffset,
|
|
666
|
+
ropeRotaryDim: config.ropeRotaryDim,
|
|
667
|
+
ropeInterleaved: config.ropeInterleaved,
|
|
664
668
|
tokenIds: context.currentTokenIds ?? null,
|
|
665
669
|
skipInputNorm: step.skipInputNorm === true,
|
|
666
670
|
activationDtype,
|
|
@@ -690,6 +694,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
690
694
|
hiddenSize,
|
|
691
695
|
probes: context.debugProbes,
|
|
692
696
|
recorder,
|
|
697
|
+
dtype: outputDtype,
|
|
693
698
|
});
|
|
694
699
|
}
|
|
695
700
|
break;
|
|
@@ -733,6 +738,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
733
738
|
hiddenSize,
|
|
734
739
|
probes: context.debugProbes,
|
|
735
740
|
recorder,
|
|
741
|
+
dtype: outputDtype,
|
|
736
742
|
});
|
|
737
743
|
}
|
|
738
744
|
break;
|
|
@@ -767,6 +773,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
767
773
|
hiddenSize,
|
|
768
774
|
probes: context.debugProbes,
|
|
769
775
|
recorder,
|
|
776
|
+
dtype: outputDtype,
|
|
770
777
|
});
|
|
771
778
|
}
|
|
772
779
|
break;
|
|
@@ -801,6 +808,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
801
808
|
hiddenSize,
|
|
802
809
|
probes: context.debugProbes,
|
|
803
810
|
recorder,
|
|
811
|
+
dtype: outputDtype,
|
|
804
812
|
});
|
|
805
813
|
}
|
|
806
814
|
break;
|
|
@@ -825,6 +833,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
825
833
|
hiddenSize,
|
|
826
834
|
probes: context.debugProbes,
|
|
827
835
|
recorder,
|
|
836
|
+
dtype: outputDtype,
|
|
828
837
|
});
|
|
829
838
|
}
|
|
830
839
|
break;
|
|
@@ -851,6 +860,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
851
860
|
hiddenSize,
|
|
852
861
|
probes: context.debugProbes,
|
|
853
862
|
recorder,
|
|
863
|
+
dtype: toDtype,
|
|
854
864
|
});
|
|
855
865
|
}
|
|
856
866
|
break;
|
|
@@ -880,6 +890,7 @@ async function processLayerPlanGPU(layerIdx, inputBuffer, numTokens, isPrefill,
|
|
|
880
890
|
hiddenSize,
|
|
881
891
|
probes: context.debugProbes,
|
|
882
892
|
recorder,
|
|
893
|
+
dtype: getSlotDtype('state') ?? activationDtype,
|
|
883
894
|
});
|
|
884
895
|
|
|
885
896
|
const computeConfig = context.runtimeComputeConfig ?? null;
|
|
@@ -299,9 +299,13 @@ export class InferencePipeline extends PipelineState {
|
|
|
299
299
|
const maxSeqLen = config.maxSeqLen;
|
|
300
300
|
const ropeBuffers = await initRoPEFrequencies({
|
|
301
301
|
headDim: config.headDim,
|
|
302
|
+
rotaryDim: config.ropeRotaryDim,
|
|
302
303
|
maxSeqLen,
|
|
303
304
|
ropeTheta: config.ropeTheta,
|
|
304
305
|
ropeLocalTheta: config.ropeLocalTheta,
|
|
306
|
+
mropeInterleaved: config.ropeInterleaved,
|
|
307
|
+
mropeSection: config.mropeSection,
|
|
308
|
+
partialRotaryFactor: config.partialRotaryFactor,
|
|
305
309
|
ropeScale: config.ropeScale,
|
|
306
310
|
ropeLocalScale: config.ropeLocalScale,
|
|
307
311
|
ropeScalingType: config.ropeScalingType,
|
|
@@ -64,6 +64,68 @@ function resolveSpecialTokens(specialTokensRaw, fallbackTokens, vocab) {
|
|
|
64
64
|
return resolved;
|
|
65
65
|
}
|
|
66
66
|
|
|
67
|
+
function resolveByteLevelPretokenizerConfig(preTokenizer) {
|
|
68
|
+
if (!preTokenizer || typeof preTokenizer !== 'object') {
|
|
69
|
+
return {
|
|
70
|
+
useByteLevel: false,
|
|
71
|
+
addPrefixSpace: null,
|
|
72
|
+
};
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
if (preTokenizer.type === 'ByteLevel') {
|
|
76
|
+
return {
|
|
77
|
+
useByteLevel: true,
|
|
78
|
+
addPrefixSpace: preTokenizer.add_prefix_space === true,
|
|
79
|
+
};
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
if (preTokenizer.type === 'Sequence' && Array.isArray(preTokenizer.pretokenizers)) {
|
|
83
|
+
for (const entry of preTokenizer.pretokenizers) {
|
|
84
|
+
const resolved = resolveByteLevelPretokenizerConfig(entry);
|
|
85
|
+
if (resolved.useByteLevel) {
|
|
86
|
+
return resolved;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
return {
|
|
92
|
+
useByteLevel: false,
|
|
93
|
+
addPrefixSpace: null,
|
|
94
|
+
};
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
function registerAddedTokens(addedTokens, vocab, reverseVocab, patterns, specialTokenIds, derivedSpecialTokens = null) {
|
|
98
|
+
let maxId = -1;
|
|
99
|
+
for (const token of addedTokens) {
|
|
100
|
+
const content = token?.content;
|
|
101
|
+
const id = typeof token?.id === 'number' ? token.id : parseInt(token?.id, 10);
|
|
102
|
+
if (!Number.isFinite(id) || !content) continue;
|
|
103
|
+
if (!vocab.has(content)) {
|
|
104
|
+
vocab.set(content, id);
|
|
105
|
+
reverseVocab.set(id, content);
|
|
106
|
+
}
|
|
107
|
+
if (id > maxId) maxId = id;
|
|
108
|
+
if (content.length > 1) {
|
|
109
|
+
patterns.push({ content, id });
|
|
110
|
+
}
|
|
111
|
+
if (token.special) {
|
|
112
|
+
specialTokenIds.add(id);
|
|
113
|
+
if (derivedSpecialTokens) {
|
|
114
|
+
if (derivedSpecialTokens.bos == null && (content === '<bos>' || content === '<s>' || content.includes('bos'))) {
|
|
115
|
+
derivedSpecialTokens.bos = id;
|
|
116
|
+
} else if (derivedSpecialTokens.eos == null && (content === '<eos>' || content === '</s>' || content.includes('eos'))) {
|
|
117
|
+
derivedSpecialTokens.eos = id;
|
|
118
|
+
} else if (derivedSpecialTokens.pad == null && (content === '<pad>' || content.includes('pad'))) {
|
|
119
|
+
derivedSpecialTokens.pad = id;
|
|
120
|
+
} else if (derivedSpecialTokens.unk == null && (content === '<unk>' || content.includes('unk'))) {
|
|
121
|
+
derivedSpecialTokens.unk = id;
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
return maxId;
|
|
127
|
+
}
|
|
128
|
+
|
|
67
129
|
|
|
68
130
|
export class TransformersTokenizer extends BaseTokenizer {
|
|
69
131
|
|
|
@@ -156,6 +218,10 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
156
218
|
|
|
157
219
|
#byteDecoder = null;
|
|
158
220
|
|
|
221
|
+
#byteEncoder = null;
|
|
222
|
+
|
|
223
|
+
#useByteLevelEncoding = false;
|
|
224
|
+
|
|
159
225
|
|
|
160
226
|
constructor(config = {}) {
|
|
161
227
|
// BundledTokenizer gets vocabSize from load(), so defer validation
|
|
@@ -199,9 +265,20 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
199
265
|
}
|
|
200
266
|
|
|
201
267
|
this.#byteDecoder = new Map();
|
|
268
|
+
this.#byteEncoder = new Map();
|
|
202
269
|
for (let i = 0; i < base.length; i++) {
|
|
203
270
|
this.#byteDecoder.set(String.fromCodePoint(chars[i]), base[i]);
|
|
271
|
+
this.#byteEncoder.set(base[i], String.fromCodePoint(chars[i]));
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
#encodeByteLevelText(text) {
|
|
276
|
+
const bytes = new TextEncoder().encode(text);
|
|
277
|
+
let out = '';
|
|
278
|
+
for (const byte of bytes) {
|
|
279
|
+
out += this.#byteEncoder?.get(byte) ?? String.fromCharCode(byte);
|
|
204
280
|
}
|
|
281
|
+
return out;
|
|
205
282
|
}
|
|
206
283
|
|
|
207
284
|
|
|
@@ -290,30 +367,16 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
290
367
|
eos: null,
|
|
291
368
|
unk: null,
|
|
292
369
|
};
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
specialTokenIds.add(id);
|
|
304
|
-
if (content.length > 1) {
|
|
305
|
-
specialTokenPatterns.push({ content, id });
|
|
306
|
-
}
|
|
307
|
-
if (derivedSpecialTokens.bos == null && (content === '<bos>' || content === '<s>' || content.includes('bos'))) {
|
|
308
|
-
derivedSpecialTokens.bos = id;
|
|
309
|
-
} else if (derivedSpecialTokens.eos == null && (content === '<eos>' || content === '</s>' || content.includes('eos'))) {
|
|
310
|
-
derivedSpecialTokens.eos = id;
|
|
311
|
-
} else if (derivedSpecialTokens.pad == null && (content === '<pad>' || content.includes('pad'))) {
|
|
312
|
-
derivedSpecialTokens.pad = id;
|
|
313
|
-
} else if (derivedSpecialTokens.unk == null && (content === '<unk>' || content.includes('unk'))) {
|
|
314
|
-
derivedSpecialTokens.unk = id;
|
|
315
|
-
}
|
|
316
|
-
}
|
|
370
|
+
const addedMaxId = registerAddedTokens(
|
|
371
|
+
addedTokens,
|
|
372
|
+
this.#vocab,
|
|
373
|
+
this.#reverseVocab,
|
|
374
|
+
specialTokenPatterns,
|
|
375
|
+
specialTokenIds,
|
|
376
|
+
derivedSpecialTokens
|
|
377
|
+
);
|
|
378
|
+
if (addedMaxId > maxId) {
|
|
379
|
+
maxId = addedMaxId;
|
|
317
380
|
}
|
|
318
381
|
|
|
319
382
|
const specialTokensRaw = hf.special_tokens_map || hf.specialTokens || hf.special_tokens || null;
|
|
@@ -351,6 +414,7 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
351
414
|
|
|
352
415
|
// Handle behavior flags (use HF config if present, else runtime defaults)
|
|
353
416
|
const runtimeDefaults = getRuntimeConfig().inference.tokenizer;
|
|
417
|
+
const byteLevelPretokenizer = resolveByteLevelPretokenizerConfig(hf.pre_tokenizer);
|
|
354
418
|
const configuredAddBosToken = this.addBosToken;
|
|
355
419
|
const configuredAddEosToken = this.addEosToken;
|
|
356
420
|
this.addBosToken =
|
|
@@ -378,9 +442,16 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
378
442
|
// - runtime config addSpacePrefix (user override or null for auto-detect)
|
|
379
443
|
const decoderPrepend = hf.decoder?.prepend_scheme === 'always' || hf.decoder?.add_prefix_space === true;
|
|
380
444
|
const normalizerPrepend = hf.normalizer?.prepend_scheme === 'always' || hf.normalizer?.add_prefix_space === true;
|
|
445
|
+
this.#useByteLevelEncoding = byteLevelPretokenizer.useByteLevel;
|
|
381
446
|
const runtimeSpacePrefix = runtimeDefaults.addSpacePrefix;
|
|
382
447
|
// Use explicit runtime config if set (non-null), otherwise auto-detect from tokenizer.json
|
|
383
|
-
this.#addSpacePrefix = runtimeSpacePrefix
|
|
448
|
+
this.#addSpacePrefix = runtimeSpacePrefix
|
|
449
|
+
?? byteLevelPretokenizer.addPrefixSpace
|
|
450
|
+
?? model.add_prefix_space
|
|
451
|
+
?? model.add_dummy_prefix
|
|
452
|
+
?? decoderPrepend
|
|
453
|
+
?? normalizerPrepend
|
|
454
|
+
?? false;
|
|
384
455
|
log.debug('Tokenizer', `addSpacePrefix=${this.#addSpacePrefix} (runtime=${runtimeSpacePrefix}, model=${model.add_prefix_space ?? model.add_dummy_prefix}, decoder=${decoderPrepend}, normalizer=${normalizerPrepend})`);
|
|
385
456
|
|
|
386
457
|
// Detect space prefix style by checking which WORD tokens exist in vocab
|
|
@@ -469,11 +540,47 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
469
540
|
this.#tokenTypes = tokenizerJson.tokenTypes;
|
|
470
541
|
}
|
|
471
542
|
|
|
543
|
+
let maxId = -1;
|
|
544
|
+
for (const id of this.#vocab.values()) {
|
|
545
|
+
if (Number.isFinite(id) && id > maxId) {
|
|
546
|
+
maxId = id;
|
|
547
|
+
}
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
const addedTokens = Array.isArray(tokenizerJson.added_tokens) ? tokenizerJson.added_tokens : [];
|
|
551
|
+
const tokenPatterns = [];
|
|
552
|
+
const specialTokenIds = new Set();
|
|
553
|
+
const derivedSpecialTokens = {
|
|
554
|
+
pad: null,
|
|
555
|
+
bos: null,
|
|
556
|
+
eos: null,
|
|
557
|
+
unk: null,
|
|
558
|
+
};
|
|
559
|
+
const addedMaxId = registerAddedTokens(
|
|
560
|
+
addedTokens,
|
|
561
|
+
this.#vocab,
|
|
562
|
+
this.#reverseVocab,
|
|
563
|
+
tokenPatterns,
|
|
564
|
+
specialTokenIds,
|
|
565
|
+
derivedSpecialTokens
|
|
566
|
+
);
|
|
567
|
+
if (addedMaxId > maxId) {
|
|
568
|
+
maxId = addedMaxId;
|
|
569
|
+
}
|
|
570
|
+
|
|
472
571
|
// Set special tokens - support both camelCase and snake_case formats
|
|
473
572
|
const specialTokensRaw = (tokenizerJson.specialTokens || (tokenizerJson).special_tokens);
|
|
474
|
-
this.specialTokens = resolveSpecialTokens(
|
|
573
|
+
this.specialTokens = resolveSpecialTokens(
|
|
574
|
+
specialTokensRaw,
|
|
575
|
+
{
|
|
576
|
+
...derivedSpecialTokens,
|
|
577
|
+
...this.specialTokens,
|
|
578
|
+
},
|
|
579
|
+
this.#vocab
|
|
580
|
+
);
|
|
475
581
|
log.debug('Tokenizer', `Special tokens: BOS=${this.specialTokens.bos}, EOS=${this.specialTokens.eos}`);
|
|
476
|
-
this.#specialTokenIds =
|
|
582
|
+
this.#specialTokenIds = specialTokenIds;
|
|
583
|
+
this.#specialTokenPatterns = tokenPatterns;
|
|
477
584
|
const builtinSpecials = [
|
|
478
585
|
this.specialTokens.pad,
|
|
479
586
|
this.specialTokens.bos,
|
|
@@ -485,8 +592,13 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
485
592
|
this.#specialTokenIds.add(id);
|
|
486
593
|
}
|
|
487
594
|
}
|
|
595
|
+
this.#specialTokenPatterns.sort((a, b) => b.content.length - a.content.length);
|
|
596
|
+
if (maxId >= 0) {
|
|
597
|
+
this.vocabSize = Math.max(this.vocabSize, maxId + 1);
|
|
598
|
+
}
|
|
488
599
|
|
|
489
600
|
const runtimeDefaults = getRuntimeConfig().inference.tokenizer;
|
|
601
|
+
const byteLevelPretokenizer = resolveByteLevelPretokenizerConfig(tokenizerJson.pre_tokenizer);
|
|
490
602
|
const configuredAddBosToken = this.addBosToken;
|
|
491
603
|
const configuredAddEosToken = this.addEosToken;
|
|
492
604
|
this.addBosToken =
|
|
@@ -505,9 +617,11 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
505
617
|
if (this.addEosToken && this.specialTokens.eos == null) {
|
|
506
618
|
throw new Error('[Tokenizer] addEosToken is enabled but eos token is missing.');
|
|
507
619
|
}
|
|
620
|
+
this.#useByteLevelEncoding = byteLevelPretokenizer.useByteLevel;
|
|
508
621
|
// NOTE: Default to FALSE - first word shouldn't get space prefix
|
|
509
622
|
// Space prefixes are only for words that follow a space in original text
|
|
510
|
-
this.#addSpacePrefix = tokenizerJson.addSpacePrefix === true
|
|
623
|
+
this.#addSpacePrefix = tokenizerJson.addSpacePrefix === true
|
|
624
|
+
|| byteLevelPretokenizer.addPrefixSpace === true;
|
|
511
625
|
|
|
512
626
|
// Detect space prefix style based on vocab tokens
|
|
513
627
|
// GPT-style uses 'Ġ' (U+0120), SentencePiece uses '▁' (U+2581)
|
|
@@ -548,7 +662,8 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
548
662
|
ids.push(this.specialTokens.bos);
|
|
549
663
|
}
|
|
550
664
|
|
|
551
|
-
// Split text around
|
|
665
|
+
// Split text around literal added tokens and special tokens, then tokenize
|
|
666
|
+
// the remaining plain-text segments normally.
|
|
552
667
|
const segments = this.#splitOnSpecialTokens(text);
|
|
553
668
|
for (const seg of segments) {
|
|
554
669
|
if (seg.isSpecial && seg.id !== undefined) {
|
|
@@ -690,11 +805,19 @@ export class BundledTokenizer extends BaseTokenizer {
|
|
|
690
805
|
if (text.length === 0) return [];
|
|
691
806
|
|
|
692
807
|
let normalized = text;
|
|
693
|
-
|
|
694
|
-
|
|
808
|
+
let prefixed;
|
|
809
|
+
if (this.#useByteLevelEncoding) {
|
|
810
|
+
if (this.#addSpacePrefix && !normalized.startsWith(' ')) {
|
|
811
|
+
normalized = ` ${normalized}`;
|
|
812
|
+
}
|
|
813
|
+
prefixed = this.#encodeByteLevelText(normalized);
|
|
814
|
+
} else {
|
|
815
|
+
if (this.#addSpacePrefix && !normalized.startsWith(' ')) {
|
|
816
|
+
normalized = ` ${normalized}`;
|
|
817
|
+
}
|
|
818
|
+
const sp = this.#spacePrefixChar;
|
|
819
|
+
prefixed = normalized.replace(/ /g, sp);
|
|
695
820
|
}
|
|
696
|
-
const sp = this.#spacePrefixChar;
|
|
697
|
-
const prefixed = normalized.replace(/ /g, sp);
|
|
698
821
|
|
|
699
822
|
if (this.#mergeRanks.size === 0) {
|
|
700
823
|
return this.#encodeBPEGreedy(prefixed);
|