@seanhogg/builderforce-memory-engine 2026.6.18
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +393 -0
- package/dist/index.d.ts +32 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +40 -0
- package/dist/index.js.map +1 -0
- package/dist/kernels/activations.d.ts +5 -0
- package/dist/kernels/activations.d.ts.map +1 -0
- package/dist/kernels/activations.js +171 -0
- package/dist/kernels/activations.js.map +1 -0
- package/dist/kernels/attention.d.ts +19 -0
- package/dist/kernels/attention.d.ts.map +1 -0
- package/dist/kernels/attention.js +263 -0
- package/dist/kernels/attention.js.map +1 -0
- package/dist/kernels/complex_ssd.d.ts +33 -0
- package/dist/kernels/complex_ssd.d.ts.map +1 -0
- package/dist/kernels/complex_ssd.js +305 -0
- package/dist/kernels/complex_ssd.js.map +1 -0
- package/dist/kernels/conv1d.d.ts +3 -0
- package/dist/kernels/conv1d.d.ts.map +1 -0
- package/dist/kernels/conv1d.js +158 -0
- package/dist/kernels/conv1d.js.map +1 -0
- package/dist/kernels/linear_projection.d.ts +3 -0
- package/dist/kernels/linear_projection.d.ts.map +1 -0
- package/dist/kernels/linear_projection.js +219 -0
- package/dist/kernels/linear_projection.js.map +1 -0
- package/dist/kernels/selective_scan.d.ts +3 -0
- package/dist/kernels/selective_scan.d.ts.map +1 -0
- package/dist/kernels/selective_scan.js +348 -0
- package/dist/kernels/selective_scan.js.map +1 -0
- package/dist/kernels/ssd.d.ts +29 -0
- package/dist/kernels/ssd.d.ts.map +1 -0
- package/dist/kernels/ssd.js +276 -0
- package/dist/kernels/ssd.js.map +1 -0
- package/dist/kernels/weight_update.d.ts +3 -0
- package/dist/kernels/weight_update.d.ts.map +1 -0
- package/dist/kernels/weight_update.js +119 -0
- package/dist/kernels/weight_update.js.map +1 -0
- package/dist/model/attention_block.d.ts +48 -0
- package/dist/model/attention_block.d.ts.map +1 -0
- package/dist/model/attention_block.js +262 -0
- package/dist/model/attention_block.js.map +1 -0
- package/dist/model/mamba1_block.d.ts +70 -0
- package/dist/model/mamba1_block.d.ts.map +1 -0
- package/dist/model/mamba1_block.js +333 -0
- package/dist/model/mamba1_block.js.map +1 -0
- package/dist/model/mamba2_block.d.ts +44 -0
- package/dist/model/mamba2_block.d.ts.map +1 -0
- package/dist/model/mamba2_block.js +252 -0
- package/dist/model/mamba2_block.js.map +1 -0
- package/dist/model/mamba3_block.d.ts +51 -0
- package/dist/model/mamba3_block.d.ts.map +1 -0
- package/dist/model/mamba3_block.js +270 -0
- package/dist/model/mamba3_block.js.map +1 -0
- package/dist/model/mamba_block.d.ts +64 -0
- package/dist/model/mamba_block.d.ts.map +1 -0
- package/dist/model/mamba_block.js +303 -0
- package/dist/model/mamba_block.js.map +1 -0
- package/dist/model/mamba_model.d.ts +140 -0
- package/dist/model/mamba_model.d.ts.map +1 -0
- package/dist/model/mamba_model.js +527 -0
- package/dist/model/mamba_model.js.map +1 -0
- package/dist/model/sequence_layer.d.ts +25 -0
- package/dist/model/sequence_layer.d.ts.map +1 -0
- package/dist/model/sequence_layer.js +8 -0
- package/dist/model/sequence_layer.js.map +1 -0
- package/dist/tokenizer/bpe.d.ts +29 -0
- package/dist/tokenizer/bpe.d.ts.map +1 -0
- package/dist/tokenizer/bpe.js +164 -0
- package/dist/tokenizer/bpe.js.map +1 -0
- package/dist/training/autograd.d.ts +27 -0
- package/dist/training/autograd.d.ts.map +1 -0
- package/dist/training/autograd.js +120 -0
- package/dist/training/autograd.js.map +1 -0
- package/dist/training/trainer.d.ts +36 -0
- package/dist/training/trainer.d.ts.map +1 -0
- package/dist/training/trainer.js +183 -0
- package/dist/training/trainer.js.map +1 -0
- package/dist/utils/gpu_utils.d.ts +21 -0
- package/dist/utils/gpu_utils.d.ts.map +1 -0
- package/dist/utils/gpu_utils.js +111 -0
- package/dist/utils/gpu_utils.js.map +1 -0
- package/dist/utils/quantization.d.ts +26 -0
- package/dist/utils/quantization.d.ts.map +1 -0
- package/dist/utils/quantization.js +116 -0
- package/dist/utils/quantization.js.map +1 -0
- package/dist/utils/rng.d.ts +36 -0
- package/dist/utils/rng.d.ts.map +1 -0
- package/dist/utils/rng.js +61 -0
- package/dist/utils/rng.js.map +1 -0
- package/package.json +99 -0
- package/src/index.ts +114 -0
- package/src/kernels/activations.ts +174 -0
- package/src/kernels/attention.ts +268 -0
- package/src/kernels/complex_ssd.ts +307 -0
- package/src/kernels/conv1d.ts +159 -0
- package/src/kernels/linear_projection.ts +220 -0
- package/src/kernels/selective_scan.ts +350 -0
- package/src/kernels/ssd.ts +278 -0
- package/src/kernels/weight_update.ts +120 -0
- package/src/model/attention_block.ts +344 -0
- package/src/model/mamba1_block.ts +437 -0
- package/src/model/mamba2_block.ts +319 -0
- package/src/model/mamba3_block.ts +335 -0
- package/src/model/mamba_block.ts +401 -0
- package/src/model/mamba_model.ts +678 -0
- package/src/model/sequence_layer.ts +29 -0
- package/src/tokenizer/bpe.ts +186 -0
- package/src/training/autograd.ts +135 -0
- package/src/training/trainer.ts +309 -0
- package/src/utils/gpu_utils.ts +147 -0
- package/src/utils/quantization.ts +154 -0
- package/src/utils/rng.ts +65 -0
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* gpu_utils.ts – WebGPU device management and buffer helpers.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
6
|
+
const _gpu = globalThis as any;
|
|
7
|
+
const UNIFORM: number = _gpu.GPUBufferUsage?.UNIFORM ?? 0x40;
|
|
8
|
+
const STORAGE: number = _gpu.GPUBufferUsage?.STORAGE ?? 0x80;
|
|
9
|
+
const COPY_SRC: number = _gpu.GPUBufferUsage?.COPY_SRC ?? 0x04;
|
|
10
|
+
const COPY_DST: number = _gpu.GPUBufferUsage?.COPY_DST ?? 0x08;
|
|
11
|
+
const MAP_READ: number = _gpu.GPUBufferUsage?.MAP_READ ?? 0x01;
|
|
12
|
+
|
|
13
|
+
export interface InitWebGPUOptions {
|
|
14
|
+
powerPreference?: 'high-performance' | 'low-power';
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
export interface InitWebGPUResult {
|
|
18
|
+
device: GPUDevice;
|
|
19
|
+
adapter: GPUAdapter;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export async function initWebGPU(opts: InitWebGPUOptions = {}): Promise<InitWebGPUResult> {
|
|
23
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
24
|
+
throw new Error(
|
|
25
|
+
'WebGPU is not available in this environment. ' +
|
|
26
|
+
'Use Chrome 113+, Edge 113+, or Firefox Nightly with WebGPU enabled.'
|
|
27
|
+
);
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
const adapter = await navigator.gpu.requestAdapter({
|
|
31
|
+
powerPreference: opts.powerPreference ?? 'high-performance',
|
|
32
|
+
});
|
|
33
|
+
|
|
34
|
+
if (!adapter) {
|
|
35
|
+
throw new Error('Failed to acquire a GPUAdapter. Your GPU may not support WebGPU.');
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
const adapterLimits = adapter.limits;
|
|
39
|
+
const requested3GB = 3 * 1024 * 1024 * 1024;
|
|
40
|
+
const device = await adapter.requestDevice({
|
|
41
|
+
requiredLimits: {
|
|
42
|
+
maxBufferSize: Math.min(
|
|
43
|
+
requested3GB,
|
|
44
|
+
adapterLimits.maxBufferSize
|
|
45
|
+
),
|
|
46
|
+
maxStorageBufferBindingSize: Math.min(
|
|
47
|
+
requested3GB,
|
|
48
|
+
adapterLimits.maxStorageBufferBindingSize
|
|
49
|
+
),
|
|
50
|
+
maxComputeInvocationsPerWorkgroup: Math.min(
|
|
51
|
+
256,
|
|
52
|
+
adapterLimits.maxComputeInvocationsPerWorkgroup
|
|
53
|
+
),
|
|
54
|
+
},
|
|
55
|
+
});
|
|
56
|
+
|
|
57
|
+
device.lost.then((info) => {
|
|
58
|
+
console.error('WebGPU device lost:', info.message);
|
|
59
|
+
});
|
|
60
|
+
|
|
61
|
+
return { device, adapter };
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
export function createStorageBuffer(device: GPUDevice, data: Float32Array | Uint32Array | number[], readable = false): GPUBuffer {
|
|
65
|
+
const arr = data instanceof Float32Array || data instanceof Uint32Array ? data : new Float32Array(data);
|
|
66
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
67
|
+
const buffer = device.createBuffer({ size: arr.byteLength, usage, mappedAtCreation: true });
|
|
68
|
+
if (arr instanceof Uint32Array) {
|
|
69
|
+
new Uint32Array(buffer.getMappedRange()).set(arr);
|
|
70
|
+
} else {
|
|
71
|
+
new Float32Array(buffer.getMappedRange()).set(arr as Float32Array);
|
|
72
|
+
}
|
|
73
|
+
buffer.unmap();
|
|
74
|
+
return buffer;
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
export function createEmptyStorageBuffer(device: GPUDevice, byteSize: number, readable = false): GPUBuffer {
|
|
78
|
+
const usage = STORAGE | COPY_DST | (readable ? COPY_SRC : 0);
|
|
79
|
+
return device.createBuffer({ size: byteSize, usage });
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
export function createUniformBuffer(device: GPUDevice, data: ArrayBuffer | ArrayBufferView): GPUBuffer {
|
|
83
|
+
const bytes = ArrayBuffer.isView(data) ? data.buffer : data;
|
|
84
|
+
const buffer = device.createBuffer({
|
|
85
|
+
size : bytes.byteLength,
|
|
86
|
+
usage : UNIFORM | COPY_DST,
|
|
87
|
+
mappedAtCreation: true,
|
|
88
|
+
});
|
|
89
|
+
new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(bytes));
|
|
90
|
+
buffer.unmap();
|
|
91
|
+
return buffer;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
export async function readBuffer(device: GPUDevice, srcBuffer: GPUBuffer, byteSize: number): Promise<Float32Array> {
|
|
95
|
+
const MAP_READ_FLAG: number = _gpu.GPUMapMode?.READ ?? 0x01;
|
|
96
|
+
const stagingBuffer = device.createBuffer({
|
|
97
|
+
size : byteSize,
|
|
98
|
+
usage : MAP_READ | COPY_DST,
|
|
99
|
+
});
|
|
100
|
+
|
|
101
|
+
const encoder = device.createCommandEncoder();
|
|
102
|
+
encoder.copyBufferToBuffer(srcBuffer, 0, stagingBuffer, 0, byteSize);
|
|
103
|
+
device.queue.submit([encoder.finish()]);
|
|
104
|
+
|
|
105
|
+
await stagingBuffer.mapAsync(MAP_READ_FLAG);
|
|
106
|
+
const result = new Float32Array(stagingBuffer.getMappedRange().slice(0));
|
|
107
|
+
stagingBuffer.unmap();
|
|
108
|
+
stagingBuffer.destroy();
|
|
109
|
+
return result;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
export function uploadBuffer(device: GPUDevice, buffer: GPUBuffer, data: Float32Array, byteOffset = 0): void {
|
|
113
|
+
device.queue.writeBuffer(buffer, byteOffset, data.buffer, data.byteOffset, data.byteLength);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
export function createComputePipeline(device: GPUDevice, wgslSource: string, entryPoint: string): GPUComputePipeline {
|
|
117
|
+
const shaderModule = device.createShaderModule({ code: wgslSource });
|
|
118
|
+
return device.createComputePipeline({
|
|
119
|
+
layout : 'auto',
|
|
120
|
+
compute: { module: shaderModule, entryPoint },
|
|
121
|
+
});
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
export function createBindGroup(device: GPUDevice, pipeline: GPUComputePipeline, buffers: GPUBuffer[], groupIndex = 0): GPUBindGroup {
|
|
125
|
+
const entries = buffers.map((buf, i) => ({
|
|
126
|
+
binding : i,
|
|
127
|
+
resource: { buffer: buf },
|
|
128
|
+
}));
|
|
129
|
+
return device.createBindGroup({
|
|
130
|
+
layout : pipeline.getBindGroupLayout(groupIndex),
|
|
131
|
+
entries,
|
|
132
|
+
});
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
export function dispatchKernel(device: GPUDevice, pipeline: GPUComputePipeline, bindGroup: GPUBindGroup, workgroups: [number, number, number]): void {
|
|
136
|
+
const encoder = device.createCommandEncoder();
|
|
137
|
+
const pass = encoder.beginComputePass();
|
|
138
|
+
pass.setPipeline(pipeline);
|
|
139
|
+
pass.setBindGroup(0, bindGroup);
|
|
140
|
+
pass.dispatchWorkgroups(...workgroups);
|
|
141
|
+
pass.end();
|
|
142
|
+
device.queue.submit([encoder.finish()]);
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
export function cdiv(a: number, b: number): number {
|
|
146
|
+
return Math.ceil(a / b);
|
|
147
|
+
}
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* quantization.ts – FP16 and Int8 quantization utilities.
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export interface QuantizeInt8Result {
|
|
6
|
+
data: Int8Array;
|
|
7
|
+
scale: number;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
export interface QuantizeInt8PerChannelResult {
|
|
11
|
+
data: Int8Array;
|
|
12
|
+
scales: Float32Array;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
export interface MemoryEstimate {
|
|
16
|
+
fp32: number;
|
|
17
|
+
fp16: number;
|
|
18
|
+
int8: number;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
export function floatToFp16(val: number): number {
|
|
22
|
+
const buf = new ArrayBuffer(4);
|
|
23
|
+
const f32 = new Float32Array(buf);
|
|
24
|
+
const u32 = new Uint32Array(buf);
|
|
25
|
+
f32[0] = val;
|
|
26
|
+
const bits = u32[0]!;
|
|
27
|
+
|
|
28
|
+
const sign = (bits >>> 31) & 0x1;
|
|
29
|
+
const exponent = (bits >>> 23) & 0xFF;
|
|
30
|
+
const mantissa = bits & 0x7FFFFF;
|
|
31
|
+
|
|
32
|
+
if (exponent === 255) {
|
|
33
|
+
return (sign << 15) | 0x7C00 | (mantissa ? 0x200 : 0);
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
const expAdj = exponent - 127 + 15;
|
|
37
|
+
|
|
38
|
+
if (expAdj >= 31) {
|
|
39
|
+
return (sign << 15) | 0x7C00;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if (expAdj <= 0) {
|
|
43
|
+
if (expAdj < -10) { return sign << 15; }
|
|
44
|
+
const shift = 14 - expAdj;
|
|
45
|
+
return (sign << 15) | ((mantissa | 0x800000) >> shift);
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return (sign << 15) | (expAdj << 10) | (mantissa >> 13);
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
export function fp16ToFloat(val: number): number {
|
|
52
|
+
const sign = (val >>> 15) & 0x1;
|
|
53
|
+
const exponent = (val >>> 10) & 0x1F;
|
|
54
|
+
const mantissa = val & 0x3FF;
|
|
55
|
+
|
|
56
|
+
if (exponent === 0) {
|
|
57
|
+
const f = mantissa / 1024.0;
|
|
58
|
+
return sign ? -f : f;
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if (exponent === 31) {
|
|
62
|
+
return sign ? -Infinity : (mantissa ? NaN : Infinity);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
const expUnbiased = exponent - 15;
|
|
66
|
+
const f = (1 + mantissa / 1024.0) * Math.pow(2, expUnbiased);
|
|
67
|
+
return sign ? -f : f;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
export function quantizeFp16(f32: Float32Array): Uint16Array {
|
|
71
|
+
const out = new Uint16Array(f32.length);
|
|
72
|
+
for (let i = 0; i < f32.length; i++) {
|
|
73
|
+
out[i] = floatToFp16(f32[i]!);
|
|
74
|
+
}
|
|
75
|
+
return out;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
export function dequantizeFp16(fp16: Uint16Array): Float32Array {
|
|
79
|
+
const out = new Float32Array(fp16.length);
|
|
80
|
+
for (let i = 0; i < fp16.length; i++) {
|
|
81
|
+
out[i] = fp16ToFloat(fp16[i]!);
|
|
82
|
+
}
|
|
83
|
+
return out;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
export function quantizeInt8(f32: Float32Array): QuantizeInt8Result {
|
|
87
|
+
let maxAbs = 0;
|
|
88
|
+
for (let i = 0; i < f32.length; i++) {
|
|
89
|
+
const a = Math.abs(f32[i]!);
|
|
90
|
+
if (a > maxAbs) maxAbs = a;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
const scale = maxAbs / 127.0 || 1.0;
|
|
94
|
+
const data = new Int8Array(f32.length);
|
|
95
|
+
|
|
96
|
+
for (let i = 0; i < f32.length; i++) {
|
|
97
|
+
data[i] = Math.max(-128, Math.min(127, Math.round(f32[i]! / scale)));
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
return { data, scale };
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
export function dequantizeInt8(int8: Int8Array, scale: number): Float32Array {
|
|
104
|
+
const out = new Float32Array(int8.length);
|
|
105
|
+
for (let i = 0; i < int8.length; i++) {
|
|
106
|
+
out[i] = int8[i]! * scale;
|
|
107
|
+
}
|
|
108
|
+
return out;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
export function quantizeInt8PerChannel(f32: Float32Array, numChannels: number): QuantizeInt8PerChannelResult {
|
|
112
|
+
const channelSize = f32.length / numChannels;
|
|
113
|
+
const scales = new Float32Array(numChannels);
|
|
114
|
+
const data = new Int8Array(f32.length);
|
|
115
|
+
|
|
116
|
+
for (let c = 0; c < numChannels; c++) {
|
|
117
|
+
let maxAbs = 0;
|
|
118
|
+
const base = c * channelSize;
|
|
119
|
+
for (let j = 0; j < channelSize; j++) {
|
|
120
|
+
const a = Math.abs(f32[base + j]!);
|
|
121
|
+
if (a > maxAbs) maxAbs = a;
|
|
122
|
+
}
|
|
123
|
+
scales[c] = maxAbs / 127.0 || 1.0;
|
|
124
|
+
for (let j = 0; j < channelSize; j++) {
|
|
125
|
+
data[base + j] = Math.max(-128, Math.min(127,
|
|
126
|
+
Math.round(f32[base + j]! / scales[c]!)
|
|
127
|
+
));
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
return { data, scales };
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
export function dequantizeInt8PerChannel(int8: Int8Array, scales: Float32Array, numChannels: number): Float32Array {
|
|
135
|
+
const channelSize = int8.length / numChannels;
|
|
136
|
+
const out = new Float32Array(int8.length);
|
|
137
|
+
|
|
138
|
+
for (let c = 0; c < numChannels; c++) {
|
|
139
|
+
const base = c * channelSize;
|
|
140
|
+
for (let j = 0; j < channelSize; j++) {
|
|
141
|
+
out[base + j] = int8[base + j]! * scales[c]!;
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
return out;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
export function estimateMemory(numElements: number): MemoryEstimate {
|
|
149
|
+
return {
|
|
150
|
+
fp32: numElements * 4,
|
|
151
|
+
fp16: numElements * 2,
|
|
152
|
+
int8: numElements * 1,
|
|
153
|
+
};
|
|
154
|
+
}
|
package/src/utils/rng.ts
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* rng.ts – shared, optionally-seeded random source for weight initialisation.
|
|
3
|
+
*
|
|
4
|
+
* Weight init across the model and every block used to duplicate the same
|
|
5
|
+
* `Math.random()` Box–Muller draw. That made cold-start weights
|
|
6
|
+
* non-reproducible across machines. This module centralises the draw and lets
|
|
7
|
+
* the model install a deterministic seed for the duration of construction, so
|
|
8
|
+
* the same `seed` yields byte-identical initial weights everywhere.
|
|
9
|
+
*
|
|
10
|
+
* The default (unseeded) source delegates to `Math.random`, preserving the
|
|
11
|
+
* original behaviour for callers that don't request a seed.
|
|
12
|
+
*
|
|
13
|
+
* The seeded generator uses the same LCG constants as tools/generate-bin.js so
|
|
14
|
+
* tooling and runtime agree on what a "seed N" model looks like.
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
/** Deterministic linear-congruential generator (Numerical Recipes constants). */
|
|
18
|
+
export class SeededRng {
|
|
19
|
+
private _s: number;
|
|
20
|
+
|
|
21
|
+
constructor(seed: number) {
|
|
22
|
+
// Avoid the zero fixed point; keep state in uint32 range.
|
|
23
|
+
this._s = (seed >>> 0) || 1;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
/** Next float in [0, 1). */
|
|
27
|
+
next(): number {
|
|
28
|
+
this._s = (Math.imul(1664525, this._s) + 1013904223) >>> 0;
|
|
29
|
+
return this._s / 0x1_0000_0000;
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/** Active uniform source. Swapped by setInitSeed; defaults to Math.random. */
|
|
34
|
+
let _next: () => number = Math.random;
|
|
35
|
+
|
|
36
|
+
/**
|
|
37
|
+
* Installs (or clears) the deterministic init seed.
|
|
38
|
+
* Pass a number to make subsequent `randn`/`gaussianArray` draws reproducible;
|
|
39
|
+
* pass `undefined` to restore the default `Math.random` source.
|
|
40
|
+
*
|
|
41
|
+
* Construction is synchronous, so a process-wide source is safe: seed before
|
|
42
|
+
* building a model and clear afterwards.
|
|
43
|
+
*/
|
|
44
|
+
export function setInitSeed(seed: number | undefined): void {
|
|
45
|
+
if (seed == null) {
|
|
46
|
+
_next = Math.random;
|
|
47
|
+
} else {
|
|
48
|
+
const rng = new SeededRng(seed);
|
|
49
|
+
_next = () => rng.next();
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
/** Box–Muller Gaussian sample from the active source. */
|
|
54
|
+
export function randn(std = 1): number {
|
|
55
|
+
const u1 = Math.max(_next(), 1e-12);
|
|
56
|
+
const u2 = _next();
|
|
57
|
+
return std * Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/** Returns a Float32Array of `n` Gaussian samples with the given standard deviation. */
|
|
61
|
+
export function gaussianArray(n: number, std: number): Float32Array {
|
|
62
|
+
const a = new Float32Array(n);
|
|
63
|
+
for (let i = 0; i < n; i++) a[i] = randn(std);
|
|
64
|
+
return a;
|
|
65
|
+
}
|