@buley/neural 4.1.1 → 4.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,107 +0,0 @@
1
-
2
- import { GPUEngine } from "../engine/gpu";
3
- import { performance } from "perf_hooks";
4
-
5
- // Mock WebGPU types for Node environment if needed, likely handled by bun's specialized runtime or mocks if strictly node.
6
- // However, since we are using Bun, we might need a headless WebGPU implementation or we run this in a real browser.
7
- // REALITY CHECK: running WebGPU in a headless CI/Node environment usually requires 'headless-gl' or similar, but WebGPU is newer.
8
- // Bun does not support WebGPU native out of the box yet.
9
- // For this task, since the user is on Mac, we will assume they might run this via a browser test runner OR we simulate/mock for the "structure" of the benchmark
10
- // if actual GPU access isn't available in the terminal.
11
- // BUT: The roadmap implies real GPU benchmarks.
12
- // Strategy: We will write the benchmark to be runnable. If it fails due to missing GPU in terminal,
13
- // we'll note that it needs to be run in a browser context (e.g. via the web app or a test runner that supports it).
14
- // actually, for the purpose of this agent, I'll implement it assuming the environment *might* support it or I'll add a check.
15
-
16
- async function runBenchmark(label: string, networkSize: number, batchSize: number, iterations: number) {
17
- console.log(`\n--- Benchmark: ${label} ---`);
18
- console.log(`Network: ${networkSize} Neurons, Batch: ${batchSize}`);
19
-
20
- const gpu = new GPUEngine();
21
-
22
- try {
23
- await gpu.init();
24
- } catch (e) {
25
- console.error("WebGPU Initialize Failed (Expected in non-browser env):", e);
26
- return;
27
- }
28
-
29
- // Prepare Data
30
- const weights = new Float32Array(networkSize * networkSize); // Full connectivity
31
- const biases = new Float32Array(networkSize);
32
- const inputs = new Float32Array(networkSize * batchSize);
33
- const targets = new Float32Array(networkSize * batchSize);
34
-
35
- // Init Buffers
36
- const startObj = performance.now();
37
- gpu.prepareBuffers(networkSize, weights, biases, batchSize);
38
- gpu.prepareTrainingBuffers(targets, 0.01);
39
- const initTime = performance.now() - startObj;
40
- console.log(`Initialization/Upload: ${initTime.toFixed(2)}ms`);
41
-
42
- // Warmup
43
- await gpu.runTick(inputs);
44
-
45
- // Measure Inference
46
- const startInf = performance.now();
47
- for (let i = 0; i < iterations; i++) {
48
- await gpu.runTick(inputs);
49
- }
50
- const endInf = performance.now();
51
- const infTime = endInf - startInf;
52
- const infOPS = (iterations * batchSize) / (infTime / 1000);
53
- console.log(`Inference: ${infTime.toFixed(2)}ms for ${iterations} ticks`);
54
- console.log(`Throughput: ${infOPS.toFixed(0)} samples/sec`);
55
-
56
- // Measure Training
57
- const startTrain = performance.now();
58
- for (let i = 0; i < iterations; i++) {
59
- await gpu.trainTick();
60
- }
61
- const endTrain = performance.now();
62
- const trainTime = endTrain - startTrain;
63
- const trainOPS = (iterations * batchSize) / (trainTime / 1000);
64
- console.log(`Training: ${trainTime.toFixed(2)}ms for ${iterations} ticks`);
65
- console.log(`Throughput: ${trainOPS.toFixed(0)} samples/sec`);
66
- }
67
-
68
- async function main() {
69
- // Small
70
- await runBenchmark("Small", 100, 1, 100);
71
-
72
- // Medium
73
- await runBenchmark("Medium (Batched)", 1000, 32, 50);
74
-
75
- // Large
76
- await runBenchmark("Large (Batched)", 5000, 64, 20);
77
- }
78
-
79
- // Check for WebGPU polyfill or mock if running in Node without headers
80
- if (!global.navigator?.gpu) {
81
- console.log("No WebGPU detected in global scope. Mocking for CLI structure verification...");
82
- // @ts-ignore
83
- global.navigator = {
84
- gpu: { ...({} as any) as GPU,
85
- requestAdapter: async () => ({ ...({} as any) as GPUAdapter, // Force cast for mock
86
- requestDevice: async () => ({ ...({} as any) as GPUDevice,
87
- createShaderModule: () => ({} as unknown as GPUShaderModule),
88
- createComputePipeline: () => ({ getBindGroupLayout: () => ({} as unknown as GPUBindGroupLayout) } as unknown as GPUComputePipeline),
89
- createBuffer: (d: any) => ({ getMappedRange: () => new ArrayBuffer(d.size), unmap: () => {}, mapAsync: async () => {} } as unknown as GPUBuffer),
90
- createBindGroup: () => ({} as unknown as GPUBindGroup),
91
- createCommandEncoder: () => ({
92
- beginComputePass: () => ({ setPipeline:()=>{}, setBindGroup:()=>{}, dispatchWorkgroups:()=>{}, end:()=>{} } as unknown as GPUComputePassEncoder),
93
- copyBufferToBuffer: ()=>{},
94
- finish: ()=>(({} as any) as GPUCommandBuffer)
95
- } as unknown as GPUCommandEncoder),
96
- queue: { writeBuffer: ()=>{}, submit: ()=>{} } as unknown as GPUQueue
97
- })
98
- })
99
- }
100
- };
101
- // @ts-ignore
102
- global.GPUBufferUsage = { STORAGE: 1, COPY_DST: 2, COPY_SRC: 4, UNIFORM: 8, MAP_READ: 16 };
103
- // @ts-ignore
104
- global.GPUMapMode = { READ: 1 };
105
- }
106
-
107
- main();
@@ -1,67 +0,0 @@
1
- import { expect, test, describe, mock } from "bun:test";
2
- import { NeuronRepository, SynapseRepository } from "./repository";
3
-
4
- // Mock @buley/dash
5
- const mockDash = {
6
- execute: mock((query, _params) => {
7
- // Simple mock implementation
8
- if (query.includes("INSERT")) return Promise.resolve();
9
- if (query.includes("SELECT * FROM neurons")) return Promise.resolve([
10
- { id: "n1", type: "input", bias: 0.1, activation: "tanh" }
11
- ]);
12
- if (query.includes("SELECT * FROM synapses")) return Promise.resolve([
13
- { id: "s1", from_id: "n1", to_id: "n2", weight: 0.5 }
14
- ]);
15
- return Promise.resolve([]);
16
- }),
17
- addWithEmbedding: mock(() => Promise.resolve())
18
- };
19
-
20
- // Mock module
21
- mock.module("@buley/dash", () => ({
22
- dash: mockDash
23
- }));
24
-
25
- describe("NeuronRepository", () => {
26
- test("create() executes INSERT query", async () => {
27
- const repo = new NeuronRepository();
28
- await repo.create({ id: "n1", type: "input", bias: 0.1, activation: "tanh" });
29
- expect(mockDash.execute).toHaveBeenCalled();
30
- const call = mockDash.execute.mock.calls[0];
31
- expect(call[0]).toContain("INSERT INTO neurons");
32
- expect(call[1]).toEqual(["n1", "input", 0.1, "tanh"]);
33
- });
34
-
35
- test("createWithSemantics() calls addWithEmbedding", async () => {
36
- const repo = new NeuronRepository();
37
- await repo.createWithSemantics(
38
- { id: "n2", type: "hidden", bias: 0, activation: "relu" },
39
- "detects curves"
40
- );
41
- expect(mockDash.addWithEmbedding).toHaveBeenCalledWith("n2", "detects curves");
42
- });
43
-
44
- test("getAll() returns neurons", async () => {
45
- mockDash.execute.mockClear();
46
- const repo = new NeuronRepository();
47
- const results = await repo.getAll();
48
- expect(results.length).toBe(1);
49
- expect(results[0].id).toBe("n1");
50
- });
51
- });
52
-
53
- describe("SynapseRepository", () => {
54
- test("create() executes INSERT query", async () => {
55
- const repo = new SynapseRepository();
56
- await repo.create({ id: "s1", from_id: "n1", to_id: "n2", weight: 0.5 });
57
- expect(mockDash.execute).toHaveBeenCalled();
58
- // Check latest call
59
- });
60
-
61
- test("getAll() returns synapses", async () => {
62
- const repo = new SynapseRepository();
63
- const results = await repo.getAll();
64
- expect(results.length).toBe(1);
65
- expect(results[0].weight).toBe(0.5);
66
- });
67
- });
@@ -1,44 +0,0 @@
1
- import { dash } from "@buley/dash";
2
- import { Neuron, Synapse } from "../types";
3
-
4
- export class NeuronRepository {
5
- async create(neuron: Neuron): Promise<void> {
6
- await dash.execute(
7
- "INSERT INTO neurons (id, type, bias, activation) VALUES (?, ?, ?, ?)",
8
- [neuron.id, neuron.type, neuron.bias, neuron.activation]
9
- );
10
- }
11
-
12
- // Feature: Add with semantic embedding
13
- async createWithSemantics(neuron: Neuron, description: string): Promise<void> {
14
- // We store the structured data normally
15
- await this.create(neuron);
16
- // And we map the ID to a semantic embedding in dash's hidden semantic store
17
- await dash.addWithEmbedding(neuron.id, description);
18
- }
19
-
20
- async getAll(): Promise<Neuron[]> {
21
- return await dash.execute("SELECT * FROM neurons") as Neuron[];
22
- }
23
-
24
- async delete(id: string): Promise<void> {
25
- await dash.execute("DELETE FROM neurons WHERE id = ?", [id]);
26
- }
27
- }
28
-
29
- export class SynapseRepository {
30
- async create(synapse: Synapse): Promise<void> {
31
- await dash.execute(
32
- "INSERT INTO synapses (id, from_id, to_id, weight) VALUES (?, ?, ?, ?)",
33
- [synapse.id, synapse.from_id, synapse.to_id, synapse.weight]
34
- );
35
- }
36
-
37
- async getAll(): Promise<Synapse[]> {
38
- return await dash.execute("SELECT * FROM synapses") as Synapse[];
39
- }
40
-
41
- async delete(id: string): Promise<void> {
42
- await dash.execute("DELETE FROM synapses WHERE id = ?", [id]);
43
- }
44
- }
package/src/db/schema.ts DELETED
@@ -1,40 +0,0 @@
1
-
2
- import { dash } from "@buley/dash";
3
-
4
- export async function initializeSchema() {
5
- console.log("Initializing Neural Schema...");
6
-
7
- // Neurons Table
8
- // id: UUID
9
- // type: input, hidden, output
10
- // bias: float
11
- // activation: string (tanh, relu, sigmoid)
12
- await dash.execute(`
13
- CREATE TABLE IF NOT EXISTS neurons (
14
- id TEXT PRIMARY KEY,
15
- type TEXT NOT NULL,
16
- bias REAL DEFAULT 0.0,
17
- activation TEXT DEFAULT 'tanh',
18
- created_at INTEGER DEFAULT (unixepoch())
19
- )
20
- `);
21
-
22
- // Synapses Table
23
- // id: UUID
24
- // from_id: neuron UUID
25
- // to_id: neuron UUID
26
- // weight: float
27
- await dash.execute(`
28
- CREATE TABLE IF NOT EXISTS synapses (
29
- id TEXT PRIMARY KEY,
30
- from_id TEXT NOT NULL,
31
- to_id TEXT NOT NULL,
32
- weight REAL DEFAULT 0.0,
33
- created_at INTEGER DEFAULT (unixepoch()),
34
- FOREIGN KEY(from_id) REFERENCES neurons(id),
35
- FOREIGN KEY(to_id) REFERENCES neurons(id)
36
- )
37
- `);
38
-
39
- console.log("Schema initialized.");
40
- }
@@ -1,120 +0,0 @@
1
- import { expect, test, describe, mock } from "bun:test";
2
- import { GPUEngine } from "./gpu";
3
-
4
- // Mock WebGPU Globals
5
- const mockDevice = {
6
- createShaderModule: mock(() => ({})),
7
- createComputePipeline: mock(() => ({
8
- getBindGroupLayout: mock(() => ({}))
9
- })),
10
- createBuffer: mock((desc: any) => ({
11
- getMappedRange: () => new ArrayBuffer(desc.size),
12
- unmap: () => {},
13
- mapAsync: async () => {}
14
- })),
15
- createBindGroup: mock(() => ({})),
16
- createCommandEncoder: mock(() => ({
17
- beginComputePass: mock(() => ({
18
- setPipeline: mock(() => {}),
19
- setBindGroup: mock(() => {}),
20
- dispatchWorkgroups: mock(() => {}),
21
- end: mock(() => {})
22
- })),
23
- copyBufferToBuffer: mock(() => {}),
24
- finish: mock(() => ({}))
25
- })),
26
- queue: {
27
- writeBuffer: mock(() => {}),
28
- submit: mock(() => {})
29
- }
30
- };
31
-
32
- const mockAdapter = {
33
- requestDevice: mock(async () => mockDevice)
34
- };
35
-
36
- // Polyfill navigator.gpu
37
- // @ts-ignore
38
- global.navigator = {
39
- gpu: {
40
- requestAdapter: mock(async () => mockAdapter as unknown as GPUAdapter)
41
- } as unknown as GPU
42
- };
43
-
44
- // Polyfill Globals
45
- // @ts-ignore
46
- global.GPUBufferUsage = {
47
- MAP_READ: 1,
48
- MAP_WRITE: 2,
49
- COPY_SRC: 4,
50
- COPY_DST: 8,
51
- INDEX: 16,
52
- VERTEX: 32,
53
- UNIFORM: 64,
54
- STORAGE: 128,
55
- INDIRECT: 256,
56
- QUERY_RESOLVE: 512
57
- };
58
- // @ts-ignore
59
- global.GPUMapMode = {
60
- READ: 1,
61
- WRITE: 2
62
- };
63
-
64
- describe("GPUEngine", () => {
65
- test("init() requests adapter and device", async () => {
66
- const gpu = new GPUEngine();
67
- await gpu.init();
68
- expect(navigator.gpu.requestAdapter).toHaveBeenCalled();
69
- expect(mockAdapter.requestDevice).toHaveBeenCalled();
70
- expect(gpu.device).toBeDefined();
71
- });
72
-
73
- test("prepareBuffers() creates GPU buffers", async () => {
74
- const gpu = new GPUEngine();
75
- await gpu.init();
76
-
77
- const weights = new Float32Array([1, 2, 3, 4]);
78
- const biases = new Float32Array([0, 0]);
79
-
80
- gpu.prepareBuffers(2, weights, biases);
81
-
82
- expect(mockDevice.createBuffer).toHaveBeenCalledTimes(5); // W, I, B, O, Uniforms
83
- expect(mockDevice.createBindGroup).toHaveBeenCalled();
84
- });
85
-
86
- test("runTick() dispatches compute shader", async () => {
87
- const gpu = new GPUEngine();
88
- await gpu.init();
89
- gpu.prepareBuffers(2, new Float32Array(4), new Float32Array(2));
90
-
91
- const inputs = new Float32Array([1, 0]);
92
- await gpu.runTick(inputs);
93
-
94
- expect(mockDevice.queue.writeBuffer).toHaveBeenCalled();
95
- expect(mockDevice.createCommandEncoder).toHaveBeenCalled();
96
- // Check dispatch
97
- // We can't easily check the nested mock calls count without storing the mock,
98
- // but if no error threw, the flow worked.
99
- });
100
-
101
- test("prepareBuffers() and runTick() with Batch Size > 1", async () => {
102
- const gpu = new GPUEngine();
103
- await gpu.init();
104
-
105
- const N = 4;
106
- const B = 2; // Batch Size 2
107
-
108
- gpu.prepareBuffers(N, new Float32Array(N), new Float32Array(N), B);
109
-
110
- // Input size = N * B = 8
111
- const inputs = new Float32Array(N * B);
112
- await gpu.runTick(inputs);
113
-
114
- expect(mockDevice.createBuffer).toHaveBeenCalled();
115
- expect(mockDevice.queue.writeBuffer).toHaveBeenCalled();
116
-
117
- // Verify buffer size in mock? Mock doesn't store state.
118
- // But verifying it runs without throwing "Input size mismatch" proves validation worked.
119
- });
120
- });
package/src/engine/gpu.ts DELETED
@@ -1,266 +0,0 @@
1
- import shaderCode from './shaders/brain.wgsl?raw';
2
- import trainingShaderCode from './shaders/training.wgsl?raw';
3
- // import trainingShaderCode from './shaders/training.wgsl?raw'; // Handled in replace block above, this is safety check
4
-
5
- export class GPUEngine {
6
- device: GPUDevice | null = null;
7
- pipeline: GPUComputePipeline | null = null;
8
- bindGroup: GPUBindGroup | null = null;
9
-
10
- // Training Buffers
11
- deltaBuffer: GPUBuffer | null = null;
12
- targetBuffer: GPUBuffer | null = null;
13
- paramBuffer: GPUBuffer | null = null;
14
-
15
- trainingPipeline: GPUComputePipeline | null = null;
16
- deltaPipeline: GPUComputePipeline | null = null;
17
- trainingBindGroup: GPUBindGroup | null = null;
18
-
19
- // Buffers
20
- weightBuffer: GPUBuffer | null = null;
21
- inputBuffer: GPUBuffer | null = null;
22
- biasBuffer: GPUBuffer | null = null;
23
- outputBuffer: GPUBuffer | null = null;
24
- uniformBuffer: GPUBuffer | null = null;
25
-
26
- networkSize: number = 0;
27
- batchSize: number = 1;
28
-
29
- async init() {
30
- if (!navigator.gpu) throw new Error("WebGPU not supported");
31
- const adapter = await navigator.gpu.requestAdapter();
32
- if (!adapter) throw new Error("No GPU adapter found");
33
- this.device = await adapter.requestDevice();
34
-
35
- const shaderModule = this.device.createShaderModule({ code: shaderCode });
36
- const trainingModule = this.device.createShaderModule({ code: trainingShaderCode });
37
-
38
- this.pipeline = this.device.createComputePipeline({
39
- layout: 'auto',
40
- compute: { module: shaderModule, entryPoint: 'main' }
41
- });
42
-
43
- this.trainingPipeline = this.device.createComputePipeline({
44
- layout: 'auto',
45
- compute: { module: trainingModule, entryPoint: 'update_weights' }
46
- });
47
-
48
- this.deltaPipeline = this.device.createComputePipeline({
49
- layout: 'auto',
50
- compute: { module: trainingModule, entryPoint: 'calculate_deltas' }
51
- });
52
-
53
- console.log("GPUEngine initialized");
54
- }
55
-
56
- // Prepare buffers based on network size (N) and Batch Size (B)
57
- prepareBuffers(size: number, weights: Float32Array, biases: Float32Array, batchSize: number = 1) {
58
- if (!this.device || !this.pipeline) throw new Error("GPUEngine not initialized");
59
- this.networkSize = size;
60
- this.batchSize = batchSize;
61
-
62
- // Create Buffers
63
- // Weights & Biases are shared (Size N or N*N)
64
- this.weightBuffer = this.createBuffer(weights, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST);
65
- this.biasBuffer = this.createBuffer(biases, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST);
66
-
67
- // Inputs & Outputs are Batched (Size N * B)
68
- const batchedSize = size * batchSize;
69
- this.inputBuffer = this.createBuffer(new Float32Array(batchedSize), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST);
70
- this.outputBuffer = this.createBuffer(new Float32Array(batchedSize), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST);
71
-
72
- // Dimensions Uniform: [Size, BatchSize]
73
- const dimArray = new Uint32Array([size, batchSize]);
74
- this.uniformBuffer = this.createBuffer(dimArray, GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST);
75
-
76
- // Bind Group
77
- this.bindGroup = this.device.createBindGroup({
78
- layout: this.pipeline.getBindGroupLayout(0),
79
- entries: [
80
- { binding: 0, resource: { buffer: this.weightBuffer } },
81
- { binding: 1, resource: { buffer: this.inputBuffer } },
82
- { binding: 2, resource: { buffer: this.biasBuffer } },
83
- { binding: 3, resource: { buffer: this.outputBuffer } },
84
- { binding: 4, resource: { buffer: this.uniformBuffer } },
85
- ]
86
- });
87
- }
88
-
89
- private createBuffer(data: Float32Array | Uint32Array, usage: number): GPUBuffer {
90
- if (!this.device) throw new Error("Device null");
91
- const buffer = this.device.createBuffer({
92
- size: data.byteLength,
93
- usage: usage,
94
- mappedAtCreation: true
95
- });
96
- if (data instanceof Float32Array) {
97
- new Float32Array(buffer.getMappedRange()).set(data);
98
- } else {
99
- new Uint32Array(buffer.getMappedRange()).set(data);
100
- }
101
- buffer.unmap();
102
- return buffer;
103
- }
104
-
105
- async runTick(inputs: Float32Array): Promise<Float32Array> {
106
- if (!this.device || !this.pipeline || !this.bindGroup || !this.inputBuffer || !this.outputBuffer) {
107
- throw new Error("GPU buffers not ready");
108
- }
109
-
110
- if (inputs.length !== this.networkSize * this.batchSize) {
111
- throw new Error(`Input size mismatch. Expected ${this.networkSize * this.batchSize}, got ${inputs.length}`);
112
- }
113
-
114
- // Upload Input
115
- this.device.queue.writeBuffer(this.inputBuffer, 0, inputs as BufferSource);
116
-
117
- // Encode Command
118
- const commandEncoder = this.device.createCommandEncoder();
119
- const passEncoder = commandEncoder.beginComputePass();
120
- passEncoder.setPipeline(this.pipeline);
121
- passEncoder.setBindGroup(0, this.bindGroup);
122
-
123
- // Dispatch (Size / WorkgroupSize, 1, BatchSize)
124
- const workgroupSize = 64;
125
- const workgroupCount = Math.ceil(this.networkSize / workgroupSize);
126
- passEncoder.dispatchWorkgroups(workgroupCount, 1, this.batchSize);
127
- passEncoder.end();
128
-
129
- // Read Output
130
- const size = inputs.byteLength;
131
- const gpuReadBuffer = this.device.createBuffer({
132
- size: size,
133
- usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
134
- });
135
-
136
- commandEncoder.copyBufferToBuffer(this.outputBuffer, 0, gpuReadBuffer, 0, size);
137
-
138
- const gpuCommands = commandEncoder.finish();
139
- this.device.queue.submit([gpuCommands]);
140
-
141
- await gpuReadBuffer.mapAsync(GPUMapMode.READ);
142
- const result = new Float32Array(gpuReadBuffer.getMappedRange());
143
- const output = new Float32Array(result); // Copy
144
- gpuReadBuffer.unmap();
145
-
146
- return output;
147
- }
148
-
149
- prepareTrainingBuffers(targets: Float32Array, learningRate: number) {
150
- if (!this.device || !this.trainingPipeline || !this.weightBuffer || !this.outputBuffer || !this.biasBuffer || !this.uniformBuffer) {
151
- throw new Error("GPU not ready for training");
152
- }
153
-
154
- if (targets.length !== this.networkSize * this.batchSize) {
155
- throw new Error(`Target size mismatch. Expected ${this.networkSize * this.batchSize}, got ${targets.length}`);
156
- }
157
-
158
- // Deltas & Targets are Batched (Size N * B)
159
- const batchedSize = this.networkSize * this.batchSize;
160
- this.deltaBuffer = this.createBuffer(new Float32Array(batchedSize), GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC);
161
- this.targetBuffer = this.createBuffer(targets, GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST);
162
- this.paramBuffer = this.createBuffer(new Float32Array([learningRate]), GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST);
163
-
164
- this.trainingBindGroup = this.device.createBindGroup({
165
- layout: this.trainingPipeline.getBindGroupLayout(0),
166
- entries: [
167
- { binding: 0, resource: { buffer: this.weightBuffer } },
168
- { binding: 1, resource: { buffer: this.outputBuffer } },
169
- { binding: 2, resource: { buffer: this.biasBuffer } },
170
- { binding: 3, resource: { buffer: this.deltaBuffer } },
171
- { binding: 4, resource: { buffer: this.targetBuffer } },
172
- { binding: 5, resource: { buffer: this.uniformBuffer } },
173
- { binding: 6, resource: { buffer: this.paramBuffer } }
174
- ]
175
- });
176
- }
177
-
178
- private subscribers: ((event: { type: 'loss' | 'epoch', value: number }) => void)[] = [];
179
-
180
- subscribe(callback: (event: { type: 'loss' | 'epoch', value: number }) => void) {
181
- this.subscribers.push(callback);
182
- return () => {
183
- this.subscribers = this.subscribers.filter(s => s !== callback);
184
- };
185
- }
186
-
187
- private emit(event: { type: 'loss' | 'epoch', value: number }) {
188
- this.subscribers.forEach(cb => cb(event));
189
- }
190
-
191
- async train(inputs: Float32Array, targets: Float32Array): Promise<Float32Array> {
192
- // 1. Forward Pass
193
- const outputs = await this.runTick(inputs);
194
-
195
- // 2. Calculate Loss (MSE) on CPU for UI Feedback
196
- // Only feasible if batch size is small or we sample.
197
- // For demo, we just calc full MSE.
198
- let totalLoss = 0;
199
- for (let i = 0; i < outputs.length; i++) {
200
- // Only if target is valid? Assuming targets cover all neurons logic as per shader
201
- const t = targets[i];
202
- if (t > -998) {
203
- const diff = outputs[i] - t;
204
- totalLoss += 0.5 * diff * diff;
205
- }
206
- }
207
- const meanLoss = totalLoss / this.batchSize; // Approx
208
- this.emit({ type: 'loss', value: meanLoss });
209
-
210
- // 3. Backward Pass
211
- // Ensure buffers (deltas, targets) are ready?
212
- // Reuse prepareTrainingBuffers or assume already called?
213
- // Let's assume prepareTrainingBuffers was called ONCE before loop.
214
- // We just need to update TARGETS buffer!
215
- if (this.targetBuffer) {
216
- this.device?.queue.writeBuffer(this.targetBuffer, 0, targets as BufferSource);
217
- }
218
-
219
- // Run Training Shaders
220
- await this.trainTick();
221
-
222
- this.emit({ type: 'epoch', value: 1 }); // Just tick count really
223
- return outputs;
224
- }
225
-
226
- async trainTick(deltas?: Float32Array): Promise<void> {
227
- if (!this.device || !this.trainingPipeline || !this.deltaPipeline || !this.trainingBindGroup || !this.deltaBuffer) {
228
- throw new Error("Training not ready");
229
- }
230
-
231
- if (deltas && deltas.length > 0) {
232
- this.device.queue.writeBuffer(this.deltaBuffer, 0, deltas as BufferSource);
233
- }
234
-
235
- const commandEncoder = this.device.createCommandEncoder();
236
- const passEncoder = commandEncoder.beginComputePass();
237
-
238
- // Pass 1: Calculate Deltas (Batched)
239
- passEncoder.setPipeline(this.deltaPipeline);
240
- passEncoder.setBindGroup(0, this.trainingBindGroup);
241
- const workgroupSize = 64;
242
- const workgroupCount = Math.ceil(this.networkSize / workgroupSize);
243
- passEncoder.dispatchWorkgroups(workgroupCount, 1, this.batchSize);
244
-
245
- passEncoder.end();
246
-
247
- const updatePass = commandEncoder.beginComputePass();
248
- updatePass.setPipeline(this.trainingPipeline);
249
- updatePass.setBindGroup(0, this.trainingBindGroup); // Re-bind for new pass
250
- updatePass.dispatchWorkgroups(workgroupCount, 1, 1); // Not batched
251
- updatePass.end();
252
-
253
- this.device.queue.submit([commandEncoder.finish()]);
254
- }
255
-
256
- async injectInput(data: Float32Array): Promise<void> {
257
- if (!this.device || !this.inputBuffer) return;
258
-
259
- // We only write what we are given, usually just the first N inputs (Microphone bins)
260
- // If data is smaller than buffer, we use queue.writeBuffer which handles partial writes
261
- this.device.queue.writeBuffer(this.inputBuffer, 0, data as BufferSource);
262
-
263
- // Trigger a tick? Or let the outer loop do it?
264
- // Let's just update the buffer. The UI loop calls runTick() or similar.
265
- }
266
- }