@simulatte/doppler 0.1.6 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +126 -0
- package/README.md +16 -23
- package/package.json +14 -1
- package/src/adapters/adapter-registry.js +12 -1
- package/src/adapters/lora-loader.js +23 -6
- package/src/bridge/extension-client.d.ts +5 -0
- package/src/bridge/extension-client.js +40 -0
- package/src/bridge/index.d.ts +2 -1
- package/src/bridge/index.js +6 -4
- package/src/browser/browser-converter.js +26 -1
- package/src/browser/file-picker.js +6 -0
- package/src/browser/safetensors-parser-browser.js +84 -1
- package/src/browser/shard-io-browser.js +2 -2
- package/src/browser/tensor-source-download.js +8 -2
- package/src/browser/tensor-source-http.d.ts +1 -0
- package/src/browser/tensor-source-http.js +5 -1
- package/src/client/doppler-api.browser.js +20 -4
- package/src/client/doppler-api.js +19 -3
- package/src/client/doppler-provider/generation.js +12 -0
- package/src/client/doppler-provider/model-manager.d.ts +10 -0
- package/src/client/doppler-provider/model-manager.js +91 -19
- package/src/client/doppler-provider/source-runtime.d.ts +2 -1
- package/src/client/doppler-provider/source-runtime.js +132 -13
- package/src/client/doppler-registry.json +8 -7
- package/src/config/backward-registry-loader.js +17 -2
- package/src/config/execution-v0-contract-check.js +113 -15
- package/src/config/kernel-path-contract-check.js +57 -29
- package/src/config/kernel-path-loader.js +5 -36
- package/src/config/kernels/kernel-ref-digests.js +1 -1
- package/src/config/kernels/registry.js +14 -1
- package/src/config/kernels/registry.json +7 -5
- package/src/config/loader.d.ts +1 -1
- package/src/config/loader.js +12 -2
- package/src/config/merge-contract-check.js +59 -4
- package/src/config/merge-helpers.js +128 -7
- package/src/config/merge.d.ts +1 -0
- package/src/config/merge.js +10 -0
- package/src/config/param-validator.js +47 -2
- package/src/config/presets/kernel-paths/{gemma2-q4k-dequant-f32a.json → gemma2-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/gemma3-f16-fused-f32a-online-streamingprefill.json +223 -0
- package/src/config/presets/kernel-paths/{gemma3-q4k-dequant-f32a.json → gemma3-q4k-dequant-f32a-nosubgroups.json} +3 -3
- package/src/config/presets/kernel-paths/registry.json +29 -8
- package/src/config/presets/models/gemma2.json +2 -2
- package/src/config/presets/runtime/experiments/bench/gemma3-bench-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/debug/gemma3-debug-q4k.json +1 -1
- package/src/config/presets/runtime/experiments/verify/gemma3-verify.json +1 -1
- package/src/config/presets/runtime/kernels/dequant-f16-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/dequant-f32-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/embeddinggemma-q4k-dequant-f32a.json +37 -0
- package/src/config/presets/runtime/kernels/fused-q4k.json +6 -13
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f16a.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-dequant-f32a-nosubgroups.json +33 -0
- package/src/config/presets/runtime/kernels/gemma2-q4k-fused-f32a.json +33 -0
- package/src/config/presets/runtime/kernels/safe-q4k.json +6 -13
- package/src/config/presets/runtime/platform/metal-apple-q4k.json +1 -1
- package/src/config/runtime.js +6 -1
- package/src/config/schema/debug.schema.d.ts +5 -0
- package/src/config/schema/doppler.schema.js +16 -21
- package/src/config/schema/inference-defaults.schema.js +3 -3
- package/src/config/schema/kernel-path.schema.d.ts +5 -1
- package/src/config/schema/kernel-thresholds.schema.js +12 -4
- package/src/config/schema/manifest.schema.d.ts +2 -1
- package/src/config/schema/manifest.schema.js +16 -3
- package/src/config/training-defaults.js +30 -22
- package/src/converter/conversion-plan.js +94 -9
- package/src/converter/core.d.ts +7 -0
- package/src/converter/core.js +14 -9
- package/src/converter/execution-v0-manifest.js +4 -1
- package/src/converter/index.d.ts +1 -0
- package/src/converter/index.js +1 -0
- package/src/converter/manifest-inference.js +43 -12
- package/src/converter/parsers/diffusion.js +0 -3
- package/src/converter/quantization-info.js +35 -15
- package/src/converter/shard-packer.d.ts +1 -1
- package/src/converter/shard-packer.js +4 -1
- package/src/debug/config.js +123 -11
- package/src/debug/signals.js +7 -1
- package/src/debug/tensor.d.ts +2 -0
- package/src/debug/tensor.js +13 -2
- package/src/distribution/p2p-control-plane.js +52 -12
- package/src/distribution/p2p-observability.js +43 -7
- package/src/distribution/p2p-webrtc-browser.js +20 -0
- package/src/distribution/shard-delivery.js +77 -26
- package/src/formats/gguf/types.js +33 -16
- package/src/formats/rdrr/groups.d.ts +12 -4
- package/src/formats/rdrr/groups.js +3 -6
- package/src/formats/rdrr/parsing.js +39 -2
- package/src/formats/rdrr/types.d.ts +2 -1
- package/src/gpu/command-recorder.js +86 -61
- package/src/gpu/device.d.ts +1 -0
- package/src/gpu/device.js +73 -19
- package/src/gpu/kernel-tuner/benchmarks.js +326 -316
- package/src/gpu/kernel-tuner/cache.js +71 -4
- package/src/gpu/kernel-tuner/tuner.js +22 -4
- package/src/gpu/kernels/attention.js +15 -34
- package/src/gpu/kernels/backward/adam.js +62 -58
- package/src/gpu/kernels/backward/attention_backward.js +257 -169
- package/src/gpu/kernels/backward/conv2d_backward.js +14 -1
- package/src/gpu/kernels/cast.js +191 -149
- package/src/gpu/kernels/check-stop.js +33 -44
- package/src/gpu/kernels/conv2d.js +27 -17
- package/src/gpu/kernels/cross_entropy_loss.js +21 -15
- package/src/gpu/kernels/depthwise_conv2d.js +36 -26
- package/src/gpu/kernels/dequant.js +178 -126
- package/src/gpu/kernels/energy.d.ts +3 -21
- package/src/gpu/kernels/energy.js +111 -88
- package/src/gpu/kernels/feature-check.js +1 -1
- package/src/gpu/kernels/fused_ffn.js +84 -65
- package/src/gpu/kernels/fused_matmul_residual.js +56 -33
- package/src/gpu/kernels/fused_matmul_rmsnorm.js +62 -45
- package/src/gpu/kernels/gather.js +33 -15
- package/src/gpu/kernels/gelu.js +19 -11
- package/src/gpu/kernels/grouped_pointwise_conv2d.js +33 -23
- package/src/gpu/kernels/groupnorm.js +34 -23
- package/src/gpu/kernels/kv-quantize.js +5 -2
- package/src/gpu/kernels/layernorm.js +35 -19
- package/src/gpu/kernels/logit-merge.js +5 -3
- package/src/gpu/kernels/matmul.js +58 -39
- package/src/gpu/kernels/modulate.js +23 -15
- package/src/gpu/kernels/moe.js +221 -175
- package/src/gpu/kernels/pixel_shuffle.js +22 -14
- package/src/gpu/kernels/relu.js +18 -10
- package/src/gpu/kernels/repeat_channels.js +25 -17
- package/src/gpu/kernels/residual.js +37 -27
- package/src/gpu/kernels/rmsnorm.js +57 -41
- package/src/gpu/kernels/rope.js +3 -0
- package/src/gpu/kernels/sample.js +27 -38
- package/src/gpu/kernels/sana_linear_attention.js +18 -10
- package/src/gpu/kernels/scale.js +18 -11
- package/src/gpu/kernels/shader-cache.js +4 -2
- package/src/gpu/kernels/silu.js +120 -72
- package/src/gpu/kernels/softmax.js +44 -25
- package/src/gpu/kernels/split_qkv.js +23 -13
- package/src/gpu/kernels/transpose.js +18 -10
- package/src/gpu/kernels/transpose.wgsl +5 -3
- package/src/gpu/kernels/upsample2d.js +21 -13
- package/src/gpu/kernels/utils.js +20 -13
- package/src/gpu/partitioned-buffer-pool.js +10 -2
- package/src/gpu/perf-guards.js +2 -9
- package/src/gpu/profiler.js +27 -22
- package/src/gpu/readback-utils.d.ts +16 -0
- package/src/gpu/readback-utils.js +41 -0
- package/src/gpu/submit-tracker.js +13 -0
- package/src/gpu/uniform-cache.d.ts +1 -0
- package/src/gpu/uniform-cache.js +30 -9
- package/src/hotswap/intent-bundle.js +6 -0
- package/src/hotswap/manifest.d.ts +10 -1
- package/src/hotswap/manifest.js +12 -2
- package/src/hotswap/runtime.js +30 -8
- package/src/index-browser.d.ts +44 -0
- package/src/index-browser.js +14 -0
- package/src/inference/browser-harness-contract-helpers.d.ts +5 -0
- package/src/inference/browser-harness-contract-helpers.js +28 -0
- package/src/inference/browser-harness-diffusion-energy-suites.d.ts +2 -0
- package/src/inference/browser-harness-diffusion-energy-suites.js +269 -0
- package/src/inference/browser-harness-model-helpers.d.ts +16 -0
- package/src/inference/browser-harness-model-helpers.js +217 -0
- package/src/inference/browser-harness-report-helpers.d.ts +7 -0
- package/src/inference/browser-harness-report-helpers.js +42 -0
- package/src/inference/browser-harness-runtime-helpers.d.ts +61 -0
- package/src/inference/browser-harness-runtime-helpers.js +415 -0
- package/src/inference/browser-harness-suite-helpers.d.ts +28 -0
- package/src/inference/browser-harness-suite-helpers.js +268 -0
- package/src/inference/browser-harness-text-helpers.d.ts +27 -0
- package/src/inference/browser-harness-text-helpers.js +788 -0
- package/src/inference/browser-harness.d.ts +6 -0
- package/src/inference/browser-harness.js +130 -1996
- package/src/inference/kv-cache/base.js +140 -94
- package/src/inference/kv-cache/tiered.js +5 -3
- package/src/inference/moe-router.js +88 -56
- package/src/inference/multi-model-network.js +5 -3
- package/src/inference/network-evolution.d.ts +11 -2
- package/src/inference/network-evolution.js +20 -21
- package/src/inference/pipelines/context.d.ts +3 -0
- package/src/inference/pipelines/context.js +142 -2
- package/src/inference/pipelines/diffusion/helpers.js +7 -2
- package/src/inference/pipelines/diffusion/pipeline.js +2 -1
- package/src/inference/pipelines/diffusion/sd3-transformer.js +10 -10
- package/src/inference/pipelines/diffusion/vae.js +3 -7
- package/src/inference/pipelines/energy/pipeline.js +27 -21
- package/src/inference/pipelines/energy/quintel.d.ts +5 -0
- package/src/inference/pipelines/energy/quintel.js +11 -0
- package/src/inference/pipelines/energy-head/row-head-pipeline.js +17 -13
- package/src/inference/pipelines/structured/json-head-pipeline.js +26 -11
- package/src/inference/pipelines/text/attention/projections.js +151 -101
- package/src/inference/pipelines/text/attention/record.js +62 -8
- package/src/inference/pipelines/text/attention/run.js +62 -8
- package/src/inference/pipelines/text/config.js +3 -4
- package/src/inference/pipelines/text/embed.js +2 -8
- package/src/inference/pipelines/text/execution-plan.js +41 -19
- package/src/inference/pipelines/text/execution-v0-contract-helpers.d.ts +59 -0
- package/src/inference/pipelines/text/execution-v0-contract-helpers.js +937 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.d.ts +15 -0
- package/src/inference/pipelines/text/execution-v0-runtime-builders.js +279 -0
- package/src/inference/pipelines/text/execution-v0.js +62 -1013
- package/src/inference/pipelines/text/generator-steps.d.ts +46 -0
- package/src/inference/pipelines/text/generator-steps.js +298 -207
- package/src/inference/pipelines/text/generator.js +6 -23
- package/src/inference/pipelines/text/init.js +78 -20
- package/src/inference/pipelines/text/kernel-path-auto-select.js +2 -0
- package/src/inference/pipelines/text/kernel-trace.d.ts +2 -0
- package/src/inference/pipelines/text/kernel-trace.js +6 -0
- package/src/inference/pipelines/text/layer.js +3 -9
- package/src/inference/pipelines/text/linear-attention.d.ts +10 -0
- package/src/inference/pipelines/text/linear-attention.js +80 -6
- package/src/inference/pipelines/text/logits/gpu.js +10 -5
- package/src/inference/pipelines/text/logits/index.js +10 -11
- package/src/inference/pipelines/text/logits/utils.d.ts +7 -0
- package/src/inference/pipelines/text/logits/utils.js +9 -0
- package/src/inference/pipelines/text/lora-apply.js +50 -32
- package/src/inference/pipelines/text/model-load.js +279 -104
- package/src/inference/pipelines/text/moe-cache.js +5 -4
- package/src/inference/pipelines/text/moe-cpu-gptoss.js +74 -69
- package/src/inference/pipelines/text/moe-cpu.js +42 -38
- package/src/inference/pipelines/text/moe-gpu.js +110 -86
- package/src/inference/pipelines/text/ops.js +90 -90
- package/src/inference/pipelines/text/probes.js +9 -9
- package/src/inference/pipelines/text/weights.js +17 -7
- package/src/inference/pipelines/text.js +13 -1
- package/src/inference/speculative.d.ts +2 -2
- package/src/inference/speculative.js +4 -18
- package/src/inference/test-harness.d.ts +1 -1
- package/src/inference/test-harness.js +15 -5
- package/src/inference/tokenizer.d.ts +0 -5
- package/src/inference/tokenizer.js +4 -23
- package/src/inference/tokenizers/bpe.js +9 -0
- package/src/inference/tokenizers/bundled.js +20 -0
- package/src/inference/tokenizers/sentencepiece.js +12 -0
- package/src/loader/doppler-loader.js +38 -22
- package/src/loader/dtype-utils.js +3 -44
- package/src/loader/embedding-loader.js +7 -3
- package/src/loader/experts/expert-cache.js +13 -6
- package/src/loader/experts/expert-loader.js +10 -6
- package/src/loader/final-weights-loader.js +8 -4
- package/src/loader/layer-loader.js +2 -1
- package/src/loader/loader-state.js +2 -2
- package/src/loader/memory-monitor.js +8 -0
- package/src/loader/multi-model-loader.d.ts +14 -0
- package/src/loader/multi-model-loader.js +70 -24
- package/src/loader/shard-cache.js +81 -12
- package/src/loader/shard-resolver.js +25 -3
- package/src/loader/tensors/tensor-loader.js +209 -144
- package/src/loader/tensors/tensor-reader.js +76 -19
- package/src/loader/weight-downcast.js +1 -1
- package/src/memory/buffer-pool.d.ts +9 -1
- package/src/memory/buffer-pool.js +109 -44
- package/src/memory/unified-detect.js +1 -1
- package/src/rules/inference/kernel-path.rules.json +24 -8
- package/src/rules/rule-registry.js +25 -1
- package/src/storage/backends/opfs-store.js +68 -24
- package/src/storage/downloader.js +364 -83
- package/src/storage/index.d.ts +3 -0
- package/src/storage/index.js +3 -0
- package/src/storage/preflight.d.ts +2 -2
- package/src/storage/preflight.js +24 -2
- package/src/storage/quickstart-downloader.js +11 -5
- package/src/storage/registry.js +10 -4
- package/src/storage/reports.js +1 -1
- package/src/storage/shard-manager.d.ts +15 -1
- package/src/storage/shard-manager.js +51 -3
- package/src/storage/source-artifact-store.d.ts +52 -0
- package/src/storage/source-artifact-store.js +234 -0
- package/src/tooling/command-api-constants.d.ts +9 -0
- package/src/tooling/command-api-constants.js +9 -0
- package/src/tooling/command-api-family-normalizers.d.ts +9 -0
- package/src/tooling/command-api-family-normalizers.js +343 -0
- package/src/tooling/command-api-helpers.d.ts +25 -0
- package/src/tooling/command-api-helpers.js +262 -0
- package/src/tooling/command-api.js +16 -602
- package/src/tooling/command-envelope.js +4 -1
- package/src/tooling/command-runner-shared.js +52 -18
- package/src/tooling/lean-execution-contract.js +150 -3
- package/src/tooling/node-browser-command-runner.js +161 -271
- package/src/tooling/node-command-runner.js +29 -3
- package/src/tooling/node-converter.js +27 -1
- package/src/tooling/node-source-runtime.d.ts +1 -1
- package/src/tooling/node-source-runtime.js +84 -3
- package/src/tooling/node-webgpu.js +24 -21
- package/src/tooling/opfs-cache.js +21 -4
- package/src/tooling/runtime-input-composition.d.ts +38 -0
- package/src/tooling/runtime-input-composition.js +86 -0
- package/src/tooling/source-runtime-bundle.d.ts +40 -5
- package/src/tooling/source-runtime-bundle.js +261 -34
- package/src/tooling/source-runtime-materializer.d.ts +6 -0
- package/src/tooling/source-runtime-materializer.js +93 -0
- package/src/training/attention-backward.js +32 -17
- package/src/training/autograd.js +80 -52
- package/src/training/checkpoint-watch.d.ts +2 -1
- package/src/training/checkpoint-watch.js +39 -6
- package/src/training/checkpoint.js +40 -11
- package/src/training/clip.js +2 -1
- package/src/training/datasets/token-batch.js +20 -8
- package/src/training/distillation/checkpoint-watch.js +1 -0
- package/src/training/distillation/student-fixture.d.ts +22 -0
- package/src/training/distillation/student-fixture.js +846 -0
- package/src/training/distillation/suite-data.d.ts +45 -0
- package/src/training/distillation/suite-data.js +189 -0
- package/src/training/lora-pipeline.js +4 -7
- package/src/training/lora.js +26 -12
- package/src/training/loss.js +5 -6
- package/src/training/objectives/cross_entropy.js +2 -5
- package/src/training/objectives/distill_kd.js +4 -8
- package/src/training/objectives/distill_triplet.js +4 -8
- package/src/training/objectives/ul_stage2_base.js +4 -8
- package/src/training/operator-command.js +2 -0
- package/src/training/optimizer.js +19 -7
- package/src/training/runner.js +2 -1
- package/src/training/suite.js +18 -978
- package/src/training/tensor-factory.d.ts +9 -0
- package/src/training/tensor-factory.js +13 -0
- package/src/training/trainer.js +3 -5
- package/src/training/ul_dataset.js +3 -5
- package/src/training/workloads.js +70 -79
- package/src/version.js +1 -1
- package/tools/convert-safetensors-node.js +22 -16
- package/tools/doppler-cli.js +44 -25
package/src/training/autograd.js
CHANGED
|
@@ -6,6 +6,7 @@ import { acquireBuffer, readBuffer, releaseBuffer, uploadData } from '../memory/
|
|
|
6
6
|
import { createTensor } from '../gpu/tensor.js';
|
|
7
7
|
import { attentionBackwardCpu } from './attention-backward.js';
|
|
8
8
|
import { f16ToF32Array, f32ToF16Array } from '../inference/kv-cache/types.js';
|
|
9
|
+
import { createUploadedTensor } from './tensor-factory.js';
|
|
9
10
|
|
|
10
11
|
export const OpType = {
|
|
11
12
|
EMBED: 'embed',
|
|
@@ -35,6 +36,7 @@ export class AutogradTape {
|
|
|
35
36
|
constructor(registry) {
|
|
36
37
|
this.registry = registry;
|
|
37
38
|
this.records = [];
|
|
39
|
+
this.retainedBuffers = new Set();
|
|
38
40
|
}
|
|
39
41
|
|
|
40
42
|
watch(tensor) {
|
|
@@ -43,6 +45,13 @@ export class AutogradTape {
|
|
|
43
45
|
|
|
44
46
|
async record(op, fn, inputs, options = {}) {
|
|
45
47
|
const output = await fn(...inputs);
|
|
48
|
+
if (Array.isArray(options.retainBuffers)) {
|
|
49
|
+
for (const buffer of options.retainBuffers) {
|
|
50
|
+
if (buffer) {
|
|
51
|
+
this.retainedBuffers.add(buffer);
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
}
|
|
46
55
|
this.records.push({ op, inputs, output, options });
|
|
47
56
|
return output;
|
|
48
57
|
}
|
|
@@ -50,31 +59,40 @@ export class AutogradTape {
|
|
|
50
59
|
async backward(gradOutput) {
|
|
51
60
|
const grads = new Map();
|
|
52
61
|
const seeds = this.normalizeBackwardSeeds(gradOutput);
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
for (let i = this.records.length - 1; i >= 0; i -= 1) {
|
|
58
|
-
const record = this.records[i];
|
|
59
|
-
const entry = this.registry.ops[record.op];
|
|
60
|
-
if (!entry) {
|
|
61
|
-
continue;
|
|
62
|
+
try {
|
|
63
|
+
for (const seed of seeds) {
|
|
64
|
+
await this.accumulateGrad(grads, seed.tensor, seed.grad);
|
|
62
65
|
}
|
|
63
66
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
67
|
+
for (let i = this.records.length - 1; i >= 0; i -= 1) {
|
|
68
|
+
const record = this.records[i];
|
|
69
|
+
const entry = this.registry.ops[record.op];
|
|
70
|
+
if (!entry) {
|
|
71
|
+
continue;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
const gradOut = grads.get(record.output);
|
|
75
|
+
if (!gradOut) {
|
|
76
|
+
continue;
|
|
77
|
+
}
|
|
68
78
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
79
|
+
const gradsOut = await this.runBackward(entry.backward, record, gradOut);
|
|
80
|
+
for (const { input, grad } of gradsOut) {
|
|
81
|
+
if (input && grad) {
|
|
82
|
+
await this.accumulateGrad(grads, input, grad);
|
|
83
|
+
}
|
|
73
84
|
}
|
|
74
85
|
}
|
|
75
|
-
}
|
|
76
86
|
|
|
77
|
-
|
|
87
|
+
return grads;
|
|
88
|
+
} finally {
|
|
89
|
+
for (const buffer of this.retainedBuffers) {
|
|
90
|
+
try {
|
|
91
|
+
releaseBuffer(buffer);
|
|
92
|
+
} catch {}
|
|
93
|
+
}
|
|
94
|
+
this.retainedBuffers.clear();
|
|
95
|
+
}
|
|
78
96
|
}
|
|
79
97
|
|
|
80
98
|
isTensorLike(value) {
|
|
@@ -245,9 +263,7 @@ export class AutogradTape {
|
|
|
245
263
|
expanded.set(gradRow.subarray(0, copyCount), rowOffset);
|
|
246
264
|
const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
|
|
247
265
|
const payload = dtype === 'f16' ? f32ToF16Array(expanded) : expanded;
|
|
248
|
-
|
|
249
|
-
uploadData(outBuffer, payload);
|
|
250
|
-
return createTensor(outBuffer, dtype, [rows, cols], 'row_slice_backward_output');
|
|
266
|
+
return createUploadedTensor(payload, dtype, [rows, cols], 'row_slice_backward_output');
|
|
251
267
|
}
|
|
252
268
|
|
|
253
269
|
resolveSiluRowsplitGate(gateValue, activation) {
|
|
@@ -305,9 +321,7 @@ export class AutogradTape {
|
|
|
305
321
|
|
|
306
322
|
const dtype = gradOut.dtype === 'f16' ? 'f16' : 'f32';
|
|
307
323
|
const payload = dtype === 'f16' ? f32ToF16Array(output) : output;
|
|
308
|
-
|
|
309
|
-
uploadData(outBuffer, payload);
|
|
310
|
-
return createTensor(outBuffer, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
|
|
324
|
+
return createUploadedTensor(payload, dtype, [numTokens, dim * 2], 'silu_rowsplit_backward_output');
|
|
311
325
|
}
|
|
312
326
|
|
|
313
327
|
async accumulateLargeGradF32(existing, grad, size, shape) {
|
|
@@ -317,35 +331,49 @@ export class AutogradTape {
|
|
|
317
331
|
}
|
|
318
332
|
const bytesPerElement = 4;
|
|
319
333
|
const outputBuffer = acquireBuffer(size * bytesPerElement, undefined, 'grad_accum_large_output');
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
334
|
+
try {
|
|
335
|
+
for (let offset = 0; offset < size; offset += MAX_RESIDUAL_ELEMENTS_PER_DISPATCH) {
|
|
336
|
+
const chunkElements = Math.min(MAX_RESIDUAL_ELEMENTS_PER_DISPATCH, size - offset);
|
|
337
|
+
const chunkBytes = chunkElements * bytesPerElement;
|
|
338
|
+
const chunkOffsetBytes = offset * bytesPerElement;
|
|
339
|
+
|
|
340
|
+
let aChunkBuffer = null;
|
|
341
|
+
let bChunkBuffer = null;
|
|
342
|
+
let summedChunkBuffer = null;
|
|
343
|
+
try {
|
|
344
|
+
aChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_a_chunk');
|
|
345
|
+
bChunkBuffer = acquireBuffer(chunkBytes, undefined, 'grad_accum_large_b_chunk');
|
|
346
|
+
const copyIn = device.createCommandEncoder();
|
|
347
|
+
copyIn.copyBufferToBuffer(existing.buffer, chunkOffsetBytes, aChunkBuffer, 0, chunkBytes);
|
|
348
|
+
copyIn.copyBufferToBuffer(grad.buffer, chunkOffsetBytes, bChunkBuffer, 0, chunkBytes);
|
|
349
|
+
device.queue.submit([copyIn.finish()]);
|
|
350
|
+
|
|
351
|
+
const aChunk = createTensor(aChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_a_tensor');
|
|
352
|
+
const bChunk = createTensor(bChunkBuffer, 'f32', [chunkElements], 'grad_accum_large_b_tensor');
|
|
353
|
+
const summedChunk = await runResidualAdd(aChunk, bChunk, chunkElements);
|
|
354
|
+
summedChunkBuffer = summedChunk?.buffer ?? null;
|
|
355
|
+
|
|
356
|
+
const copyOut = device.createCommandEncoder();
|
|
357
|
+
copyOut.copyBufferToBuffer(summedChunk.buffer, 0, outputBuffer, chunkOffsetBytes, chunkBytes);
|
|
358
|
+
device.queue.submit([copyOut.finish()]);
|
|
359
|
+
} finally {
|
|
360
|
+
if (aChunkBuffer) {
|
|
361
|
+
releaseBuffer(aChunkBuffer);
|
|
362
|
+
}
|
|
363
|
+
if (bChunkBuffer) {
|
|
364
|
+
releaseBuffer(bChunkBuffer);
|
|
365
|
+
}
|
|
366
|
+
if (summedChunkBuffer && summedChunkBuffer !== outputBuffer) {
|
|
367
|
+
releaseBuffer(summedChunkBuffer);
|
|
368
|
+
}
|
|
369
|
+
}
|
|
345
370
|
}
|
|
346
|
-
}
|
|
347
371
|
|
|
348
|
-
|
|
372
|
+
return createTensor(outputBuffer, 'f32', [...shape], 'grad_accum_large_output');
|
|
373
|
+
} catch (error) {
|
|
374
|
+
releaseBuffer(outputBuffer);
|
|
375
|
+
throw error;
|
|
376
|
+
}
|
|
349
377
|
}
|
|
350
378
|
|
|
351
379
|
|
|
@@ -3,5 +3,6 @@ export declare function watchFinalizedCheckpoints(options: {
|
|
|
3
3
|
manifestPath: string;
|
|
4
4
|
pollIntervalMs?: number | null;
|
|
5
5
|
stopWhenIdle?: boolean;
|
|
6
|
+
signal?: AbortSignal | null;
|
|
6
7
|
onCheckpoint: (markerPath: string) => Promise<void> | void;
|
|
7
|
-
}): Promise<{ ok: true; processedCount: number; manifestPath: string }>;
|
|
8
|
+
}): Promise<{ ok: true; processedCount: number; manifestPath: string; aborted?: boolean }>;
|
|
@@ -55,6 +55,36 @@ async function readProcessedManifest(manifestPath) {
|
|
|
55
55
|
}
|
|
56
56
|
}
|
|
57
57
|
|
|
58
|
+
function createWatchResult(processed, manifestPath, aborted = false) {
|
|
59
|
+
return {
|
|
60
|
+
ok: true,
|
|
61
|
+
processedCount: processed.size,
|
|
62
|
+
manifestPath,
|
|
63
|
+
aborted,
|
|
64
|
+
};
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
async function waitForPollInterval(pollIntervalMs, signal) {
|
|
68
|
+
if (!signal) {
|
|
69
|
+
await new Promise((resolvePromise) => setTimeout(resolvePromise, pollIntervalMs));
|
|
70
|
+
return true;
|
|
71
|
+
}
|
|
72
|
+
if (signal.aborted) {
|
|
73
|
+
return false;
|
|
74
|
+
}
|
|
75
|
+
return new Promise((resolvePromise) => {
|
|
76
|
+
const onAbort = () => {
|
|
77
|
+
clearTimeout(timer);
|
|
78
|
+
resolvePromise(false);
|
|
79
|
+
};
|
|
80
|
+
const timer = setTimeout(() => {
|
|
81
|
+
signal.removeEventListener('abort', onAbort);
|
|
82
|
+
resolvePromise(true);
|
|
83
|
+
}, pollIntervalMs);
|
|
84
|
+
signal.addEventListener('abort', onAbort, { once: true });
|
|
85
|
+
});
|
|
86
|
+
}
|
|
87
|
+
|
|
58
88
|
export async function watchFinalizedCheckpoints(options) {
|
|
59
89
|
const checkpointsDir = resolve(String(options.checkpointsDir));
|
|
60
90
|
const manifestPath = resolve(String(options.manifestPath));
|
|
@@ -65,6 +95,7 @@ export async function watchFinalizedCheckpoints(options) {
|
|
|
65
95
|
const onCheckpoint = typeof options.onCheckpoint === 'function'
|
|
66
96
|
? options.onCheckpoint
|
|
67
97
|
: null;
|
|
98
|
+
const signal = options.signal ?? null;
|
|
68
99
|
if (!onCheckpoint) {
|
|
69
100
|
throw new Error('watchFinalizedCheckpoints requires onCheckpoint(markerPath).');
|
|
70
101
|
}
|
|
@@ -72,6 +103,9 @@ export async function watchFinalizedCheckpoints(options) {
|
|
|
72
103
|
const processed = await readProcessedManifest(manifestPath);
|
|
73
104
|
let idlePolls = 0;
|
|
74
105
|
for (;;) {
|
|
106
|
+
if (signal?.aborted) {
|
|
107
|
+
return createWatchResult(processed, manifestPath, true);
|
|
108
|
+
}
|
|
75
109
|
const checkpointsExist = await ensureDirectoryExists(checkpointsDir);
|
|
76
110
|
const markers = checkpointsExist
|
|
77
111
|
? await listCheckpointMarkers(checkpointsDir)
|
|
@@ -92,15 +126,14 @@ export async function watchFinalizedCheckpoints(options) {
|
|
|
92
126
|
if (!sawNewMarker) {
|
|
93
127
|
idlePolls += 1;
|
|
94
128
|
if (stopWhenIdle && idlePolls > 0) {
|
|
95
|
-
return
|
|
96
|
-
ok: true,
|
|
97
|
-
processedCount: processed.size,
|
|
98
|
-
manifestPath,
|
|
99
|
-
};
|
|
129
|
+
return createWatchResult(processed, manifestPath);
|
|
100
130
|
}
|
|
101
131
|
} else {
|
|
102
132
|
idlePolls = 0;
|
|
103
133
|
}
|
|
104
|
-
|
|
134
|
+
const shouldContinue = await waitForPollInterval(pollIntervalMs, signal);
|
|
135
|
+
if (!shouldContinue) {
|
|
136
|
+
return createWatchResult(processed, manifestPath, true);
|
|
137
|
+
}
|
|
105
138
|
}
|
|
106
139
|
}
|
|
@@ -31,6 +31,13 @@ function openCheckpointDB(options = {}) {
|
|
|
31
31
|
});
|
|
32
32
|
}
|
|
33
33
|
|
|
34
|
+
function closeCheckpointDB(db) {
|
|
35
|
+
if (!db || typeof db.close !== 'function') {
|
|
36
|
+
return;
|
|
37
|
+
}
|
|
38
|
+
db.close();
|
|
39
|
+
}
|
|
40
|
+
|
|
34
41
|
async function resolveNodeCheckpointPath(key, options = {}) {
|
|
35
42
|
const [{ resolve, join, dirname }, { mkdir }] = await Promise.all([
|
|
36
43
|
import('node:path'),
|
|
@@ -140,9 +147,15 @@ export async function saveCheckpoint(key, payload, options = {}) {
|
|
|
140
147
|
const useNodeStore = isNodeRuntime() && typeof indexedDB === 'undefined';
|
|
141
148
|
const nodePath = useNodeStore ? await resolveNodeCheckpointPath(key, options) : null;
|
|
142
149
|
const browserStore = useNodeStore ? null : await openCheckpointDB(options);
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
150
|
+
let previousData;
|
|
151
|
+
try {
|
|
152
|
+
previousData = useNodeStore
|
|
153
|
+
? await readNodeCheckpointRecord(nodePath)
|
|
154
|
+
: await readCheckpointRecord(browserStore.db, browserStore.storeName, key);
|
|
155
|
+
} catch (error) {
|
|
156
|
+
closeCheckpointDB(browserStore?.db);
|
|
157
|
+
throw error;
|
|
158
|
+
}
|
|
146
159
|
const previousMetadata = previousData?.metadata || {};
|
|
147
160
|
const previousLineage = previousMetadata.lineage || {};
|
|
148
161
|
const previousCheckpointHash = options.priorCheckpointHash
|
|
@@ -194,13 +207,25 @@ export async function saveCheckpoint(key, payload, options = {}) {
|
|
|
194
207
|
|
|
195
208
|
return new Promise((resolve, reject) => {
|
|
196
209
|
const tx = browserStore.db.transaction(browserStore.storeName, 'readwrite');
|
|
197
|
-
tx.oncomplete = () =>
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
210
|
+
tx.oncomplete = () => {
|
|
211
|
+
closeCheckpointDB(browserStore.db);
|
|
212
|
+
resolve({
|
|
213
|
+
key,
|
|
214
|
+
path: null,
|
|
215
|
+
metadata: data.metadata,
|
|
216
|
+
data,
|
|
217
|
+
});
|
|
218
|
+
};
|
|
219
|
+
tx.onerror = () => {
|
|
220
|
+
const error = tx.error;
|
|
221
|
+
closeCheckpointDB(browserStore.db);
|
|
222
|
+
reject(error);
|
|
223
|
+
};
|
|
224
|
+
tx.onabort = () => {
|
|
225
|
+
const error = tx.error ?? new Error('Checkpoint transaction aborted');
|
|
226
|
+
closeCheckpointDB(browserStore.db);
|
|
227
|
+
reject(error);
|
|
228
|
+
};
|
|
204
229
|
const store = tx.objectStore(browserStore.storeName);
|
|
205
230
|
store.put(data, key);
|
|
206
231
|
});
|
|
@@ -213,7 +238,11 @@ export async function loadCheckpoint(key, options = {}) {
|
|
|
213
238
|
? await readNodeCheckpointRecord(nodePath)
|
|
214
239
|
: await (async () => {
|
|
215
240
|
const { db, storeName } = await openCheckpointDB(options);
|
|
216
|
-
|
|
241
|
+
try {
|
|
242
|
+
return await readCheckpointRecord(db, storeName, key);
|
|
243
|
+
} finally {
|
|
244
|
+
closeCheckpointDB(db);
|
|
245
|
+
}
|
|
217
246
|
})();
|
|
218
247
|
|
|
219
248
|
if (!data || !data.metadata || !options.expectedMetadata) {
|
package/src/training/clip.js
CHANGED
|
@@ -12,7 +12,8 @@ async function readGradData(grad) {
|
|
|
12
12
|
}
|
|
13
13
|
|
|
14
14
|
export async function clipGradients(grads, config) {
|
|
15
|
-
const maxNorm = config?.training?.
|
|
15
|
+
const maxNorm = config?.training?.gradientClipping?.maxNorm
|
|
16
|
+
?? config?.training?.gradient?.maxNorm;
|
|
16
17
|
let sumSq = 0;
|
|
17
18
|
let totalParamCount = 0;
|
|
18
19
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
|
|
2
|
-
import { acquireBuffer, uploadData } from '../../memory/buffer-pool.js';
|
|
2
|
+
import { acquireBuffer, uploadData, releaseBuffer } from '../../memory/buffer-pool.js';
|
|
3
3
|
import { createTensor } from '../../gpu/tensor.js';
|
|
4
4
|
|
|
5
5
|
function flattenTokenBatch(samples, key) {
|
|
@@ -27,14 +27,26 @@ export function buildTokenBatch(samples) {
|
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
export function createTokenBatchTensors(batch) {
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
let inputBuf = null;
|
|
31
|
+
let targetBuf = null;
|
|
32
|
+
try {
|
|
33
|
+
inputBuf = acquireBuffer(batch.inputFlat.byteLength, undefined, 'train_input_tokens');
|
|
34
|
+
uploadData(inputBuf, batch.inputFlat);
|
|
32
35
|
|
|
33
|
-
|
|
34
|
-
|
|
36
|
+
targetBuf = acquireBuffer(batch.targetFlat.byteLength, undefined, 'train_target_tokens');
|
|
37
|
+
uploadData(targetBuf, batch.targetFlat);
|
|
35
38
|
|
|
36
|
-
|
|
37
|
-
|
|
39
|
+
const input = createTensor(inputBuf, 'f32', [batch.inputFlat.length], 'train_input_tokens');
|
|
40
|
+
const targets = createTensor(targetBuf, 'f32', [batch.targetFlat.length], 'train_target_tokens');
|
|
38
41
|
|
|
39
|
-
|
|
42
|
+
return { input, targets, offsets: batch.offsets };
|
|
43
|
+
} catch (error) {
|
|
44
|
+
if (inputBuf) {
|
|
45
|
+
releaseBuffer(inputBuf);
|
|
46
|
+
}
|
|
47
|
+
if (targetBuf) {
|
|
48
|
+
releaseBuffer(targetBuf);
|
|
49
|
+
}
|
|
50
|
+
throw error;
|
|
51
|
+
}
|
|
40
52
|
}
|
|
@@ -14,6 +14,7 @@ export async function watchDistillationCheckpoints(options) {
|
|
|
14
14
|
manifestPath,
|
|
15
15
|
pollIntervalMs: options.pollIntervalMs || 2000,
|
|
16
16
|
stopWhenIdle: options.stopWhenIdle === true,
|
|
17
|
+
signal: options.signal ?? null,
|
|
17
18
|
onCheckpoint: async (markerPath) => {
|
|
18
19
|
const { marker } = await readDistillCheckpointMarker(markerPath);
|
|
19
20
|
const reports = await evaluateDistillationCheckpoint({
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
export interface DistillStudentFixture {
|
|
2
|
+
config: Record<string, unknown>;
|
|
3
|
+
model: {
|
|
4
|
+
forward: (input: unknown, tape: unknown) => Promise<unknown>;
|
|
5
|
+
forwardDistill?: (
|
|
6
|
+
batch: unknown,
|
|
7
|
+
tape: unknown,
|
|
8
|
+
options?: Record<string, unknown>
|
|
9
|
+
) => Promise<{ logits: unknown }>;
|
|
10
|
+
cleanupDistillStep?: () => void;
|
|
11
|
+
loraParams?: () => unknown[];
|
|
12
|
+
paramGroups?: () => Record<string, unknown[]>;
|
|
13
|
+
};
|
|
14
|
+
outputDim?: number;
|
|
15
|
+
embeddingDim?: number;
|
|
16
|
+
cleanup(): void;
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
export declare function createDistillStudentRuntimeModelFixture(
|
|
20
|
+
overrides?: Record<string, unknown>,
|
|
21
|
+
options?: Record<string, unknown>
|
|
22
|
+
): Promise<DistillStudentFixture>;
|