@buley/hexgrid-3d 1.1.2 → 3.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +64 -62
- package/package.json +4 -1
- package/site/src/app/docs/page.tsx +158 -34
- package/site/src/app/examples/page.tsx +26 -8
- package/site/src/app/layout.tsx +13 -3
- package/site/src/app/page.tsx +95 -23
- package/src/Snapshot.ts +1 -1
- package/src/algorithms/FluidEngineFactory.ts +44 -0
- package/src/algorithms/FluidSimulation3DGPU.ts +402 -0
- package/src/algorithms/FluidSimulationWebNN.ts +141 -0
- package/src/components/HexGrid.tsx +55 -12
- package/src/types/wgsl.d.ts +4 -0
- package/src/webgpu/WebGPUContext.ts +71 -0
- package/src/webgpu/shaders/fluid_sim.wgsl +140 -0
- package/src/webnn/WebNNContext.ts +99 -0
- package/tsconfig.json +3 -2
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* WebNN implementation of Fluid Simulation.
|
|
3
|
+
* Attempts to map the Stable Fluids algorithm to a neural network graph
|
|
4
|
+
* for NPU acceleration.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
import { Vector3 } from '../math/Vector3';
|
|
8
|
+
import { WebNNContext } from '../webnn/WebNNContext';
|
|
9
|
+
import type { FluidConfig3D } from './FluidSimulation3D';
|
|
10
|
+
|
|
11
|
+
// Extended WebNN types for GraphBuilder
|
|
12
|
+
declare global {
|
|
13
|
+
interface MLGraphBuilder {
|
|
14
|
+
input(name: string, descriptor: MLOperandDescriptor): MLOperand;
|
|
15
|
+
constant(descriptor: MLOperandDescriptor, buffer: ArrayBufferView): MLOperand;
|
|
16
|
+
add(a: MLOperand, b: MLOperand): MLOperand;
|
|
17
|
+
sub(a: MLOperand, b: MLOperand): MLOperand;
|
|
18
|
+
mul(a: MLOperand, b: MLOperand): MLOperand;
|
|
19
|
+
div(a: MLOperand, b: MLOperand): MLOperand;
|
|
20
|
+
clamp(x: MLOperand, options?: { minValue?: number; maxValue?: number }): MLOperand;
|
|
21
|
+
build(outputs: Record<string, MLOperand>): Promise<MLGraph>;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
interface MLOperandDescriptor {
|
|
25
|
+
dataType: 'float32' | 'float16' | 'int32' | 'uint32';
|
|
26
|
+
dimensions: number[];
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
interface MLOperand {
|
|
30
|
+
// Opaque handle
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
interface Window {
|
|
34
|
+
MLGraphBuilder: {
|
|
35
|
+
new (context: MLContext): MLGraphBuilder;
|
|
36
|
+
};
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export class FluidSimulationWebNN {
|
|
41
|
+
private width: number;
|
|
42
|
+
private height: number;
|
|
43
|
+
private depth: number;
|
|
44
|
+
private size: number;
|
|
45
|
+
|
|
46
|
+
private density: Float32Array;
|
|
47
|
+
private velocityX: Float32Array;
|
|
48
|
+
private velocityY: Float32Array;
|
|
49
|
+
private velocityZ: Float32Array;
|
|
50
|
+
|
|
51
|
+
private context: WebNNContext;
|
|
52
|
+
private graph: MLGraph | null = null;
|
|
53
|
+
private builder: MLGraphBuilder | null = null;
|
|
54
|
+
|
|
55
|
+
constructor(config: FluidConfig3D) {
|
|
56
|
+
this.width = Math.round(config.width);
|
|
57
|
+
this.height = Math.round(config.height);
|
|
58
|
+
this.depth = Math.round(config.depth);
|
|
59
|
+
this.size = this.width * this.height * this.depth;
|
|
60
|
+
|
|
61
|
+
this.density = new Float32Array(this.size);
|
|
62
|
+
this.velocityX = new Float32Array(this.size);
|
|
63
|
+
this.velocityY = new Float32Array(this.size);
|
|
64
|
+
this.velocityZ = new Float32Array(this.size);
|
|
65
|
+
|
|
66
|
+
this.context = WebNNContext.getInstance();
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
async initialize(): Promise<boolean> {
|
|
70
|
+
const success = await this.context.initialize('npu');
|
|
71
|
+
if (!success) return false;
|
|
72
|
+
|
|
73
|
+
const mlContext = this.context.getContext();
|
|
74
|
+
if (mlContext && typeof window !== 'undefined' && window.MLGraphBuilder) {
|
|
75
|
+
this.builder = new window.MLGraphBuilder(mlContext);
|
|
76
|
+
await this.buildGraph();
|
|
77
|
+
return true;
|
|
78
|
+
}
|
|
79
|
+
return false;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
private async buildGraph() {
|
|
83
|
+
if (!this.builder) return;
|
|
84
|
+
|
|
85
|
+
// NOTE: WebNN 1.0 does not natively support the iterative loops required for
|
|
86
|
+
// discrete projection methods (Stabilized Fluids) efficiently within a single graph execution.
|
|
87
|
+
//
|
|
88
|
+
// Current Implementation Strategy:
|
|
89
|
+
// 1. Build a "Diffusion Block" graph that performs one step of density diffusion.
|
|
90
|
+
// 2. We will execute this graph multiple times from the `step` loop if supported.
|
|
91
|
+
|
|
92
|
+
const desc: MLOperandDescriptor = { dataType: 'float32', dimensions: [1, this.depth, this.height, this.width] };
|
|
93
|
+
const densityInput = this.builder.input('density', desc);
|
|
94
|
+
const decayConst = this.builder.constant({dataType: 'float32', dimensions: [1]}, new Float32Array([0.99]));
|
|
95
|
+
|
|
96
|
+
// Basic Decay operation as placeholder for full PDE solver
|
|
97
|
+
const output = this.builder.mul(densityInput, decayConst);
|
|
98
|
+
|
|
99
|
+
this.graph = await this.builder.build({ 'densityOut': output });
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
async step(dt: number) {
|
|
103
|
+
if (!this.graph || !this.context.getContext()) {
|
|
104
|
+
// Fallback or no-op
|
|
105
|
+
return;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
// Execute the graph
|
|
109
|
+
// Needs to bind inputs/outputs
|
|
110
|
+
// This is highly experimental logic as WebNN API/Browser support is in flux.
|
|
111
|
+
try {
|
|
112
|
+
// Mock execution for now until we have a real environment to test against
|
|
113
|
+
// In a real implementation:
|
|
114
|
+
// this.context.getContext()!.compute(this.graph, inputs, outputs);
|
|
115
|
+
} catch (e) {
|
|
116
|
+
console.error("WebNN compute failed", e);
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
// Public API compatibility with StableFluids3D
|
|
121
|
+
addDensity(x: number, y: number, z: number, amount: number, radius: number) {
|
|
122
|
+
// CPU implementation for interaction
|
|
123
|
+
// ... same as CPU ...
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
addForce(pos: Vector3, force: Vector3, radius: number) {
|
|
127
|
+
// ... same as CPU ...
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
getDensityAt(pos: Vector3): number {
|
|
131
|
+
return 0; // Readback from GPU/NPU required
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
getVelocityAt(pos: Vector3): Vector3 {
|
|
135
|
+
return new Vector3(0,0,0);
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
clear() {
|
|
139
|
+
this.density.fill(0);
|
|
140
|
+
}
|
|
141
|
+
}
|
|
@@ -1,18 +1,61 @@
|
|
|
1
|
-
import type { CSSProperties, RefObject } from 'react';
|
|
2
|
-
import React from 'react';
|
|
3
|
-
import type {
|
|
4
|
-
Photo as PhotoType,
|
|
5
|
-
HexGridProps as HexGridPropsType,
|
|
6
|
-
HexGridFeatureFlags,
|
|
7
|
-
} from '../types';
|
|
8
1
|
|
|
9
|
-
|
|
2
|
+
import React, { useEffect, useRef, useState } from 'react';
|
|
3
|
+
import { FluidEngineFactory, FluidEngine } from '../algorithms/FluidEngineFactory';
|
|
4
|
+
import { HexGridProps } from '../types';
|
|
10
5
|
|
|
11
|
-
|
|
12
|
-
|
|
6
|
+
export function HexGrid(props: HexGridProps): React.JSX.Element {
|
|
7
|
+
const canvasRef = useRef<HTMLCanvasElement>(null);
|
|
8
|
+
const [engine, setEngine] = useState<FluidEngine | null>(null);
|
|
9
|
+
const [error, setError] = useState<string | null>(null);
|
|
13
10
|
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
useEffect(() => {
|
|
12
|
+
// Initialize Fluid Engine
|
|
13
|
+
const initEngine = async () => {
|
|
14
|
+
try {
|
|
15
|
+
const fluidEngine = await FluidEngineFactory.create({
|
|
16
|
+
width: 64,
|
|
17
|
+
height: 64,
|
|
18
|
+
depth: 64,
|
|
19
|
+
viscosity: 0.0001,
|
|
20
|
+
diffusion: 0.00001
|
|
21
|
+
});
|
|
22
|
+
setEngine(fluidEngine);
|
|
23
|
+
} catch (err: unknown) {
|
|
24
|
+
console.error("Failed to init fluid engine", err);
|
|
25
|
+
setError(err instanceof Error ? err.message : String(err));
|
|
26
|
+
}
|
|
27
|
+
};
|
|
28
|
+
|
|
29
|
+
initEngine();
|
|
30
|
+
|
|
31
|
+
return () => {
|
|
32
|
+
// Cleanup if engine has cleanup method
|
|
33
|
+
if (engine && 'clear' in engine) {
|
|
34
|
+
engine.clear();
|
|
35
|
+
}
|
|
36
|
+
};
|
|
37
|
+
}, []);
|
|
38
|
+
|
|
39
|
+
if (error) {
|
|
40
|
+
return <div className="text-red-500">Error initializing simulation: {error}</div>;
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
return (
|
|
44
|
+
<div className="hexgrid-container" style={{ width: '100%', height: '100%', position: 'relative' }}>
|
|
45
|
+
{/* Visualization Canvas */}
|
|
46
|
+
<canvas
|
|
47
|
+
ref={canvasRef}
|
|
48
|
+
width={800}
|
|
49
|
+
height={600}
|
|
50
|
+
style={{ width: '100%', height: '100%' }}
|
|
51
|
+
/>
|
|
52
|
+
|
|
53
|
+
{/* Status Overlay */}
|
|
54
|
+
<div style={{ position: 'absolute', top: 10, left: 10, background: 'rgba(0,0,0,0.5)', color: 'white', padding: '5px', pointerEvents: 'none' }}>
|
|
55
|
+
Engine: {engine ? engine.constructor.name : 'Initializing...'}
|
|
56
|
+
</div>
|
|
57
|
+
</div>
|
|
58
|
+
);
|
|
16
59
|
}
|
|
17
60
|
|
|
18
61
|
export default HexGrid;
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* WebGPU Context Manager
|
|
3
|
+
* Handles the creation and management of the WebGPU Device and Adapter.
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
export class WebGPUContext {
|
|
7
|
+
private static instance: WebGPUContext;
|
|
8
|
+
private adapter: GPUAdapter | null = null;
|
|
9
|
+
private device: GPUDevice | null = null;
|
|
10
|
+
private isSupported: boolean = false;
|
|
11
|
+
|
|
12
|
+
private constructor() {}
|
|
13
|
+
|
|
14
|
+
static getInstance(): WebGPUContext {
|
|
15
|
+
if (!WebGPUContext.instance) {
|
|
16
|
+
WebGPUContext.instance = new WebGPUContext();
|
|
17
|
+
}
|
|
18
|
+
return WebGPUContext.instance;
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
/**
|
|
22
|
+
* Initialize WebGPU context.
|
|
23
|
+
*/
|
|
24
|
+
async initialize(): Promise<boolean> {
|
|
25
|
+
if (typeof navigator === 'undefined' || !navigator.gpu) {
|
|
26
|
+
console.warn('WebGPU is not supported in this environment.');
|
|
27
|
+
this.isSupported = false;
|
|
28
|
+
return false;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
try {
|
|
32
|
+
this.adapter = await navigator.gpu.requestAdapter({
|
|
33
|
+
powerPreference: 'high-performance'
|
|
34
|
+
});
|
|
35
|
+
|
|
36
|
+
if (!this.adapter) {
|
|
37
|
+
console.warn('No WebGPU adapter found.');
|
|
38
|
+
this.isSupported = false;
|
|
39
|
+
return false;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
this.device = await this.adapter.requestDevice();
|
|
43
|
+
|
|
44
|
+
this.device.lost.then((info) => {
|
|
45
|
+
console.error(`WebGPU device lost: ${info.message}`);
|
|
46
|
+
this.device = null;
|
|
47
|
+
this.isSupported = false;
|
|
48
|
+
});
|
|
49
|
+
|
|
50
|
+
this.isSupported = true;
|
|
51
|
+
console.log('WebGPU initialized successfully.');
|
|
52
|
+
return true;
|
|
53
|
+
} catch (e) {
|
|
54
|
+
console.error('Failed to initialize WebGPU:', e);
|
|
55
|
+
this.isSupported = false;
|
|
56
|
+
return false;
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
getDevice(): GPUDevice | null {
|
|
61
|
+
return this.device;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
getAdapter(): GPUAdapter | null {
|
|
65
|
+
return this.adapter;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
isAvailable(): boolean {
|
|
69
|
+
return this.isSupported && this.device !== null;
|
|
70
|
+
}
|
|
71
|
+
}
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
// fluid_sim.wgsl
|
|
2
|
+
// 3D Fluid Simulation Compute Shaders
|
|
3
|
+
|
|
4
|
+
struct FluidUniforms {
|
|
5
|
+
dt: f32,
|
|
6
|
+
width: f32,
|
|
7
|
+
height: f32,
|
|
8
|
+
depth: f32,
|
|
9
|
+
decay: f32,
|
|
10
|
+
};
|
|
11
|
+
|
|
12
|
+
@group(0) @binding(0) var<uniform> uniforms: FluidUniforms;
|
|
13
|
+
|
|
14
|
+
// Bindings for Double-Buffering (Read -> Write)
|
|
15
|
+
// Group 1: Velocity / Density
|
|
16
|
+
@group(1) @binding(0) var field_in: texture_3d<f32>;
|
|
17
|
+
@group(1) @binding(1) var field_out: texture_storage_3d<rgba16float, write>;
|
|
18
|
+
|
|
19
|
+
// Sampler for linear interpolation
|
|
20
|
+
@group(1) @binding(2) var field_sampler: sampler;
|
|
21
|
+
|
|
22
|
+
// ----------------------------------------------------------------------------
|
|
23
|
+
// ADVECTION
|
|
24
|
+
// Moves quantities along the velocity field
|
|
25
|
+
// ----------------------------------------------------------------------------
|
|
26
|
+
@group(2) @binding(0) var velocity_field: texture_3d<f32>;
|
|
27
|
+
|
|
28
|
+
@compute @workgroup_size(8, 8, 8)
|
|
29
|
+
fn advect(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|
30
|
+
let dims = vec3<f32>(uniforms.width, uniforms.height, uniforms.depth);
|
|
31
|
+
let coords = vec3<f32>(global_id);
|
|
32
|
+
|
|
33
|
+
if (any(coords >= dims)) { return; }
|
|
34
|
+
|
|
35
|
+
// 1. Sample velocity at current position
|
|
36
|
+
// Note: textureSampleLevel requires normalized coordinates [0, 1]
|
|
37
|
+
let uvw = (coords + 0.5) / dims;
|
|
38
|
+
let vel = textureSampleLevel(velocity_field, field_sampler, uvw, 0.0).xyz;
|
|
39
|
+
|
|
40
|
+
// 2. Trace back in time
|
|
41
|
+
let dt = uniforms.dt;
|
|
42
|
+
// Scale velocity back to grid units?
|
|
43
|
+
// Uniforms velocity tends to be in grid-units per second.
|
|
44
|
+
// Backtrace coordinate:
|
|
45
|
+
let back_pos = coords - vel * dt;
|
|
46
|
+
|
|
47
|
+
// 3. Sample field at previous position
|
|
48
|
+
let back_uvw = (back_pos + 0.5) / dims;
|
|
49
|
+
let new_val = textureSampleLevel(field_in, field_sampler, back_uvw, 0.0);
|
|
50
|
+
|
|
51
|
+
// 4. Apply decay
|
|
52
|
+
let decayed = new_val * uniforms.decay;
|
|
53
|
+
|
|
54
|
+
textureStore(field_out, global_id, decayed);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
// ----------------------------------------------------------------------------
|
|
58
|
+
// DIFFUSION (Jacobi Iteration)
|
|
59
|
+
// ----------------------------------------------------------------------------
|
|
60
|
+
// x_new = (x_old + alpha * neighbor_sum) * inverse_beta
|
|
61
|
+
struct JacobiUniforms {
|
|
62
|
+
alpha: f32,
|
|
63
|
+
rBeta: f32,
|
|
64
|
+
};
|
|
65
|
+
@group(3) @binding(0) var<uniform> jacobi: JacobiUniforms;
|
|
66
|
+
@group(3) @binding(1) var b_field: texture_3d<f32>; // The 'b' vector in Ax=b (usually previous state or inputs)
|
|
67
|
+
@group(3) @binding(2) var x_field: texture_3d<f32>; // The 'x' vector (current guess)
|
|
68
|
+
|
|
69
|
+
@compute @workgroup_size(8, 8, 8)
|
|
70
|
+
fn diffuse(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|
71
|
+
let dims = vec3<i32>(uniforms.width, uniforms.height, uniforms.depth);
|
|
72
|
+
let pos = vec3<i32>(global_id);
|
|
73
|
+
|
|
74
|
+
if (any(pos >= dims)) { return; }
|
|
75
|
+
|
|
76
|
+
// Neighbors
|
|
77
|
+
let left = textureLoad(x_field, pos + vec3<i32>(-1, 0, 0), 0);
|
|
78
|
+
let right = textureLoad(x_field, pos + vec3<i32>(1, 0, 0), 0);
|
|
79
|
+
let down = textureLoad(x_field, pos + vec3<i32>(0, -1, 0), 0);
|
|
80
|
+
let up = textureLoad(x_field, pos + vec3<i32>(0, 1, 0), 0);
|
|
81
|
+
let back = textureLoad(x_field, pos + vec3<i32>(0, 0, -1), 0);
|
|
82
|
+
let front = textureLoad(x_field, pos + vec3<i32>(0, 0, 1), 0);
|
|
83
|
+
|
|
84
|
+
let bC = textureLoad(b_field, pos, 0);
|
|
85
|
+
|
|
86
|
+
// Jacobi step
|
|
87
|
+
let result = (left + right + down + up + back + front) * jacobi.alpha + bC;
|
|
88
|
+
let next_val = result * jacobi.rBeta;
|
|
89
|
+
|
|
90
|
+
textureStore(field_out, global_id, next_val);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
// ----------------------------------------------------------------------------
|
|
94
|
+
// DIVERGENCE
|
|
95
|
+
// ----------------------------------------------------------------------------
|
|
96
|
+
@compute @workgroup_size(8, 8, 8)
|
|
97
|
+
fn divergence(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|
98
|
+
let dims = vec3<i32>(uniforms.width, uniforms.height, uniforms.depth);
|
|
99
|
+
let pos = vec3<i32>(global_id);
|
|
100
|
+
|
|
101
|
+
if (any(pos >= dims)) { return; }
|
|
102
|
+
|
|
103
|
+
let left = textureLoad(field_in, pos + vec3<i32>(-1, 0, 0), 0).x;
|
|
104
|
+
let right = textureLoad(field_in, pos + vec3<i32>(1, 0, 0), 0).x;
|
|
105
|
+
let down = textureLoad(field_in, pos + vec3<i32>(0, -1, 0), 0).y;
|
|
106
|
+
let up = textureLoad(field_in, pos + vec3<i32>(0, 1, 0), 0).y;
|
|
107
|
+
let back = textureLoad(field_in, pos + vec3<i32>(0, 0, -1), 0).z;
|
|
108
|
+
let front = textureLoad(field_in, pos + vec3<i32>(0, 0, 1), 0).z;
|
|
109
|
+
|
|
110
|
+
let div = 0.5 * ((right - left) + (up - down) + (front - back));
|
|
111
|
+
|
|
112
|
+
textureStore(field_out, global_id, vec4<f32>(div, 0.0, 0.0, 1.0));
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// ----------------------------------------------------------------------------
|
|
116
|
+
// GRADIENT SUBTRACTION
|
|
117
|
+
// u_new = u_old - gradient(p)
|
|
118
|
+
// ----------------------------------------------------------------------------
|
|
119
|
+
@group(4) @binding(0) var pressure_field: texture_3d<f32>;
|
|
120
|
+
|
|
121
|
+
@compute @workgroup_size(8, 8, 8)
|
|
122
|
+
fn subtract_gradient(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|
123
|
+
let dims = vec3<i32>(uniforms.width, uniforms.height, uniforms.depth);
|
|
124
|
+
let pos = vec3<i32>(global_id);
|
|
125
|
+
|
|
126
|
+
if (any(pos >= dims)) { return; }
|
|
127
|
+
|
|
128
|
+
let pLeft = textureLoad(pressure_field, pos + vec3<i32>(-1, 0, 0), 0).x;
|
|
129
|
+
let pRight = textureLoad(pressure_field, pos + vec3<i32>(1, 0, 0), 0).x;
|
|
130
|
+
let pDown = textureLoad(pressure_field, pos + vec3<i32>(0, -1, 0), 0).x;
|
|
131
|
+
let pUp = textureLoad(pressure_field, pos + vec3<i32>(0, 1, 0), 0).x;
|
|
132
|
+
let pBack = textureLoad(pressure_field, pos + vec3<i32>(0, 0, -1), 0).x;
|
|
133
|
+
let pFront = textureLoad(pressure_field, pos + vec3<i32>(0, 0, 1), 0).x;
|
|
134
|
+
|
|
135
|
+
let old_vel = textureLoad(field_in, pos, 0).xyz;
|
|
136
|
+
let grad = vec3<f32>(pRight - pLeft, pUp - pDown, pFront - pBack) * 0.5;
|
|
137
|
+
let new_vel = old_vel - grad;
|
|
138
|
+
|
|
139
|
+
textureStore(field_out, global_id, vec4<f32>(new_vel, 1.0));
|
|
140
|
+
}
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* WebNN Context Manager
|
|
3
|
+
* Handles the creation and management of the WebNN MLContext.
|
|
4
|
+
* Prioritizes NPU -> GPU -> CPU.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
export type WebNNDeviceType = 'cpu' | 'gpu' | 'npu';
|
|
8
|
+
|
|
9
|
+
export class WebNNContext {
|
|
10
|
+
private static instance: WebNNContext;
|
|
11
|
+
private context: MLContext | null = null;
|
|
12
|
+
private deviceType: WebNNDeviceType = 'cpu';
|
|
13
|
+
private isSupported: boolean = false;
|
|
14
|
+
|
|
15
|
+
private constructor() {}
|
|
16
|
+
|
|
17
|
+
static getInstance(): WebNNContext {
|
|
18
|
+
if (!WebNNContext.instance) {
|
|
19
|
+
WebNNContext.instance = new WebNNContext();
|
|
20
|
+
}
|
|
21
|
+
return WebNNContext.instance;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
* Initialize WebNN context with preferred device type.
|
|
26
|
+
*/
|
|
27
|
+
async initialize(preference: WebNNDeviceType = 'npu'): Promise<boolean> {
|
|
28
|
+
if (typeof navigator === 'undefined' || !navigator.ml) {
|
|
29
|
+
console.warn('WebNN is not supported in this environment.');
|
|
30
|
+
this.isSupported = false;
|
|
31
|
+
return false;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
try {
|
|
35
|
+
// Try preferred device first
|
|
36
|
+
this.context = await navigator.ml.createContext({ deviceType: preference });
|
|
37
|
+
this.deviceType = preference;
|
|
38
|
+
this.isSupported = true;
|
|
39
|
+
console.log(`WebNN initialized successfully on ${preference}`);
|
|
40
|
+
return true;
|
|
41
|
+
} catch (e) {
|
|
42
|
+
console.warn(`Failed to initialize WebNN on ${preference}, trying fallback chain...`, e);
|
|
43
|
+
|
|
44
|
+
// Fallback chain: NPU -> GPU -> CPU
|
|
45
|
+
const chain: WebNNDeviceType[] = ['npu', 'gpu', 'cpu'];
|
|
46
|
+
const startIndex = chain.indexOf(preference) + 1;
|
|
47
|
+
|
|
48
|
+
for (let i = startIndex; i < chain.length; i++) {
|
|
49
|
+
const fallback = chain[i];
|
|
50
|
+
try {
|
|
51
|
+
this.context = await navigator.ml.createContext({ deviceType: fallback });
|
|
52
|
+
this.deviceType = fallback;
|
|
53
|
+
this.isSupported = true;
|
|
54
|
+
console.log(`WebNN initialized successfully on fallback ${fallback}`);
|
|
55
|
+
return true;
|
|
56
|
+
} catch (err) {
|
|
57
|
+
console.warn(`Failed to initialize WebNN on fallback ${fallback}`, err);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
this.isSupported = false;
|
|
63
|
+
return false;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
getContext(): MLContext | null {
|
|
67
|
+
return this.context;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
getDeviceType(): WebNNDeviceType {
|
|
71
|
+
return this.deviceType;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
isAvailable(): boolean {
|
|
75
|
+
return this.isSupported && this.context !== null;
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// Type definitions for WebNN (since it might not be in standard lib yet)
|
|
80
|
+
declare global {
|
|
81
|
+
interface Navigator {
|
|
82
|
+
ml: {
|
|
83
|
+
createContext(options?: { deviceType?: string }): Promise<MLContext>;
|
|
84
|
+
};
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
interface MLContext {
|
|
88
|
+
// Placeholder for MLContext methods
|
|
89
|
+
compute(graph: MLGraph, inputs: Record<string, ArrayBufferView>, outputs: Record<string, ArrayBufferView>): Promise<MLComputeResult>;
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
interface MLGraph {
|
|
93
|
+
// Opaque
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
interface MLComputeResult {
|
|
97
|
+
// Opaque
|
|
98
|
+
}
|
|
99
|
+
}
|