@simulatte/doppler 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +16 -23
- package/package.json +14 -1
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +7 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +12 -2
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +2 -1
- package/src/config/schema/manifest.schema.js +16 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +58 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +57 -41
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +62 -8
- package/src/inference/pipelines/text/attention/run.js +62 -8
- package/src/inference/pipelines/text/config.js +3 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +41 -19
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.js +78 -20
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +3 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +44 -25
|
@@ -1,33 +1,29 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
|
|
5
|
-
const values = new Uint32Array(1);
|
|
6
|
-
crypto.getRandomValues(values);
|
|
7
|
-
return values[0] / 4294967296;
|
|
1
|
+
function requireRandomSource(random) {
|
|
2
|
+
if (typeof random !== 'function') {
|
|
3
|
+
throw new Error('network evolution requires an explicit random() source.');
|
|
8
4
|
}
|
|
9
|
-
|
|
10
|
-
return fallbackRandomState / 4294967296;
|
|
5
|
+
return random;
|
|
11
6
|
}
|
|
12
7
|
|
|
13
|
-
export const mutateGenome = (genome, mutationRate = 0.1) => {
|
|
8
|
+
export const mutateGenome = (genome, mutationRate = 0.1, random = null) => {
|
|
9
|
+
const sample = requireRandomSource(random);
|
|
14
10
|
|
|
15
11
|
const mutated = JSON.parse(JSON.stringify(genome));
|
|
16
|
-
if (
|
|
12
|
+
if (sample() < mutationRate) {
|
|
17
13
|
|
|
18
14
|
const types = ['chain', 'tree', 'mesh', 'dag'];
|
|
19
|
-
mutated.topology.type = types[Math.floor(
|
|
15
|
+
mutated.topology.type = types[Math.floor(sample() * types.length)];
|
|
20
16
|
}
|
|
21
17
|
|
|
22
18
|
for (const node of mutated.nodes) {
|
|
23
|
-
if (
|
|
24
|
-
node.temperature = Math.min(1, Math.max(0, node.temperature + (
|
|
19
|
+
if (sample() < mutationRate && typeof node.temperature === 'number') {
|
|
20
|
+
node.temperature = Math.min(1, Math.max(0, node.temperature + (sample() - 0.5) * 0.2));
|
|
25
21
|
}
|
|
26
22
|
}
|
|
27
23
|
|
|
28
24
|
for (const edge of mutated.edges) {
|
|
29
|
-
if (
|
|
30
|
-
edge.weight = Math.min(1, Math.max(0, edge.weight + (
|
|
25
|
+
if (sample() < mutationRate) {
|
|
26
|
+
edge.weight = Math.min(1, Math.max(0, edge.weight + (sample() - 0.5) * 0.4));
|
|
31
27
|
}
|
|
32
28
|
}
|
|
33
29
|
|
|
@@ -35,8 +31,9 @@ export const mutateGenome = (genome, mutationRate = 0.1) => {
|
|
|
35
31
|
};
|
|
36
32
|
|
|
37
33
|
|
|
38
|
-
export const crossoverGenome = (a, b) => {
|
|
39
|
-
|
|
34
|
+
export const crossoverGenome = (a, b, random = null) => {
|
|
35
|
+
const sample = requireRandomSource(random);
|
|
36
|
+
return sample() < 0.5 ? JSON.parse(JSON.stringify(a)) : JSON.parse(JSON.stringify(b));
|
|
40
37
|
};
|
|
41
38
|
|
|
42
39
|
|
|
@@ -48,7 +45,9 @@ export async function evolveNetwork(config) {
|
|
|
48
45
|
mutationRate = 0.1,
|
|
49
46
|
evaluate,
|
|
50
47
|
randomGenome,
|
|
48
|
+
random,
|
|
51
49
|
} = config;
|
|
50
|
+
const sample = requireRandomSource(random);
|
|
52
51
|
|
|
53
52
|
let population = Array.from({ length: populationSize }, () => randomGenome());
|
|
54
53
|
|
|
@@ -63,9 +62,9 @@ export async function evolveNetwork(config) {
|
|
|
63
62
|
const offspring = [];
|
|
64
63
|
|
|
65
64
|
while (offspring.length < populationSize - eliteCount) {
|
|
66
|
-
const parentA = scored[Math.floor(
|
|
67
|
-
const parentB = scored[Math.floor(
|
|
68
|
-
const child = mutateGenome(crossoverGenome(parentA, parentB), mutationRate);
|
|
65
|
+
const parentA = scored[Math.floor(sample() * scored.length)].genome;
|
|
66
|
+
const parentB = scored[Math.floor(sample() * scored.length)].genome;
|
|
67
|
+
const child = mutateGenome(crossoverGenome(parentA, parentB, sample), mutationRate, sample);
|
|
69
68
|
offspring.push(child);
|
|
70
69
|
}
|
|
71
70
|
|
|
@@ -8,6 +8,8 @@ export type PipelineContextOptions = {
|
|
|
8
8
|
assignProgress?: boolean;
|
|
9
9
|
};
|
|
10
10
|
|
|
11
|
+
export declare function restorePipelineContexts(target: Record<string, unknown>): boolean;
|
|
12
|
+
|
|
11
13
|
export declare function applyPipelineContexts(
|
|
12
14
|
target: Record<string, unknown>,
|
|
13
15
|
contexts?: Record<string, unknown>,
|
|
@@ -15,4 +17,5 @@ export declare function applyPipelineContexts(
|
|
|
15
17
|
): {
|
|
16
18
|
runtimeConfig: Record<string, unknown>;
|
|
17
19
|
sharedDebug: Record<string, unknown> | null | undefined;
|
|
20
|
+
restore: () => void;
|
|
18
21
|
};
|
|
@@ -1,8 +1,115 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import {
|
|
2
|
+
getDevice,
|
|
3
|
+
getKernelCapabilities,
|
|
4
|
+
getPlatformConfig,
|
|
5
|
+
setDevice,
|
|
6
|
+
} from '../../gpu/device.js';
|
|
2
7
|
import { applyDebugConfig, setGPUDevice } from '../../debug/index.js';
|
|
3
8
|
import { getRuntimeConfig, setRuntimeConfig } from '../../config/runtime.js';
|
|
9
|
+
import {
|
|
10
|
+
getLogLevel,
|
|
11
|
+
getTrace,
|
|
12
|
+
isSilentMode,
|
|
13
|
+
setLogLevel,
|
|
14
|
+
setSilentMode,
|
|
15
|
+
setTrace,
|
|
16
|
+
} from '../../debug/config.js';
|
|
17
|
+
import {
|
|
18
|
+
gpuDevice as debugGpuDevice,
|
|
19
|
+
traceBreakOnAnomaly,
|
|
20
|
+
traceLayerFilter,
|
|
21
|
+
traceMaxDecodeSteps,
|
|
22
|
+
} from '../../debug/config.js';
|
|
23
|
+
|
|
24
|
+
const RESTORE_PIPELINE_CONTEXTS = Symbol('restorePipelineContexts');
|
|
25
|
+
|
|
26
|
+
function captureTargetField(target, key) {
|
|
27
|
+
return {
|
|
28
|
+
present: Object.prototype.hasOwnProperty.call(target, key),
|
|
29
|
+
value: target[key],
|
|
30
|
+
};
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
function restoreTargetField(target, key, snapshot) {
|
|
34
|
+
if (snapshot.present) {
|
|
35
|
+
target[key] = snapshot.value;
|
|
36
|
+
return;
|
|
37
|
+
}
|
|
38
|
+
delete target[key];
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
function captureDebugState() {
|
|
42
|
+
return {
|
|
43
|
+
logLevel: getLogLevel(),
|
|
44
|
+
traceCategories: getTrace(),
|
|
45
|
+
traceLayers: [...traceLayerFilter],
|
|
46
|
+
traceMaxDecodeSteps,
|
|
47
|
+
traceBreakOnAnomaly,
|
|
48
|
+
silentMode: isSilentMode(),
|
|
49
|
+
gpuDevice: debugGpuDevice,
|
|
50
|
+
};
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
function restoreDebugState(snapshot) {
|
|
54
|
+
if (snapshot.silentMode !== isSilentMode()) {
|
|
55
|
+
setSilentMode(snapshot.silentMode);
|
|
56
|
+
}
|
|
57
|
+
if (getLogLevel() !== snapshot.logLevel) {
|
|
58
|
+
setLogLevel(snapshot.logLevel);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
const traceCategories = getTrace();
|
|
62
|
+
const traceChanged = traceCategories.length !== snapshot.traceCategories.length
|
|
63
|
+
|| traceCategories.some((category, idx) => category !== snapshot.traceCategories[idx])
|
|
64
|
+
|| traceLayerFilter.length !== snapshot.traceLayers.length
|
|
65
|
+
|| traceLayerFilter.some((layer, idx) => layer !== snapshot.traceLayers[idx])
|
|
66
|
+
|| traceMaxDecodeSteps !== snapshot.traceMaxDecodeSteps
|
|
67
|
+
|| traceBreakOnAnomaly !== snapshot.traceBreakOnAnomaly;
|
|
68
|
+
|
|
69
|
+
if (traceChanged) {
|
|
70
|
+
if (snapshot.traceCategories.length > 0) {
|
|
71
|
+
setTrace(snapshot.traceCategories.join(','), {
|
|
72
|
+
layers: snapshot.traceLayers.length > 0 ? snapshot.traceLayers : undefined,
|
|
73
|
+
maxDecodeSteps: snapshot.traceMaxDecodeSteps > 0 ? snapshot.traceMaxDecodeSteps : undefined,
|
|
74
|
+
breakOnAnomaly: snapshot.traceBreakOnAnomaly,
|
|
75
|
+
});
|
|
76
|
+
} else {
|
|
77
|
+
setTrace(false);
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
setGPUDevice(snapshot.gpuDevice ?? null);
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
export function restorePipelineContexts(target) {
|
|
85
|
+
const restore = target?.[RESTORE_PIPELINE_CONTEXTS];
|
|
86
|
+
if (typeof restore !== 'function') {
|
|
87
|
+
return false;
|
|
88
|
+
}
|
|
89
|
+
delete target[RESTORE_PIPELINE_CONTEXTS];
|
|
90
|
+
restore();
|
|
91
|
+
return true;
|
|
92
|
+
}
|
|
4
93
|
|
|
5
94
|
export function applyPipelineContexts(target, contexts = {}, options = {}) {
|
|
95
|
+
restorePipelineContexts(target);
|
|
96
|
+
|
|
97
|
+
const previousRuntimeConfig = getRuntimeConfig();
|
|
98
|
+
const previousDevice = getDevice();
|
|
99
|
+
const previousPlatformConfig = getPlatformConfig();
|
|
100
|
+
const previousAdapterInfo = previousDevice
|
|
101
|
+
? (getKernelCapabilities().adapterInfo ?? null)
|
|
102
|
+
: null;
|
|
103
|
+
const previousDebugState = captureDebugState();
|
|
104
|
+
const targetSnapshot = {
|
|
105
|
+
gpuContext: captureTargetField(target, 'gpuContext'),
|
|
106
|
+
useGPU: captureTargetField(target, 'useGPU'),
|
|
107
|
+
memoryContext: captureTargetField(target, 'memoryContext'),
|
|
108
|
+
storageContext: captureTargetField(target, 'storageContext'),
|
|
109
|
+
baseUrl: captureTargetField(target, 'baseUrl'),
|
|
110
|
+
_onProgress: captureTargetField(target, '_onProgress'),
|
|
111
|
+
};
|
|
112
|
+
|
|
6
113
|
const runtimeConfig = contexts.runtimeConfig
|
|
7
114
|
? setRuntimeConfig(contexts.runtimeConfig)
|
|
8
115
|
: getRuntimeConfig();
|
|
@@ -40,5 +147,38 @@ export function applyPipelineContexts(target, contexts = {}, options = {}) {
|
|
|
40
147
|
target._onProgress = contexts.onProgress;
|
|
41
148
|
}
|
|
42
149
|
|
|
43
|
-
|
|
150
|
+
let restored = false;
|
|
151
|
+
const restore = () => {
|
|
152
|
+
if (restored) {
|
|
153
|
+
return;
|
|
154
|
+
}
|
|
155
|
+
restored = true;
|
|
156
|
+
delete target[RESTORE_PIPELINE_CONTEXTS];
|
|
157
|
+
|
|
158
|
+
setRuntimeConfig(previousRuntimeConfig);
|
|
159
|
+
if (previousDevice) {
|
|
160
|
+
setDevice(previousDevice, {
|
|
161
|
+
platformConfig: previousPlatformConfig,
|
|
162
|
+
adapterInfo: previousAdapterInfo,
|
|
163
|
+
});
|
|
164
|
+
} else {
|
|
165
|
+
setDevice(null);
|
|
166
|
+
}
|
|
167
|
+
restoreDebugState(previousDebugState);
|
|
168
|
+
restoreTargetField(target, 'gpuContext', targetSnapshot.gpuContext);
|
|
169
|
+
restoreTargetField(target, 'useGPU', targetSnapshot.useGPU);
|
|
170
|
+
restoreTargetField(target, 'memoryContext', targetSnapshot.memoryContext);
|
|
171
|
+
restoreTargetField(target, 'storageContext', targetSnapshot.storageContext);
|
|
172
|
+
restoreTargetField(target, 'baseUrl', targetSnapshot.baseUrl);
|
|
173
|
+
restoreTargetField(target, '_onProgress', targetSnapshot._onProgress);
|
|
174
|
+
};
|
|
175
|
+
|
|
176
|
+
Object.defineProperty(target, RESTORE_PIPELINE_CONTEXTS, {
|
|
177
|
+
value: restore,
|
|
178
|
+
configurable: true,
|
|
179
|
+
enumerable: false,
|
|
180
|
+
writable: false,
|
|
181
|
+
});
|
|
182
|
+
|
|
183
|
+
return { runtimeConfig, sharedDebug, restore };
|
|
44
184
|
}
|
|
@@ -54,8 +54,13 @@ export function createDiffusionIndexBuffer(device, indices, label) {
|
|
|
54
54
|
size: indices.byteLength,
|
|
55
55
|
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
|
|
56
56
|
});
|
|
57
|
-
|
|
58
|
-
|
|
57
|
+
try {
|
|
58
|
+
device.queue.writeBuffer(buffer, 0, indices);
|
|
59
|
+
return buffer;
|
|
60
|
+
} catch (error) {
|
|
61
|
+
buffer.destroy();
|
|
62
|
+
throw error;
|
|
63
|
+
}
|
|
59
64
|
}
|
|
60
65
|
|
|
61
66
|
export function expectDiffusionWeight(weight, label) {
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { getDevice, getKernelCapabilities } from '../../../gpu/device.js';
|
|
2
2
|
import { log, trace } from '../../../debug/index.js';
|
|
3
3
|
import { registerPipeline } from '../registry.js';
|
|
4
|
-
import { applyPipelineContexts } from '../context.js';
|
|
4
|
+
import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
|
|
5
5
|
import { createInitializedPipeline } from '../factory.js';
|
|
6
6
|
import { createRng, sampleNormal } from '../rng.js';
|
|
7
7
|
import { initializeDiffusion } from './init.js';
|
|
@@ -319,6 +319,7 @@ export class DiffusionPipeline {
|
|
|
319
319
|
this.vaeWeights = null;
|
|
320
320
|
this.textEncoderWeights = null;
|
|
321
321
|
this.transformerWeights = null;
|
|
322
|
+
restorePipelineContexts(this);
|
|
322
323
|
}
|
|
323
324
|
|
|
324
325
|
async ensureVaeWeights() {
|
|
@@ -299,26 +299,26 @@ function resolveModulationSegments(weight, hiddenSize, fallbackSegments, resolve
|
|
|
299
299
|
if (Number.isInteger(segments) && segments > 0) {
|
|
300
300
|
return segments;
|
|
301
301
|
}
|
|
302
|
-
|
|
303
|
-
'
|
|
304
|
-
`
|
|
302
|
+
throw new Error(
|
|
303
|
+
`Modulation segments mismatch for ${name || 'unknown'}: rows=${rows}, hidden=${hiddenSize}, ` +
|
|
304
|
+
`expected an integer multiple instead of falling back to ${fallbackSegments}.`
|
|
305
305
|
);
|
|
306
306
|
}
|
|
307
|
-
|
|
307
|
+
throw new Error(
|
|
308
|
+
`Modulation tensor "${name || 'unknown'}" is missing shape metadata. ` +
|
|
309
|
+
`Runtime cannot fall back to ${fallbackSegments} segments.`
|
|
310
|
+
);
|
|
308
311
|
}
|
|
309
312
|
|
|
310
313
|
function resolveModulationOffsets(segments, hiddenSize) {
|
|
311
|
-
if (segments
|
|
314
|
+
if (segments === 9) {
|
|
312
315
|
return {
|
|
313
316
|
attn: { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 },
|
|
314
317
|
attn2: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
|
|
315
318
|
ff: { scale: hiddenSize * 6, shift: hiddenSize * 7, gate: hiddenSize * 8 },
|
|
316
319
|
};
|
|
317
320
|
}
|
|
318
|
-
if (segments
|
|
319
|
-
if (segments !== 6) {
|
|
320
|
-
log.warn('Diffusion', `Unexpected modulation segment count=${segments}; using 6-segment layout.`);
|
|
321
|
-
}
|
|
321
|
+
if (segments === 6) {
|
|
322
322
|
const attn = { scale: 0, shift: hiddenSize, gate: hiddenSize * 2 };
|
|
323
323
|
return {
|
|
324
324
|
attn,
|
|
@@ -326,7 +326,7 @@ function resolveModulationOffsets(segments, hiddenSize) {
|
|
|
326
326
|
ff: { scale: hiddenSize * 3, shift: hiddenSize * 4, gate: hiddenSize * 5 },
|
|
327
327
|
};
|
|
328
328
|
}
|
|
329
|
-
throw new Error(`Unsupported modulation segments=${segments} (expected
|
|
329
|
+
throw new Error(`Unsupported modulation segments=${segments} (expected 6 or 9).`);
|
|
330
330
|
}
|
|
331
331
|
|
|
332
332
|
async function buildModulation(timeText, weight, bias, hiddenSize, segments, runtime, matmul, weightName, ops) {
|
|
@@ -118,13 +118,9 @@ function resolveAttentionHeadShape(channels, config) {
|
|
|
118
118
|
headDim: channels / configuredNumHeads,
|
|
119
119
|
};
|
|
120
120
|
}
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
return {
|
|
125
|
-
numHeads: Math.max(1, channels / headDim),
|
|
126
|
-
headDim,
|
|
127
|
-
};
|
|
121
|
+
throw new Error(
|
|
122
|
+
`VAE attention requires explicit compatible attention_head_dim or num_attention_heads for channels=${channels}.`
|
|
123
|
+
);
|
|
128
124
|
}
|
|
129
125
|
|
|
130
126
|
function createBiasTensor(weight, label, fallbackDtype = 'f16') {
|
|
@@ -16,10 +16,10 @@ import { log, trace } from '../../../debug/index.js';
|
|
|
16
16
|
import { DEFAULT_ENERGY_CONFIG } from '../../../config/schema/energy.schema.js';
|
|
17
17
|
import { f32ToF16Array, f16ToF32Array } from '../../kv-cache/types.js';
|
|
18
18
|
import { registerPipeline } from '../registry.js';
|
|
19
|
-
import { applyPipelineContexts } from '../context.js';
|
|
19
|
+
import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
|
|
20
20
|
import { createInitializedPipeline } from '../factory.js';
|
|
21
21
|
import { createRng, sampleNormal } from '../rng.js';
|
|
22
|
-
import { mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
|
|
22
|
+
import { buildQuintelKernelFlags, mergeQuintelConfig, runQuintelEnergyLoop } from './quintel.js';
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
function generateRandomArray(count, mode, seed, scale) {
|
|
@@ -140,24 +140,28 @@ async function createEnergyTensor(device, data, dtype, shape, label) {
|
|
|
140
140
|
const byteLength = data.byteLength;
|
|
141
141
|
const alignedSize = Math.ceil(byteLength / 4) * 4;
|
|
142
142
|
const buffer = acquireBuffer(alignedSize, undefined, label);
|
|
143
|
+
try {
|
|
144
|
+
let payload = data;
|
|
145
|
+
if (alignedSize !== byteLength) {
|
|
146
|
+
const padded = new Uint8Array(alignedSize);
|
|
147
|
+
const view = data instanceof ArrayBuffer
|
|
148
|
+
? new Uint8Array(data)
|
|
149
|
+
: new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
150
|
+
padded.set(view);
|
|
151
|
+
payload = padded;
|
|
152
|
+
}
|
|
143
153
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
const
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
device.queue.writeBuffer(buffer, 0, payload);
|
|
155
|
-
const tensor = createTensor(buffer, dtype, shape, label);
|
|
156
|
-
const expectedBytes = tensorBytes(shape, dtype);
|
|
157
|
-
if (expectedBytes !== byteLength) {
|
|
158
|
-
log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
|
|
154
|
+
device.queue.writeBuffer(buffer, 0, payload);
|
|
155
|
+
const tensor = createTensor(buffer, dtype, shape, label);
|
|
156
|
+
const expectedBytes = tensorBytes(shape, dtype);
|
|
157
|
+
if (expectedBytes !== byteLength) {
|
|
158
|
+
log.warn('Energy', `${label} byte length mismatch: expected ${expectedBytes}, got ${byteLength}`);
|
|
159
|
+
}
|
|
160
|
+
return tensor;
|
|
161
|
+
} catch (error) {
|
|
162
|
+
releaseBuffer(buffer);
|
|
163
|
+
throw error;
|
|
159
164
|
}
|
|
160
|
-
return tensor;
|
|
161
165
|
}
|
|
162
166
|
|
|
163
167
|
async function readTensorToFloat32(tensor) {
|
|
@@ -202,6 +206,7 @@ export class EnergyPipeline {
|
|
|
202
206
|
|
|
203
207
|
async unload() {
|
|
204
208
|
this.manifest = null;
|
|
209
|
+
restorePipelineContexts(this);
|
|
205
210
|
}
|
|
206
211
|
|
|
207
212
|
async generate(request = {}) {
|
|
@@ -336,6 +341,7 @@ export class EnergyPipeline {
|
|
|
336
341
|
const centerWeight = Number.isFinite(weights.center) ? weights.center : 1.0;
|
|
337
342
|
const binarizeWeight = Number.isFinite(weights.binarize) ? weights.binarize : 0.0;
|
|
338
343
|
const centerTarget = Number.isFinite(quintelConfig.centerTarget) ? quintelConfig.centerTarget : 1.0;
|
|
344
|
+
const flags = buildQuintelKernelFlags(rules, binarizeWeight);
|
|
339
345
|
const energyHistory = [];
|
|
340
346
|
const stepTimesMs = [];
|
|
341
347
|
let lastEnergy = null;
|
|
@@ -387,11 +393,11 @@ export class EnergyPipeline {
|
|
|
387
393
|
await runEnergyQuintelReduce(stateTensor, {
|
|
388
394
|
count: elementCount,
|
|
389
395
|
size,
|
|
396
|
+
flags,
|
|
390
397
|
symmetryWeight,
|
|
391
398
|
centerWeight,
|
|
392
399
|
binarizeWeight,
|
|
393
400
|
centerTarget,
|
|
394
|
-
rules,
|
|
395
401
|
outputBuffer: reduceBuffer,
|
|
396
402
|
});
|
|
397
403
|
|
|
@@ -447,13 +453,13 @@ export class EnergyPipeline {
|
|
|
447
453
|
await runEnergyQuintelGrad(stateTensor, {
|
|
448
454
|
count: elementCount,
|
|
449
455
|
size,
|
|
456
|
+
flags,
|
|
450
457
|
countDiff: safeCountDiff,
|
|
451
458
|
symmetryWeight,
|
|
452
459
|
countWeight,
|
|
453
460
|
centerWeight,
|
|
454
461
|
binarizeWeight,
|
|
455
462
|
centerTarget,
|
|
456
|
-
rules,
|
|
457
463
|
outputBuffer: gradBuffer,
|
|
458
464
|
});
|
|
459
465
|
|
|
@@ -471,6 +477,7 @@ export class EnergyPipeline {
|
|
|
471
477
|
await runEnergyQuintelUpdate(stateTensor, {
|
|
472
478
|
count: elementCount,
|
|
473
479
|
size,
|
|
480
|
+
flags,
|
|
474
481
|
stepSize,
|
|
475
482
|
gradientScale,
|
|
476
483
|
countDiff: safeCountDiff,
|
|
@@ -481,7 +488,6 @@ export class EnergyPipeline {
|
|
|
481
488
|
centerTarget,
|
|
482
489
|
clampMin,
|
|
483
490
|
clampMax,
|
|
484
|
-
rules,
|
|
485
491
|
});
|
|
486
492
|
}
|
|
487
493
|
|
|
@@ -84,4 +84,9 @@ export function mergeQuintelConfig(
|
|
|
84
84
|
override?: Partial<QuintelEnergyConfig> | null
|
|
85
85
|
): QuintelEnergyConfig;
|
|
86
86
|
|
|
87
|
+
export function buildQuintelKernelFlags(
|
|
88
|
+
rules: Partial<QuintelRuleConfig> | null | undefined,
|
|
89
|
+
binarizeWeight?: number
|
|
90
|
+
): number;
|
|
91
|
+
|
|
87
92
|
export function runQuintelEnergyLoop(options: QuintelEnergyLoopOptions): QuintelEnergyLoopResult;
|
|
@@ -22,6 +22,17 @@ export function mergeQuintelConfig(base, override) {
|
|
|
22
22
|
};
|
|
23
23
|
}
|
|
24
24
|
|
|
25
|
+
export function buildQuintelKernelFlags(rules, binarizeWeight) {
|
|
26
|
+
let flags = 0;
|
|
27
|
+
if (rules?.mirrorX) flags |= 1;
|
|
28
|
+
if (rules?.mirrorY) flags |= 2;
|
|
29
|
+
if (rules?.diagonal) flags |= 4;
|
|
30
|
+
if (rules?.count) flags |= 8;
|
|
31
|
+
if (rules?.center) flags |= 16;
|
|
32
|
+
if (Number.isFinite(binarizeWeight) && binarizeWeight !== 0) flags |= 32;
|
|
33
|
+
return flags >>> 0;
|
|
34
|
+
}
|
|
35
|
+
|
|
25
36
|
function applyPairEnergy(state, gradients, indexA, indexB, weight) {
|
|
26
37
|
const diff = state[indexA] - state[indexB];
|
|
27
38
|
const energy = weight * diff * diff;
|
|
@@ -5,7 +5,7 @@ import { runEnergyEval, runEnergyUpdate } from '../../../gpu/kernels/index.js';
|
|
|
5
5
|
import { log } from '../../../debug/index.js';
|
|
6
6
|
import { f16ToF32Array, f32ToF16Array } from '../../kv-cache/types.js';
|
|
7
7
|
import { registerPipeline } from '../registry.js';
|
|
8
|
-
import { applyPipelineContexts } from '../context.js';
|
|
8
|
+
import { applyPipelineContexts, restorePipelineContexts } from '../context.js';
|
|
9
9
|
import { createInitializedPipeline } from '../factory.js';
|
|
10
10
|
import { selectRuleValue } from '../../../rules/rule-registry.js';
|
|
11
11
|
|
|
@@ -165,19 +165,22 @@ async function createFeatureTensor(device, values, dtype, label) {
|
|
|
165
165
|
const byteLength = payload.byteLength;
|
|
166
166
|
const alignedSize = Math.ceil(byteLength / 4) * 4;
|
|
167
167
|
const buffer = acquireBuffer(alignedSize, undefined, label);
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
168
|
+
try {
|
|
169
|
+
if (alignedSize === byteLength) {
|
|
170
|
+
device.queue.writeBuffer(buffer, 0, payload);
|
|
171
|
+
} else {
|
|
172
|
+
const bytes = payload instanceof Uint16Array
|
|
173
|
+
? new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength)
|
|
174
|
+
: new Uint8Array(payload.buffer, payload.byteOffset, payload.byteLength);
|
|
175
|
+
const padded = new Uint8Array(alignedSize);
|
|
176
|
+
padded.set(bytes);
|
|
177
|
+
device.queue.writeBuffer(buffer, 0, padded);
|
|
178
|
+
}
|
|
179
|
+
return createTensor(buffer, dtype, [values.length], label);
|
|
180
|
+
} catch (error) {
|
|
181
|
+
releaseBuffer(buffer);
|
|
182
|
+
throw error;
|
|
178
183
|
}
|
|
179
|
-
|
|
180
|
-
return createTensor(buffer, dtype, [values.length], label);
|
|
181
184
|
}
|
|
182
185
|
|
|
183
186
|
async function readTensorF32(tensor) {
|
|
@@ -307,6 +310,7 @@ export class EnergyRowHeadPipeline {
|
|
|
307
310
|
this.manifest = null;
|
|
308
311
|
this.model = null;
|
|
309
312
|
this.stats = {};
|
|
313
|
+
restorePipelineContexts(this);
|
|
310
314
|
}
|
|
311
315
|
|
|
312
316
|
async scoreRows(request = {}) {
|
|
@@ -84,20 +84,35 @@ function parseStructuredJSONObject(rawText) {
|
|
|
84
84
|
function resolveStructuredRuntime(manifest, runtimeConfig) {
|
|
85
85
|
const modelCfg = isObj(manifest?.inference?.structuredJsonHead)
|
|
86
86
|
? manifest.inference.structuredJsonHead
|
|
87
|
-
:
|
|
87
|
+
: null;
|
|
88
|
+
if (!modelCfg) {
|
|
89
|
+
throw new Error('StructuredJsonHeadPipeline: manifest.inference.structuredJsonHead is required.');
|
|
90
|
+
}
|
|
88
91
|
const runtimeCfg = isObj(runtimeConfig?.inference?.structuredJsonHead)
|
|
89
92
|
? runtimeConfig.inference.structuredJsonHead
|
|
90
|
-
:
|
|
93
|
+
: {};
|
|
94
|
+
const resolvedMaxTokens = Number.isFinite(runtimeCfg.maxTokens)
|
|
95
|
+
? Math.max(1, Math.floor(runtimeCfg.maxTokens))
|
|
96
|
+
: (Number.isFinite(modelCfg.maxTokens) ? Math.max(1, Math.floor(modelCfg.maxTokens)) : null);
|
|
97
|
+
const resolvedTemperature = Number.isFinite(runtimeCfg.temperature)
|
|
98
|
+
? Number(runtimeCfg.temperature)
|
|
99
|
+
: (Number.isFinite(modelCfg.temperature) ? Number(modelCfg.temperature) : null);
|
|
100
|
+
const resolvedMaxOutputChars = Number.isFinite(runtimeCfg.maxOutputChars)
|
|
101
|
+
? Math.max(4096, Math.floor(runtimeCfg.maxOutputChars))
|
|
102
|
+
: (Number.isFinite(modelCfg.maxOutputChars) ? Math.max(4096, Math.floor(modelCfg.maxOutputChars)) : null);
|
|
103
|
+
if (!Number.isFinite(resolvedMaxTokens)) {
|
|
104
|
+
throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.maxTokens is required.');
|
|
105
|
+
}
|
|
106
|
+
if (!Number.isFinite(resolvedTemperature)) {
|
|
107
|
+
throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.temperature is required.');
|
|
108
|
+
}
|
|
109
|
+
if (!Number.isFinite(resolvedMaxOutputChars)) {
|
|
110
|
+
throw new Error('StructuredJsonHeadPipeline: structuredJsonHead.maxOutputChars is required.');
|
|
111
|
+
}
|
|
91
112
|
return {
|
|
92
|
-
maxTokens:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
temperature: Number.isFinite(runtimeCfg.temperature)
|
|
96
|
-
? Number(runtimeCfg.temperature)
|
|
97
|
-
: (Number.isFinite(modelCfg.temperature) ? Number(modelCfg.temperature) : 0),
|
|
98
|
-
maxOutputChars: Number.isFinite(runtimeCfg.maxOutputChars)
|
|
99
|
-
? Math.max(4096, Math.floor(runtimeCfg.maxOutputChars))
|
|
100
|
-
: (Number.isFinite(modelCfg.maxOutputChars) ? Math.max(4096, Math.floor(modelCfg.maxOutputChars)) : 262144),
|
|
113
|
+
maxTokens: resolvedMaxTokens,
|
|
114
|
+
temperature: resolvedTemperature,
|
|
115
|
+
maxOutputChars: resolvedMaxOutputChars,
|
|
101
116
|
};
|
|
102
117
|
}
|
|
103
118
|
|