opendi-js 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +278 -0
- package/build/opendi.js +2 -0
- package/build/opendi.wasm +0 -0
- package/opendi.d.ts +157 -0
- package/package.json +40 -0
- package/src/opendi.cjs +4 -0
- package/src/opendi.mjs +613 -0
package/README.md
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# opendi-js
|
|
2
|
+
|
|
3
|
+
WebAssembly bindings for the OpenDI math and machine learning library. Compiles the pure C99 engine to WASM via Emscripten and wraps it with an idiomatic JavaScript API.
|
|
4
|
+
|
|
5
|
+
Works in Node.js, Deno, Bun, and browsers.
|
|
6
|
+
|
|
7
|
+
## Features
|
|
8
|
+
|
|
9
|
+
- **Primitive Operations** - add, subtract, multiply, divide, abs, min, max, round, pow
|
|
10
|
+
- **Calculus** - Forward/backward/central difference, second derivative, nth derivative, Romberg integration
|
|
11
|
+
- **Linear Algebra** - Vectors (add, dot, cross, norm, scale) and matrices (multiply, add, scale, transpose)
|
|
12
|
+
- **Activations** - relu, sigmoid, softmax
|
|
13
|
+
- **Loss Functions** - MSE, cross-entropy
|
|
14
|
+
- **Backward Functions** - Gradient computation for activations and matrix operations
|
|
15
|
+
- **Optimizers** - SGD weight updates
|
|
16
|
+
- **Random** - Seeded uniform and normal distributions
|
|
17
|
+
- **Statistics** - Z-score normalization
|
|
18
|
+
- **Pipeline** - Batch activations, dense layers, loss gradients, weight initialization
|
|
19
|
+
- **Session API** - Pointer-level arena control for training loops
|
|
20
|
+
- **Zero Copy Overhead** - WASM heap operations, no serialization
|
|
21
|
+
|
|
22
|
+
## Two-Tier API
|
|
23
|
+
|
|
24
|
+
### Tier 1: Simple Functions
|
|
25
|
+
|
|
26
|
+
Pass JavaScript arrays, get results back. Memory is managed automatically.
|
|
27
|
+
|
|
28
|
+
```js
|
|
29
|
+
import { init } from 'opendi-js'
|
|
30
|
+
|
|
31
|
+
const o = await init()
|
|
32
|
+
|
|
33
|
+
o.add(1, 2, 3) // 6
|
|
34
|
+
o.softmax([1, 2, 3]) // Float64Array [0.09, 0.24, 0.66]
|
|
35
|
+
o.matmul([1,2,3,4], [5,6,7,8], 2, 2, 2) // Float64Array [19,22,43,50]
|
|
36
|
+
o.forwarddiff(x => x * x, 3, 0.001) // ~6.0
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
### Tier 2: Session API
|
|
40
|
+
|
|
41
|
+
For training loops where performance matters. Work with WASM pointers directly, control arena memory manually.
|
|
42
|
+
|
|
43
|
+
```js
|
|
44
|
+
const session = o.createSession(65536)
|
|
45
|
+
|
|
46
|
+
const inputPtr = session.writeArray([0.1, 0.3, 0.7, 0.9])
|
|
47
|
+
const targetPtr = session.writeArray([0, 0, 1, 1])
|
|
48
|
+
|
|
49
|
+
const fwd = session.denseForward(inputPtr, weights._ptr, 4, 1, 1, 'sigmoid')
|
|
50
|
+
const loss = session.mseLoss(fwd.outputPtr, targetPtr, 4)
|
|
51
|
+
|
|
52
|
+
session.clear() // reuse arena memory between epochs
|
|
53
|
+
session.destroy() // free arena when done
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## Quick Start
|
|
57
|
+
|
|
58
|
+
### Install
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
npm install opendi-js
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
### ESM
|
|
65
|
+
|
|
66
|
+
```js
|
|
67
|
+
import { init } from 'opendi-js'
|
|
68
|
+
const o = await init()
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
### CommonJS
|
|
72
|
+
|
|
73
|
+
```js
|
|
74
|
+
const { init } = require('opendi-js')
|
|
75
|
+
const o = await init()
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
## API Reference
|
|
79
|
+
|
|
80
|
+
### Primitives
|
|
81
|
+
|
|
82
|
+
```js
|
|
83
|
+
o.add(1, 2, 3) // 6 (variadic)
|
|
84
|
+
o.subtract(10, 3) // 7 (variadic)
|
|
85
|
+
o.multiply(2, 3, 4) // 24 (variadic)
|
|
86
|
+
o.divide(100, 5, 2) // 10 (variadic)
|
|
87
|
+
o.abs(-5) // 5
|
|
88
|
+
o.min(3, 1, 2) // 1 (variadic)
|
|
89
|
+
o.max(3, 1, 2) // 3 (variadic)
|
|
90
|
+
o.round('floor', 3.7) // 3 ('floor', 'ceil', 'round')
|
|
91
|
+
o.pow(2, 10) // 1024
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
### Activations
|
|
95
|
+
|
|
96
|
+
```js
|
|
97
|
+
o.relu(5) // 5
|
|
98
|
+
o.relu(-5) // 0
|
|
99
|
+
o.sigmoid(0) // 0.5
|
|
100
|
+
o.softmax([1, 2, 3]) // Float64Array (sums to 1.0)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
### Vectors
|
|
104
|
+
|
|
105
|
+
```js
|
|
106
|
+
o.vecadd([1,2,3], [4,5,6]) // Float64Array [5,7,9]
|
|
107
|
+
o.vecscale([1,2,3], 2) // Float64Array [2,4,6]
|
|
108
|
+
o.veccross([1,0,0], [0,1,0]) // Float64Array [0,0,1]
|
|
109
|
+
o.vecdot([1,2,3], [4,5,6]) // 32
|
|
110
|
+
o.vecnorm([3, 4]) // 5
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
### Matrices
|
|
114
|
+
|
|
115
|
+
All matrices are flat arrays in row-major order.
|
|
116
|
+
|
|
117
|
+
```js
|
|
118
|
+
o.matmul([1,2,3,4], [5,6,7,8], 2, 2, 2) // Float64Array [19,22,43,50]
|
|
119
|
+
o.matadd([1,2,3,4], [10,20,30,40], 2, 2) // Float64Array [11,22,33,44]
|
|
120
|
+
o.matscale([1,2,3,4], 3, 2, 2) // Float64Array [3,6,9,12]
|
|
121
|
+
o.mattranspose([1,2,3,4,5,6], 2, 3) // Float64Array [1,4,2,5,3,6]
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
### Loss Functions
|
|
125
|
+
|
|
126
|
+
```js
|
|
127
|
+
o.mseLoss([1,2,3], [1,2,3]) // 0 (perfect prediction)
|
|
128
|
+
o.crossEntropy(pred, targets) // cross-entropy loss
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
### Backward (Gradients)
|
|
132
|
+
|
|
133
|
+
```js
|
|
134
|
+
o.reluBackward(dout, input) // Float64Array
|
|
135
|
+
o.sigmoidBackward(dout, output) // Float64Array
|
|
136
|
+
o.softmaxBackward(dout, output) // Float64Array
|
|
137
|
+
o.matmulBackwardA(dout, b, m, n, p) // Float64Array (gradient w.r.t. A)
|
|
138
|
+
o.matmulBackwardB(a, dout, m, n, p) // Float64Array (gradient w.r.t. B)
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
### Optimizer
|
|
142
|
+
|
|
143
|
+
```js
|
|
144
|
+
o.sgdUpdate([1,2,3], [0.1,0.2,0.3], 1.0) // Float64Array [0.9, 1.8, 2.7]
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
### Random
|
|
148
|
+
|
|
149
|
+
```js
|
|
150
|
+
o.randomSeed(42)
|
|
151
|
+
o.randomUniform(0, 1, 5) // Float64Array of 5 values in [0, 1]
|
|
152
|
+
o.randomNormal(0, 1, 5) // Float64Array of 5 values ~ N(0, 1)
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
### Statistics
|
|
156
|
+
|
|
157
|
+
```js
|
|
158
|
+
o.normalize([2,4,4,4,5,5,7,9]) // Float64Array (mean ~0, std ~1)
|
|
159
|
+
```
|
|
160
|
+
|
|
161
|
+
### Calculus
|
|
162
|
+
|
|
163
|
+
Pass regular JavaScript functions. They are compiled to WASM function pointers internally.
|
|
164
|
+
|
|
165
|
+
```js
|
|
166
|
+
o.forwarddiff(x => x*x, 3, 0.001) // ~6.0
|
|
167
|
+
o.backwarddiff(x => x*x, 3, 0.001) // ~6.0
|
|
168
|
+
o.centralDifference(x => x*x, 3, 0.001) // ~6.0
|
|
169
|
+
o.secondDerivative(x => x*x, 3, 0.001) // ~2.0
|
|
170
|
+
o.centralNth(x => x*x, 3, 0.001, 2) // ~2.0 (nth derivative)
|
|
171
|
+
o.rombergIntegrate(x => x*x, 0, 1, 1e-8, 10) // ~0.333
|
|
172
|
+
```
|
|
173
|
+
|
|
174
|
+
### Pipeline
|
|
175
|
+
|
|
176
|
+
```js
|
|
177
|
+
o.batchRelu([-1, 2, -3, 4]) // Float64Array [0,2,0,4]
|
|
178
|
+
o.batchSigmoid([0, 0]) // Float64Array [0.5, 0.5]
|
|
179
|
+
o.batchSoftmax(input, rows, cols) // Float64Array (per-row softmax)
|
|
180
|
+
o.batchNormalize(features, n, nFeat) // Float64Array (z-score normalized)
|
|
181
|
+
o.mseBackward(pred, targets) // Float64Array (MSE gradient)
|
|
182
|
+
o.crossEntropyBackward(pred, targets, nSamples, nClasses) // Float64Array
|
|
183
|
+
o.accuracy(pred, labels, nSamples, nClasses) // number
|
|
184
|
+
```
|
|
185
|
+
|
|
186
|
+
### Weights
|
|
187
|
+
|
|
188
|
+
```js
|
|
189
|
+
o.randomSeed(42)
|
|
190
|
+
const w = o.initWeights(10, 0.0, 0.1)
|
|
191
|
+
w.data // Float64Array of 10 values
|
|
192
|
+
w._ptr // WASM pointer (for session API)
|
|
193
|
+
w.length // 10
|
|
194
|
+
w.free() // free the malloc'd memory
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
### Session API
|
|
198
|
+
|
|
199
|
+
```js
|
|
200
|
+
const session = o.createSession(65536)
|
|
201
|
+
|
|
202
|
+
session.writeArray([1, 2, 3]) // returns WASM pointer
|
|
203
|
+
session.writeI32Array([0, 1, 2]) // returns WASM pointer (int32)
|
|
204
|
+
session.readArray(ptr, 3) // Float64Array from pointer
|
|
205
|
+
session.copyToPtr(destPtr, srcArray) // write JS array to existing pointer
|
|
206
|
+
|
|
207
|
+
// All math functions available with pointer args:
|
|
208
|
+
session.vecadd(aPtr, bPtr, n)
|
|
209
|
+
session.matmul(aPtr, bPtr, m, n, p)
|
|
210
|
+
session.denseForward(inputPtr, weightsPtr, m, n, p, 'sigmoid')
|
|
211
|
+
session.denseBackward(doutPtr, inputPtr, weightsPtr, cachePtr, m, n, p, 'sigmoid')
|
|
212
|
+
session.sgdUpdate(weightsPtr, gradsPtr, lr, n)
|
|
213
|
+
session.mseLoss(predPtr, targetsPtr, n)
|
|
214
|
+
|
|
215
|
+
session.clear() // arena_clear() - reuse memory
|
|
216
|
+
session.destroy() // arena_destroy() - free arena
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
## Training Loop Example
|
|
220
|
+
|
|
221
|
+
```js
|
|
222
|
+
import { init } from 'opendi-js'
|
|
223
|
+
|
|
224
|
+
const o = await init()
|
|
225
|
+
o.randomSeed(42)
|
|
226
|
+
|
|
227
|
+
const features = [0.1, 0.3, 0.7, 0.9]
|
|
228
|
+
const targets = [0, 0, 1, 1]
|
|
229
|
+
const nSamples = 4, nFeatures = 1
|
|
230
|
+
|
|
231
|
+
const weights = o.initWeights(nFeatures, 0.0, 0.1)
|
|
232
|
+
const session = o.createSession(65536)
|
|
233
|
+
|
|
234
|
+
for (let epoch = 0; epoch < 200; epoch++) {
|
|
235
|
+
const featPtr = session.writeArray(features)
|
|
236
|
+
const targPtr = session.writeArray(targets)
|
|
237
|
+
|
|
238
|
+
const fwd = session.denseForward(featPtr, weights._ptr, nSamples, nFeatures, 1, 'sigmoid')
|
|
239
|
+
const loss = session.mseLoss(fwd.outputPtr, targPtr, nSamples)
|
|
240
|
+
|
|
241
|
+
const dLoss = session.mseBackward(fwd.outputPtr, targPtr, nSamples)
|
|
242
|
+
const grad = session.denseBackward(dLoss, featPtr, weights._ptr, fwd.cachePtr, nSamples, nFeatures, 1, 'sigmoid')
|
|
243
|
+
|
|
244
|
+
const newWPtr = session.sgdUpdate(weights._ptr, grad.dWeightsPtr, 1.0, nFeatures)
|
|
245
|
+
session.copyToPtr(weights._ptr, session.readArray(newWPtr, nFeatures))
|
|
246
|
+
|
|
247
|
+
session.clear()
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
session.destroy()
|
|
251
|
+
weights.free()
|
|
252
|
+
```
|
|
253
|
+
|
|
254
|
+
## Building from Source
|
|
255
|
+
|
|
256
|
+
Requires [Emscripten](https://emscripten.org/).
|
|
257
|
+
|
|
258
|
+
```bash
|
|
259
|
+
cd opendi-js
|
|
260
|
+
bash scripts/build.sh
|
|
261
|
+
```
|
|
262
|
+
|
|
263
|
+
Produces `build/opendi.js` and `build/opendi.wasm`.
|
|
264
|
+
|
|
265
|
+
## Testing
|
|
266
|
+
|
|
267
|
+
```bash
|
|
268
|
+
node test/test.mjs
|
|
269
|
+
```
|
|
270
|
+
|
|
271
|
+
## Documentation
|
|
272
|
+
|
|
273
|
+
API documentation is in `docs/`. Each function group has its own documentation file organized by module.
|
|
274
|
+
|
|
275
|
+
## See Also
|
|
276
|
+
|
|
277
|
+
- [OpenDI C Library](../README.md) - The underlying C99 engine
|
|
278
|
+
- [API Documentation](docs/) - Detailed function reference
|
package/build/opendi.js
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
1
|
+
async function createOpenDI(moduleArg={}){var moduleRtn;var Module=moduleArg;var ENVIRONMENT_IS_WEB=!!globalThis.window;var ENVIRONMENT_IS_WORKER=!!globalThis.WorkerGlobalScope;var ENVIRONMENT_IS_NODE=globalThis.process?.versions?.node&&globalThis.process?.type!="renderer";if(ENVIRONMENT_IS_NODE){const{createRequire}=await import("module");var require=createRequire(import.meta.url)}var arguments_=[];var thisProgram="./this.program";var quit_=(status,toThrow)=>{throw toThrow};var _scriptName=import.meta.url;var scriptDirectory="";function locateFile(path){if(Module["locateFile"]){return Module["locateFile"](path,scriptDirectory)}return scriptDirectory+path}var readAsync,readBinary;if(ENVIRONMENT_IS_NODE){var fs=require("fs");if(_scriptName.startsWith("file:")){scriptDirectory=require("path").dirname(require("url").fileURLToPath(_scriptName))+"/"}readBinary=filename=>{filename=isFileURI(filename)?new URL(filename):filename;var ret=fs.readFileSync(filename);return ret};readAsync=async(filename,binary=true)=>{filename=isFileURI(filename)?new URL(filename):filename;var ret=fs.readFileSync(filename,binary?undefined:"utf8");return ret};if(process.argv.length>1){thisProgram=process.argv[1].replace(/\\/g,"/")}arguments_=process.argv.slice(2);quit_=(status,toThrow)=>{process.exitCode=status;throw toThrow}}else if(ENVIRONMENT_IS_WEB||ENVIRONMENT_IS_WORKER){try{scriptDirectory=new URL(".",_scriptName).href}catch{}{readAsync=async url=>{var response=await fetch(url,{credentials:"same-origin"});if(response.ok){return response.arrayBuffer()}throw new Error(response.status+" : "+response.url)}}}else{}var out=console.log.bind(console);var err=console.error.bind(console);var wasmBinary;var ABORT=false;var isFileURI=filename=>filename.startsWith("file://");var readyPromiseResolve,readyPromiseReject;var HEAP8,HEAPU8,HEAP16,HEAPU16,HEAP32,HEAPU32,HEAPF32,HEAPF64;var HEAP64,HEAPU64;var runtimeInitialized=false;function updateMemoryViews(){var b=wasmMemory.buffer;HEAP8=new Int8Array(b);HEAP16=new Int16Array(b);HEAPU8=new Uint8Array(b);HEAPU16=new Uint16Array(b);HEAP32=new Int32Array(b);HEAPU32=new Uint32Array(b);HEAPF32=new Float32Array(b);HEAPF64=new Float64Array(b);HEAP64=new BigInt64Array(b);HEAPU64=new BigUint64Array(b)}function preRun(){if(Module["preRun"]){if(typeof Module["preRun"]=="function")Module["preRun"]=[Module["preRun"]];while(Module["preRun"].length){addOnPreRun(Module["preRun"].shift())}}callRuntimeCallbacks(onPreRuns)}function initRuntime(){runtimeInitialized=true;wasmExports["__wasm_call_ctors"]()}function postRun(){if(Module["postRun"]){if(typeof Module["postRun"]=="function")Module["postRun"]=[Module["postRun"]];while(Module["postRun"].length){addOnPostRun(Module["postRun"].shift())}}callRuntimeCallbacks(onPostRuns)}function abort(what){Module["onAbort"]?.(what);what="Aborted("+what+")";err(what);ABORT=true;what+=". Build with -sASSERTIONS for more info.";var e=new WebAssembly.RuntimeError(what);readyPromiseReject?.(e);throw e}var wasmBinaryFile;function findWasmBinary(){if(Module["locateFile"]){return locateFile("opendi.wasm")}return new URL("opendi.wasm",import.meta.url).href}function getBinarySync(file){if(file==wasmBinaryFile&&wasmBinary){return new Uint8Array(wasmBinary)}if(readBinary){return readBinary(file)}throw"both async and sync fetching of the wasm failed"}async function getWasmBinary(binaryFile){if(!wasmBinary){try{var response=await readAsync(binaryFile);return new Uint8Array(response)}catch{}}return getBinarySync(binaryFile)}async function instantiateArrayBuffer(binaryFile,imports){try{var binary=await getWasmBinary(binaryFile);var instance=await WebAssembly.instantiate(binary,imports);return instance}catch(reason){err(`failed to asynchronously prepare wasm: ${reason}`);abort(reason)}}async function instantiateAsync(binary,binaryFile,imports){if(!binary&&!ENVIRONMENT_IS_NODE){try{var response=fetch(binaryFile,{credentials:"same-origin"});var instantiationResult=await WebAssembly.instantiateStreaming(response,imports);return instantiationResult}catch(reason){err(`wasm streaming compile failed: ${reason}`);err("falling back to ArrayBuffer instantiation")}}return instantiateArrayBuffer(binaryFile,imports)}function getWasmImports(){var imports={env:wasmImports,wasi_snapshot_preview1:wasmImports};return imports}async function createWasm(){function receiveInstance(instance,module){wasmExports=instance.exports;assignWasmExports(wasmExports);updateMemoryViews();return wasmExports}function receiveInstantiationResult(result){return receiveInstance(result["instance"])}var info=getWasmImports();if(Module["instantiateWasm"]){return new Promise((resolve,reject)=>{Module["instantiateWasm"](info,(inst,mod)=>{resolve(receiveInstance(inst,mod))})})}wasmBinaryFile??=findWasmBinary();var result=await instantiateAsync(wasmBinary,wasmBinaryFile,info);var exports=receiveInstantiationResult(result);return exports}class ExitStatus{name="ExitStatus";constructor(status){this.message=`Program terminated with exit(${status})`;this.status=status}}var callRuntimeCallbacks=callbacks=>{while(callbacks.length>0){callbacks.shift()(Module)}};var onPostRuns=[];var addOnPostRun=cb=>onPostRuns.push(cb);var onPreRuns=[];var addOnPreRun=cb=>onPreRuns.push(cb);function getValue(ptr,type="i8"){if(type.endsWith("*"))type="*";switch(type){case"i1":return HEAP8[ptr];case"i8":return HEAP8[ptr];case"i16":return HEAP16[ptr>>1];case"i32":return HEAP32[ptr>>2];case"i64":return HEAP64[ptr>>3];case"float":return HEAPF32[ptr>>2];case"double":return HEAPF64[ptr>>3];case"*":return HEAPU32[ptr>>2];default:abort(`invalid type for getValue: ${type}`)}}var noExitRuntime=true;function setValue(ptr,value,type="i8"){if(type.endsWith("*"))type="*";switch(type){case"i1":HEAP8[ptr]=value;break;case"i8":HEAP8[ptr]=value;break;case"i16":HEAP16[ptr>>1]=value;break;case"i32":HEAP32[ptr>>2]=value;break;case"i64":HEAP64[ptr>>3]=BigInt(value);break;case"float":HEAPF32[ptr>>2]=value;break;case"double":HEAPF64[ptr>>3]=value;break;case"*":HEAPU32[ptr>>2]=value;break;default:abort(`invalid type for setValue: ${type}`)}}var stackRestore=val=>__emscripten_stack_restore(val);var stackSave=()=>_emscripten_stack_get_current();var getHeapMax=()=>2147483648;var alignMemory=(size,alignment)=>Math.ceil(size/alignment)*alignment;var growMemory=size=>{var oldHeapSize=wasmMemory.buffer.byteLength;var pages=(size-oldHeapSize+65535)/65536|0;try{wasmMemory.grow(pages);updateMemoryViews();return 1}catch(e){}};var _emscripten_resize_heap=requestedSize=>{var oldSize=HEAPU8.length;requestedSize>>>=0;var maxHeapSize=getHeapMax();if(requestedSize>maxHeapSize){return false}for(var cutDown=1;cutDown<=4;cutDown*=2){var overGrownHeapSize=oldSize*(1+.2/cutDown);overGrownHeapSize=Math.min(overGrownHeapSize,requestedSize+100663296);var newSize=Math.min(maxHeapSize,alignMemory(Math.max(requestedSize,overGrownHeapSize),65536));var replacement=growMemory(newSize);if(replacement){return true}}return false};var getCFunc=ident=>{var func=Module["_"+ident];return func};var writeArrayToMemory=(array,buffer)=>{HEAP8.set(array,buffer)};var lengthBytesUTF8=str=>{var len=0;for(var i=0;i<str.length;++i){var c=str.charCodeAt(i);if(c<=127){len++}else if(c<=2047){len+=2}else if(c>=55296&&c<=57343){len+=4;++i}else{len+=3}}return len};var stringToUTF8Array=(str,heap,outIdx,maxBytesToWrite)=>{if(!(maxBytesToWrite>0))return 0;var startIdx=outIdx;var endIdx=outIdx+maxBytesToWrite-1;for(var i=0;i<str.length;++i){var u=str.codePointAt(i);if(u<=127){if(outIdx>=endIdx)break;heap[outIdx++]=u}else if(u<=2047){if(outIdx+1>=endIdx)break;heap[outIdx++]=192|u>>6;heap[outIdx++]=128|u&63}else if(u<=65535){if(outIdx+2>=endIdx)break;heap[outIdx++]=224|u>>12;heap[outIdx++]=128|u>>6&63;heap[outIdx++]=128|u&63}else{if(outIdx+3>=endIdx)break;heap[outIdx++]=240|u>>18;heap[outIdx++]=128|u>>12&63;heap[outIdx++]=128|u>>6&63;heap[outIdx++]=128|u&63;i++}}heap[outIdx]=0;return outIdx-startIdx};var stringToUTF8=(str,outPtr,maxBytesToWrite)=>stringToUTF8Array(str,HEAPU8,outPtr,maxBytesToWrite);var stackAlloc=sz=>__emscripten_stack_alloc(sz);var stringToUTF8OnStack=str=>{var size=lengthBytesUTF8(str)+1;var ret=stackAlloc(size);stringToUTF8(str,ret,size);return ret};var UTF8Decoder=globalThis.TextDecoder&&new TextDecoder;var findStringEnd=(heapOrArray,idx,maxBytesToRead,ignoreNul)=>{var maxIdx=idx+maxBytesToRead;if(ignoreNul)return maxIdx;while(heapOrArray[idx]&&!(idx>=maxIdx))++idx;return idx};var UTF8ArrayToString=(heapOrArray,idx=0,maxBytesToRead,ignoreNul)=>{var endPtr=findStringEnd(heapOrArray,idx,maxBytesToRead,ignoreNul);if(endPtr-idx>16&&heapOrArray.buffer&&UTF8Decoder){return UTF8Decoder.decode(heapOrArray.subarray(idx,endPtr))}var str="";while(idx<endPtr){var u0=heapOrArray[idx++];if(!(u0&128)){str+=String.fromCharCode(u0);continue}var u1=heapOrArray[idx++]&63;if((u0&224)==192){str+=String.fromCharCode((u0&31)<<6|u1);continue}var u2=heapOrArray[idx++]&63;if((u0&240)==224){u0=(u0&15)<<12|u1<<6|u2}else{u0=(u0&7)<<18|u1<<12|u2<<6|heapOrArray[idx++]&63}if(u0<65536){str+=String.fromCharCode(u0)}else{var ch=u0-65536;str+=String.fromCharCode(55296|ch>>10,56320|ch&1023)}}return str};var UTF8ToString=(ptr,maxBytesToRead,ignoreNul)=>ptr?UTF8ArrayToString(HEAPU8,ptr,maxBytesToRead,ignoreNul):"";var ccall=(ident,returnType,argTypes,args,opts)=>{var toC={string:str=>{var ret=0;if(str!==null&&str!==undefined&&str!==0){ret=stringToUTF8OnStack(str)}return ret},array:arr=>{var ret=stackAlloc(arr.length);writeArrayToMemory(arr,ret);return ret}};function convertReturnValue(ret){if(returnType==="string"){return UTF8ToString(ret)}if(returnType==="boolean")return Boolean(ret);return ret}var func=getCFunc(ident);var cArgs=[];var stack=0;if(args){for(var i=0;i<args.length;i++){var converter=toC[argTypes[i]];if(converter){if(stack===0)stack=stackSave();cArgs[i]=converter(args[i])}else{cArgs[i]=args[i]}}}var ret=func(...cArgs);function onDone(ret){if(stack!==0)stackRestore(stack);return convertReturnValue(ret)}ret=onDone(ret);return ret};var cwrap=(ident,returnType,argTypes,opts)=>{var numericArgs=!argTypes||argTypes.every(type=>type==="number"||type==="boolean");var numericRet=returnType!=="string";if(numericRet&&numericArgs&&!opts){return getCFunc(ident)}return(...args)=>ccall(ident,returnType,argTypes,args,opts)};var wasmTableMirror=[];var getWasmTableEntry=funcPtr=>{var func=wasmTableMirror[funcPtr];if(!func){wasmTableMirror[funcPtr]=func=wasmTable.get(funcPtr)}return func};var updateTableMap=(offset,count)=>{if(functionsInTableMap){for(var i=offset;i<offset+count;i++){var item=getWasmTableEntry(i);if(item){functionsInTableMap.set(item,i)}}}};var functionsInTableMap;var getFunctionAddress=func=>{if(!functionsInTableMap){functionsInTableMap=new WeakMap;updateTableMap(0,wasmTable.length)}return functionsInTableMap.get(func)||0};var freeTableIndexes=[];var getEmptyTableSlot=()=>{if(freeTableIndexes.length){return freeTableIndexes.pop()}return wasmTable["grow"](1)};var setWasmTableEntry=(idx,func)=>{wasmTable.set(idx,func);wasmTableMirror[idx]=wasmTable.get(idx)};var uleb128EncodeWithLen=arr=>{const n=arr.length;return[n%128|128,n>>7,...arr]};var wasmTypeCodes={i:127,p:127,j:126,f:125,d:124,e:111};var generateTypePack=types=>uleb128EncodeWithLen(Array.from(types,type=>{var code=wasmTypeCodes[type];return code}));var convertJsFunctionToWasm=(func,sig)=>{var bytes=Uint8Array.of(0,97,115,109,1,0,0,0,1,...uleb128EncodeWithLen([1,96,...generateTypePack(sig.slice(1)),...generateTypePack(sig[0]==="v"?"":sig[0])]),2,7,1,1,101,1,102,0,0,7,5,1,1,102,0,0);var module=new WebAssembly.Module(bytes);var instance=new WebAssembly.Instance(module,{e:{f:func}});var wrappedFunc=instance.exports["f"];return wrappedFunc};var addFunction=(func,sig)=>{var rtn=getFunctionAddress(func);if(rtn){return rtn}var ret=getEmptyTableSlot();try{setWasmTableEntry(ret,func)}catch(err){if(!(err instanceof TypeError)){throw err}var wrapped=convertJsFunctionToWasm(func,sig);setWasmTableEntry(ret,wrapped)}functionsInTableMap.set(func,ret);return ret};var removeFunction=index=>{functionsInTableMap.delete(getWasmTableEntry(index));setWasmTableEntry(index,null);freeTableIndexes.push(index)};{if(Module["noExitRuntime"])noExitRuntime=Module["noExitRuntime"];if(Module["print"])out=Module["print"];if(Module["printErr"])err=Module["printErr"];if(Module["wasmBinary"])wasmBinary=Module["wasmBinary"];if(Module["arguments"])arguments_=Module["arguments"];if(Module["thisProgram"])thisProgram=Module["thisProgram"];if(Module["preInit"]){if(typeof Module["preInit"]=="function")Module["preInit"]=[Module["preInit"]];while(Module["preInit"].length>0){Module["preInit"].shift()()}}}Module["ccall"]=ccall;Module["cwrap"]=cwrap;Module["addFunction"]=addFunction;Module["removeFunction"]=removeFunction;Module["setValue"]=setValue;Module["getValue"]=getValue;Module["UTF8ToString"]=UTF8ToString;Module["stringToUTF8"]=stringToUTF8;Module["lengthBytesUTF8"]=lengthBytesUTF8;var _exponents,_roundval,_absolute,_random_normal,_random_seed,_random_uniform,_cross_entropy,_mse_loss,_centralnth,_forwarddiff,_central_difference,_secondderivative,_backwarddiff,_romberg_integrate,_free,_matmul_backward_a,_mattranspose,_matmul,_matmul_backward_b,_softmax_backward,_sigmoid_backward,_relu_backward,_normalize,_batch_sigmoid,_sigmoid,_batch_normalize,_init_weights,_malloc,_batch_softmax,_dense_forward,_relu,_dense_backward,_accuracy,_cross_entropy_backward,_batch_relu,_mse_backward,_matadd,_matscale,_vecdot,_vecscale,_vecnorm,_veccross,_vecadd,_sgd_update,_softmax,_glue_arena_create,_glue_arena_push,_glue_arena_clear,_glue_arena_destroy,_glue_add,_glue_subtract,_glue_multiply,_glue_divide,_glue_min,_glue_max,_glue_dense_forward,_glue_get_cache_ptr,_glue_dense_backward,_glue_get_dw_ptr,_glue_get_di_ptr,_glue_free,__emscripten_stack_restore,__emscripten_stack_alloc,_emscripten_stack_get_current,memory,__indirect_function_table,wasmMemory,wasmTable;function assignWasmExports(wasmExports){_exponents=Module["_exponents"]=wasmExports["exponents"];_roundval=Module["_roundval"]=wasmExports["roundval"];_absolute=Module["_absolute"]=wasmExports["absolute"];_random_normal=Module["_random_normal"]=wasmExports["random_normal"];_random_seed=Module["_random_seed"]=wasmExports["random_seed"];_random_uniform=Module["_random_uniform"]=wasmExports["random_uniform"];_cross_entropy=Module["_cross_entropy"]=wasmExports["cross_entropy"];_mse_loss=Module["_mse_loss"]=wasmExports["mse_loss"];_centralnth=Module["_centralnth"]=wasmExports["centralnth"];_forwarddiff=Module["_forwarddiff"]=wasmExports["forwarddiff"];_central_difference=Module["_central_difference"]=wasmExports["central_difference"];_secondderivative=Module["_secondderivative"]=wasmExports["secondderivative"];_backwarddiff=Module["_backwarddiff"]=wasmExports["backwarddiff"];_romberg_integrate=Module["_romberg_integrate"]=wasmExports["romberg_integrate"];_free=Module["_free"]=wasmExports["free"];_matmul_backward_a=Module["_matmul_backward_a"]=wasmExports["matmul_backward_a"];_mattranspose=Module["_mattranspose"]=wasmExports["mattranspose"];_matmul=Module["_matmul"]=wasmExports["matmul"];_matmul_backward_b=Module["_matmul_backward_b"]=wasmExports["matmul_backward_b"];_softmax_backward=Module["_softmax_backward"]=wasmExports["softmax_backward"];_sigmoid_backward=Module["_sigmoid_backward"]=wasmExports["sigmoid_backward"];_relu_backward=Module["_relu_backward"]=wasmExports["relu_backward"];_normalize=Module["_normalize"]=wasmExports["normalize"];_batch_sigmoid=Module["_batch_sigmoid"]=wasmExports["batch_sigmoid"];_sigmoid=Module["_sigmoid"]=wasmExports["sigmoid"];_batch_normalize=Module["_batch_normalize"]=wasmExports["batch_normalize"];_init_weights=Module["_init_weights"]=wasmExports["init_weights"];_malloc=Module["_malloc"]=wasmExports["malloc"];_batch_softmax=Module["_batch_softmax"]=wasmExports["batch_softmax"];_dense_forward=Module["_dense_forward"]=wasmExports["dense_forward"];_relu=Module["_relu"]=wasmExports["relu"];_dense_backward=Module["_dense_backward"]=wasmExports["dense_backward"];_accuracy=Module["_accuracy"]=wasmExports["accuracy"];_cross_entropy_backward=Module["_cross_entropy_backward"]=wasmExports["cross_entropy_backward"];_batch_relu=Module["_batch_relu"]=wasmExports["batch_relu"];_mse_backward=Module["_mse_backward"]=wasmExports["mse_backward"];_matadd=Module["_matadd"]=wasmExports["matadd"];_matscale=Module["_matscale"]=wasmExports["matscale"];_vecdot=Module["_vecdot"]=wasmExports["vecdot"];_vecscale=Module["_vecscale"]=wasmExports["vecscale"];_vecnorm=Module["_vecnorm"]=wasmExports["vecnorm"];_veccross=Module["_veccross"]=wasmExports["veccross"];_vecadd=Module["_vecadd"]=wasmExports["vecadd"];_sgd_update=Module["_sgd_update"]=wasmExports["sgd_update"];_softmax=Module["_softmax"]=wasmExports["softmax"];_glue_arena_create=Module["_glue_arena_create"]=wasmExports["glue_arena_create"];_glue_arena_push=Module["_glue_arena_push"]=wasmExports["glue_arena_push"];_glue_arena_clear=Module["_glue_arena_clear"]=wasmExports["glue_arena_clear"];_glue_arena_destroy=Module["_glue_arena_destroy"]=wasmExports["glue_arena_destroy"];_glue_add=Module["_glue_add"]=wasmExports["glue_add"];_glue_subtract=Module["_glue_subtract"]=wasmExports["glue_subtract"];_glue_multiply=Module["_glue_multiply"]=wasmExports["glue_multiply"];_glue_divide=Module["_glue_divide"]=wasmExports["glue_divide"];_glue_min=Module["_glue_min"]=wasmExports["glue_min"];_glue_max=Module["_glue_max"]=wasmExports["glue_max"];_glue_dense_forward=Module["_glue_dense_forward"]=wasmExports["glue_dense_forward"];_glue_get_cache_ptr=Module["_glue_get_cache_ptr"]=wasmExports["glue_get_cache_ptr"];_glue_dense_backward=Module["_glue_dense_backward"]=wasmExports["glue_dense_backward"];_glue_get_dw_ptr=Module["_glue_get_dw_ptr"]=wasmExports["glue_get_dw_ptr"];_glue_get_di_ptr=Module["_glue_get_di_ptr"]=wasmExports["glue_get_di_ptr"];_glue_free=Module["_glue_free"]=wasmExports["glue_free"];__emscripten_stack_restore=wasmExports["_emscripten_stack_restore"];__emscripten_stack_alloc=wasmExports["_emscripten_stack_alloc"];_emscripten_stack_get_current=wasmExports["emscripten_stack_get_current"];memory=wasmMemory=Module["wasmMemory"]=wasmExports["memory"];__indirect_function_table=wasmTable=wasmExports["__indirect_function_table"]}var wasmImports={emscripten_resize_heap:_emscripten_resize_heap};function run(){preRun();function doRun(){Module["calledRun"]=true;if(ABORT)return;initRuntime();readyPromiseResolve?.(Module);Module["onRuntimeInitialized"]?.();postRun()}if(Module["setStatus"]){Module["setStatus"]("Running...");setTimeout(()=>{setTimeout(()=>Module["setStatus"](""),1);doRun()},1)}else{doRun()}}var wasmExports;wasmExports=await (createWasm());run();if(runtimeInitialized){moduleRtn=Module}else{moduleRtn=new Promise((resolve,reject)=>{readyPromiseResolve=resolve;readyPromiseReject=reject})}
|
|
2
|
+
;return moduleRtn}export default createOpenDI;
|
|
Binary file
|
package/opendi.d.ts
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
export interface Weights {
|
|
2
|
+
/** Raw WASM pointer — use with Session API */
|
|
3
|
+
_ptr: number
|
|
4
|
+
/** Copy of the weight data as Float64Array */
|
|
5
|
+
data: Float64Array
|
|
6
|
+
/** Number of weights */
|
|
7
|
+
length: number
|
|
8
|
+
/** Free the underlying malloc'd memory */
|
|
9
|
+
free(): void
|
|
10
|
+
}
|
|
11
|
+
|
|
12
|
+
export interface ForwardResult {
|
|
13
|
+
/** WASM pointer to output array */
|
|
14
|
+
outputPtr: number
|
|
15
|
+
/** WASM pointer to cache (for backward pass), or 0 if softmax/none */
|
|
16
|
+
cachePtr: number
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
export interface BackwardResult {
|
|
20
|
+
/** WASM pointer to weight gradients (n × p) */
|
|
21
|
+
dWeightsPtr: number
|
|
22
|
+
/** WASM pointer to input gradients (m × n) */
|
|
23
|
+
dInputPtr: number
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
export type Activation = 'none' | 'relu' | 'sigmoid' | 'softmax' | number
|
|
27
|
+
|
|
28
|
+
export interface Session {
|
|
29
|
+
writeArray(arr: ArrayLike<number>): number
|
|
30
|
+
writeI32Array(arr: ArrayLike<number>): number
|
|
31
|
+
readArray(ptr: number, n: number): Float64Array
|
|
32
|
+
/** Copy a JS array into a WASM pointer (e.g. malloc'd weights) */
|
|
33
|
+
copyToPtr(destPtr: number, srcArray: ArrayLike<number>): void
|
|
34
|
+
|
|
35
|
+
softmax(inputPtr: number, n: number): number
|
|
36
|
+
vecadd(aPtr: number, bPtr: number, n: number): number
|
|
37
|
+
vecscale(arrPtr: number, scalar: number, n: number): number
|
|
38
|
+
matmul(aPtr: number, bPtr: number, m: number, n: number, p: number): number
|
|
39
|
+
matadd(aPtr: number, bPtr: number, m: number, n: number): number
|
|
40
|
+
matscale(aPtr: number, s: number, m: number, n: number): number
|
|
41
|
+
mattranspose(aPtr: number, m: number, n: number): number
|
|
42
|
+
|
|
43
|
+
reluBackward(doutPtr: number, inputPtr: number, n: number): number
|
|
44
|
+
sigmoidBackward(doutPtr: number, outputPtr: number, n: number): number
|
|
45
|
+
softmaxBackward(doutPtr: number, outputPtr: number, n: number): number
|
|
46
|
+
matmulBackwardA(doutPtr: number, bPtr: number, m: number, n: number, p: number): number
|
|
47
|
+
matmulBackwardB(aPtr: number, doutPtr: number, m: number, n: number, p: number): number
|
|
48
|
+
|
|
49
|
+
sgdUpdate(weightsPtr: number, gradsPtr: number, lr: number, n: number): number
|
|
50
|
+
|
|
51
|
+
normalize(ptr: number, n: number): number
|
|
52
|
+
batchRelu(ptr: number, n: number): number
|
|
53
|
+
batchSigmoid(ptr: number, n: number): number
|
|
54
|
+
batchSoftmax(ptr: number, rows: number, cols: number): number
|
|
55
|
+
batchNormalize(ptr: number, nSamples: number, nFeatures: number): number
|
|
56
|
+
|
|
57
|
+
mseBackward(predPtr: number, targetsPtr: number, n: number): number
|
|
58
|
+
crossEntropyBackward(predPtr: number, targetsPtr: number, nSamples: number, nClasses: number): number
|
|
59
|
+
|
|
60
|
+
randomUniform(min: number, max: number, n: number): number
|
|
61
|
+
randomNormal(mean: number, std: number, n: number): number
|
|
62
|
+
|
|
63
|
+
denseForward(inputPtr: number, weightsPtr: number, m: number, n: number, p: number, activation: Activation): ForwardResult
|
|
64
|
+
denseBackward(doutPtr: number, inputPtr: number, weightsPtr: number, cachePtr: number, m: number, n: number, p: number, activation: Activation): BackwardResult
|
|
65
|
+
|
|
66
|
+
mseLoss(predPtr: number, targetsPtr: number, n: number): number
|
|
67
|
+
crossEntropy(predPtr: number, targetsPtr: number, n: number): number
|
|
68
|
+
accuracy(predPtr: number, labelsPtr: number, nSamples: number, nClasses: number): number
|
|
69
|
+
|
|
70
|
+
clear(): void
|
|
71
|
+
destroy(): void
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
export interface OpenDI {
|
|
75
|
+
// Primitives
|
|
76
|
+
add(...nums: number[]): number
|
|
77
|
+
subtract(...nums: number[]): number
|
|
78
|
+
multiply(...nums: number[]): number
|
|
79
|
+
divide(...nums: number[]): number
|
|
80
|
+
abs(x: number): number
|
|
81
|
+
min(...nums: number[]): number
|
|
82
|
+
max(...nums: number[]): number
|
|
83
|
+
round(mode: 'floor' | 'ceil', x: number): number
|
|
84
|
+
pow(base: number, exp: number): number
|
|
85
|
+
|
|
86
|
+
// Activations
|
|
87
|
+
relu(x: number): number
|
|
88
|
+
sigmoid(x: number): number
|
|
89
|
+
softmax(arr: ArrayLike<number>): Float64Array
|
|
90
|
+
|
|
91
|
+
// Vectors
|
|
92
|
+
vecadd(a: ArrayLike<number>, b: ArrayLike<number>): Float64Array
|
|
93
|
+
vecscale(arr: ArrayLike<number>, scalar: number): Float64Array
|
|
94
|
+
veccross(a: ArrayLike<number>, b: ArrayLike<number>): Float64Array
|
|
95
|
+
vecdot(a: ArrayLike<number>, b: ArrayLike<number>): number
|
|
96
|
+
vecnorm(arr: ArrayLike<number>): number
|
|
97
|
+
|
|
98
|
+
// Matrices
|
|
99
|
+
matmul(a: ArrayLike<number>, b: ArrayLike<number>, m: number, n: number, p: number): Float64Array
|
|
100
|
+
matadd(a: ArrayLike<number>, b: ArrayLike<number>, m: number, n: number): Float64Array
|
|
101
|
+
matscale(a: ArrayLike<number>, s: number, m: number, n: number): Float64Array
|
|
102
|
+
mattranspose(a: ArrayLike<number>, m: number, n: number): Float64Array
|
|
103
|
+
|
|
104
|
+
// Loss
|
|
105
|
+
mseLoss(pred: ArrayLike<number>, targets: ArrayLike<number>): number
|
|
106
|
+
crossEntropy(pred: ArrayLike<number>, targets: ArrayLike<number>): number
|
|
107
|
+
|
|
108
|
+
// Backward
|
|
109
|
+
reluBackward(dout: ArrayLike<number>, input: ArrayLike<number>): Float64Array
|
|
110
|
+
sigmoidBackward(dout: ArrayLike<number>, output: ArrayLike<number>): Float64Array
|
|
111
|
+
softmaxBackward(dout: ArrayLike<number>, output: ArrayLike<number>): Float64Array
|
|
112
|
+
matmulBackwardA(dout: ArrayLike<number>, b: ArrayLike<number>, m: number, n: number, p: number): Float64Array
|
|
113
|
+
matmulBackwardB(a: ArrayLike<number>, dout: ArrayLike<number>, m: number, n: number, p: number): Float64Array
|
|
114
|
+
|
|
115
|
+
// Optimizer
|
|
116
|
+
sgdUpdate(weights: ArrayLike<number>, grads: ArrayLike<number>, lr: number): Float64Array
|
|
117
|
+
|
|
118
|
+
// Random
|
|
119
|
+
randomSeed(seed: number): void
|
|
120
|
+
randomUniform(min: number, max: number, n: number): Float64Array
|
|
121
|
+
randomNormal(mean: number, std: number, n: number): Float64Array
|
|
122
|
+
|
|
123
|
+
// Statistics
|
|
124
|
+
normalize(arr: ArrayLike<number>): Float64Array
|
|
125
|
+
|
|
126
|
+
// Calculus
|
|
127
|
+
forwarddiff(f: (x: number) => number, x: number, h: number): number
|
|
128
|
+
backwarddiff(f: (x: number) => number, x: number, h: number): number
|
|
129
|
+
centralDifference(f: (x: number) => number, x: number, h: number): number
|
|
130
|
+
secondDerivative(f: (x: number) => number, x: number, h: number): number
|
|
131
|
+
centralNth(f: (x: number) => number, x: number, h: number, n: number): number
|
|
132
|
+
rombergIntegrate(f: (x: number) => number, a: number, b: number, eps: number, kMax: number): number
|
|
133
|
+
|
|
134
|
+
// Pipeline
|
|
135
|
+
batchRelu(input: ArrayLike<number>): Float64Array
|
|
136
|
+
batchSigmoid(input: ArrayLike<number>): Float64Array
|
|
137
|
+
batchSoftmax(input: ArrayLike<number>, rows: number, cols: number): Float64Array
|
|
138
|
+
batchNormalize(features: ArrayLike<number>, nSamples: number, nFeatures: number): Float64Array
|
|
139
|
+
mseBackward(pred: ArrayLike<number>, targets: ArrayLike<number>): Float64Array
|
|
140
|
+
crossEntropyBackward(pred: ArrayLike<number>, targets: ArrayLike<number>, nSamples: number, nClasses: number): Float64Array
|
|
141
|
+
accuracy(pred: ArrayLike<number>, labels: ArrayLike<number>, nSamples: number, nClasses: number): number
|
|
142
|
+
|
|
143
|
+
// Weights
|
|
144
|
+
initWeights(n: number, mean: number, std: number): Weights
|
|
145
|
+
|
|
146
|
+
// Session
|
|
147
|
+
createSession(arenaBytes?: number): Session
|
|
148
|
+
|
|
149
|
+
// Constants
|
|
150
|
+
readonly ACTIVATION_NONE: 0
|
|
151
|
+
readonly ACTIVATION_RELU: 1
|
|
152
|
+
readonly ACTIVATION_SIGMOID: 2
|
|
153
|
+
readonly ACTIVATION_SOFTMAX: 3
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
export function init(wasmUrl?: string): Promise<OpenDI>
|
|
157
|
+
export default init
|
package/package.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "opendi-js",
|
|
3
|
+
"version": "0.1.0",
|
|
4
|
+
"description": "WASM + JS bindings for the OpenDI math/ML library",
|
|
5
|
+
"type": "module",
|
|
6
|
+
"main": "./src/opendi.cjs",
|
|
7
|
+
"module": "./src/opendi.mjs",
|
|
8
|
+
"types": "./opendi.d.ts",
|
|
9
|
+
"exports": {
|
|
10
|
+
".": {
|
|
11
|
+
"import": "./src/opendi.mjs",
|
|
12
|
+
"require": "./src/opendi.cjs"
|
|
13
|
+
}
|
|
14
|
+
},
|
|
15
|
+
"scripts": {
|
|
16
|
+
"build": "bash scripts/build.sh"
|
|
17
|
+
},
|
|
18
|
+
"files": [
|
|
19
|
+
"src/opendi.mjs",
|
|
20
|
+
"src/opendi.cjs",
|
|
21
|
+
"build/",
|
|
22
|
+
"opendi.d.ts"
|
|
23
|
+
],
|
|
24
|
+
"keywords": [
|
|
25
|
+
"wasm",
|
|
26
|
+
"math",
|
|
27
|
+
"ml",
|
|
28
|
+
"linear-algebra",
|
|
29
|
+
"neural-network",
|
|
30
|
+
"opendi"
|
|
31
|
+
],
|
|
32
|
+
"license": "MIT",
|
|
33
|
+
"author": "itscool2b",
|
|
34
|
+
"repository": {
|
|
35
|
+
"type": "git",
|
|
36
|
+
"url": "git+https://github.com/itscool2b/OpenDI.git",
|
|
37
|
+
"directory": "opendi-js"
|
|
38
|
+
},
|
|
39
|
+
"homepage": "https://github.com/itscool2b/OpenDI/tree/master/opendi-js"
|
|
40
|
+
}
|
package/src/opendi.cjs
ADDED
package/src/opendi.mjs
ADDED
|
@@ -0,0 +1,613 @@
|
|
|
1
|
+
import createOpenDI from '../build/opendi.js'
|
|
2
|
+
|
|
3
|
+
// Activation enum values matching C ActivationType
|
|
4
|
+
const ACT = { none: 0, relu: 1, sigmoid: 2, softmax: 3 }
|
|
5
|
+
|
|
6
|
+
function actEnum(str) {
|
|
7
|
+
if (typeof str === 'number') return str
|
|
8
|
+
const v = ACT[(str || 'none').toLowerCase()]
|
|
9
|
+
if (v === undefined) throw new Error(`Unknown activation: ${str}`)
|
|
10
|
+
return v
|
|
11
|
+
}
|
|
12
|
+
|
|
13
|
+
// ── Heap helpers ──
|
|
14
|
+
|
|
15
|
+
function writeF64(wasm, arena, arr) {
|
|
16
|
+
const bytes = arr.length * 8
|
|
17
|
+
const ptr = wasm._glue_arena_push(arena, bytes)
|
|
18
|
+
if (!ptr) throw new Error('Arena out of memory')
|
|
19
|
+
const heap = new Float64Array(wasm.wasmMemory.buffer, ptr, arr.length)
|
|
20
|
+
heap.set(arr)
|
|
21
|
+
return ptr
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
function readF64(wasm, ptr, n) {
|
|
25
|
+
const offset = ptr / 8
|
|
26
|
+
return new Float64Array(wasm.wasmMemory.buffer.slice(ptr, ptr + n * 8))
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
function writeI32(wasm, arena, arr) {
|
|
30
|
+
const bytes = arr.length * 4
|
|
31
|
+
const ptr = wasm._glue_arena_push(arena, bytes)
|
|
32
|
+
if (!ptr) throw new Error('Arena out of memory')
|
|
33
|
+
const heap = new Int32Array(wasm.wasmMemory.buffer, ptr, arr.length)
|
|
34
|
+
heap.set(arr)
|
|
35
|
+
return ptr
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
function readI32(wasm, ptr, n) {
|
|
39
|
+
return new Int32Array(wasm.wasmMemory.buffer.slice(ptr, ptr + n * 4))
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// ── Temp arena pattern ──
|
|
43
|
+
|
|
44
|
+
function withTempArena(wasm, bytes, fn) {
|
|
45
|
+
const arena = wasm._glue_arena_create(bytes)
|
|
46
|
+
try {
|
|
47
|
+
return fn(arena)
|
|
48
|
+
} finally {
|
|
49
|
+
wasm._glue_arena_destroy(arena)
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
// ── Function pointer helpers for calculus ──
|
|
54
|
+
// Cache function pointers to avoid removeFunction bugs in Emscripten.
|
|
55
|
+
// Each unique JS function gets one WASM table slot.
|
|
56
|
+
const fpCache = new Map()
|
|
57
|
+
|
|
58
|
+
function withFunctionPointer(wasm, jsFn, cb) {
|
|
59
|
+
let fp = fpCache.get(jsFn)
|
|
60
|
+
if (fp === undefined) {
|
|
61
|
+
fp = wasm.addFunction(jsFn, 'dd')
|
|
62
|
+
fpCache.set(jsFn, fp)
|
|
63
|
+
}
|
|
64
|
+
return cb(fp)
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
// ── Main init ──
|
|
68
|
+
|
|
69
|
+
export async function init(wasmUrl) {
|
|
70
|
+
const opts = {}
|
|
71
|
+
if (wasmUrl) {
|
|
72
|
+
opts.locateFile = () => wasmUrl
|
|
73
|
+
}
|
|
74
|
+
const wasm = await createOpenDI(opts)
|
|
75
|
+
|
|
76
|
+
const api = {}
|
|
77
|
+
|
|
78
|
+
// ── Primitives (variadic → array-based via glue) ──
|
|
79
|
+
|
|
80
|
+
api.add = (...nums) => {
|
|
81
|
+
const arr = nums.flat()
|
|
82
|
+
return withTempArena(wasm, arr.length * 8 + 64, (arena) => {
|
|
83
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
84
|
+
return wasm._glue_add(ptr, arr.length)
|
|
85
|
+
})
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
api.subtract = (...nums) => {
|
|
89
|
+
const arr = nums.flat()
|
|
90
|
+
return withTempArena(wasm, arr.length * 8 + 64, (arena) => {
|
|
91
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
92
|
+
return wasm._glue_subtract(ptr, arr.length)
|
|
93
|
+
})
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
api.multiply = (...nums) => {
|
|
97
|
+
const arr = nums.flat()
|
|
98
|
+
return withTempArena(wasm, arr.length * 8 + 64, (arena) => {
|
|
99
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
100
|
+
return wasm._glue_multiply(ptr, arr.length)
|
|
101
|
+
})
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
api.divide = (...nums) => {
|
|
105
|
+
const arr = nums.flat()
|
|
106
|
+
return withTempArena(wasm, arr.length * 8 + 64, (arena) => {
|
|
107
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
108
|
+
return wasm._glue_divide(ptr, arr.length)
|
|
109
|
+
})
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
api.abs = (x) => wasm._absolute(x)
|
|
113
|
+
|
|
114
|
+
api.min = (...nums) => {
|
|
115
|
+
const arr = nums.flat()
|
|
116
|
+
return withTempArena(wasm, arr.length * 8 + 64, (arena) => {
|
|
117
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
118
|
+
return wasm._glue_min(ptr, arr.length)
|
|
119
|
+
})
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
api.max = (...nums) => {
|
|
123
|
+
const arr = nums.flat()
|
|
124
|
+
return withTempArena(wasm, arr.length * 8 + 64, (arena) => {
|
|
125
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
126
|
+
return wasm._glue_max(ptr, arr.length)
|
|
127
|
+
})
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
api.round = (mode, x) => {
|
|
131
|
+
return withTempArena(wasm, 64, (arena) => {
|
|
132
|
+
const strBytes = wasm.lengthBytesUTF8(mode) + 1
|
|
133
|
+
const strPtr = wasm._glue_arena_push(arena, strBytes)
|
|
134
|
+
wasm.stringToUTF8(mode, strPtr, strBytes)
|
|
135
|
+
return wasm._roundval(strPtr, x)
|
|
136
|
+
})
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
api.pow = (base, exp) => wasm._exponents(base, exp)
|
|
140
|
+
|
|
141
|
+
// ── Activations ──
|
|
142
|
+
|
|
143
|
+
api.relu = (x) => wasm._relu(x)
|
|
144
|
+
|
|
145
|
+
api.sigmoid = (x) => wasm._sigmoid(x)
|
|
146
|
+
|
|
147
|
+
api.softmax = (arr) => {
|
|
148
|
+
const input = Array.isArray(arr) ? arr : Array.from(arr)
|
|
149
|
+
const n = input.length
|
|
150
|
+
return withTempArena(wasm, n * 8 * 3 + 256, (arena) => {
|
|
151
|
+
const inPtr = writeF64(wasm, arena, input)
|
|
152
|
+
const outPtr = wasm._softmax(arena, inPtr, n)
|
|
153
|
+
return readF64(wasm, outPtr, n)
|
|
154
|
+
})
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
// ── Vectors ──
|
|
158
|
+
|
|
159
|
+
api.vecadd = (a, b) => {
|
|
160
|
+
const n = a.length
|
|
161
|
+
return withTempArena(wasm, n * 8 * 3 + 256, (arena) => {
|
|
162
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
163
|
+
const bPtr = writeF64(wasm, arena, b)
|
|
164
|
+
const outPtr = wasm._vecadd(arena, aPtr, bPtr, n)
|
|
165
|
+
return readF64(wasm, outPtr, n)
|
|
166
|
+
})
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
api.vecscale = (arr, scalar) => {
|
|
170
|
+
const n = arr.length
|
|
171
|
+
return withTempArena(wasm, n * 8 * 2 + 256, (arena) => {
|
|
172
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
173
|
+
const outPtr = wasm._vecscale(arena, ptr, scalar, n)
|
|
174
|
+
return readF64(wasm, outPtr, n)
|
|
175
|
+
})
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
api.veccross = (a, b) => {
|
|
179
|
+
return withTempArena(wasm, 6 * 8 + 256, (arena) => {
|
|
180
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
181
|
+
const bPtr = writeF64(wasm, arena, b)
|
|
182
|
+
const outPtr = wasm._veccross(arena, aPtr, bPtr)
|
|
183
|
+
return readF64(wasm, outPtr, 3)
|
|
184
|
+
})
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
api.vecdot = (a, b) => {
|
|
188
|
+
const n = a.length
|
|
189
|
+
return withTempArena(wasm, n * 8 * 2 + 256, (arena) => {
|
|
190
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
191
|
+
const bPtr = writeF64(wasm, arena, b)
|
|
192
|
+
return wasm._vecdot(aPtr, bPtr, n)
|
|
193
|
+
})
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
api.vecnorm = (arr) => {
|
|
197
|
+
const n = arr.length
|
|
198
|
+
return withTempArena(wasm, n * 8 + 256, (arena) => {
|
|
199
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
200
|
+
return wasm._vecnorm(ptr, n)
|
|
201
|
+
})
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
// ── Matrices ──
|
|
205
|
+
|
|
206
|
+
api.matmul = (a, b, m, n, p) => {
|
|
207
|
+
return withTempArena(wasm, (m * n + n * p + m * p) * 8 + 256, (arena) => {
|
|
208
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
209
|
+
const bPtr = writeF64(wasm, arena, b)
|
|
210
|
+
const outPtr = wasm._matmul(arena, aPtr, bPtr, m, n, p)
|
|
211
|
+
return readF64(wasm, outPtr, m * p)
|
|
212
|
+
})
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
api.matadd = (a, b, m, n) => {
|
|
216
|
+
return withTempArena(wasm, m * n * 3 * 8 + 256, (arena) => {
|
|
217
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
218
|
+
const bPtr = writeF64(wasm, arena, b)
|
|
219
|
+
const outPtr = wasm._matadd(arena, aPtr, bPtr, m, n)
|
|
220
|
+
return readF64(wasm, outPtr, m * n)
|
|
221
|
+
})
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
api.matscale = (a, s, m, n) => {
|
|
225
|
+
return withTempArena(wasm, m * n * 2 * 8 + 256, (arena) => {
|
|
226
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
227
|
+
const outPtr = wasm._matscale(arena, aPtr, s, m, n)
|
|
228
|
+
return readF64(wasm, outPtr, m * n)
|
|
229
|
+
})
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
api.mattranspose = (a, m, n) => {
|
|
233
|
+
return withTempArena(wasm, m * n * 2 * 8 + 256, (arena) => {
|
|
234
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
235
|
+
const outPtr = wasm._mattranspose(arena, aPtr, m, n)
|
|
236
|
+
return readF64(wasm, outPtr, n * m)
|
|
237
|
+
})
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
// ── Loss ──
|
|
241
|
+
|
|
242
|
+
api.mseLoss = (pred, targets) => {
|
|
243
|
+
const n = pred.length
|
|
244
|
+
return withTempArena(wasm, n * 8 * 2 + 256, (arena) => {
|
|
245
|
+
const pPtr = writeF64(wasm, arena, pred)
|
|
246
|
+
const tPtr = writeF64(wasm, arena, targets)
|
|
247
|
+
return wasm._mse_loss(pPtr, tPtr, n)
|
|
248
|
+
})
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
api.crossEntropy = (pred, targets) => {
|
|
252
|
+
const n = pred.length
|
|
253
|
+
return withTempArena(wasm, n * 8 * 2 + 256, (arena) => {
|
|
254
|
+
const pPtr = writeF64(wasm, arena, pred)
|
|
255
|
+
const tPtr = writeF64(wasm, arena, targets)
|
|
256
|
+
return wasm._cross_entropy(pPtr, tPtr, n)
|
|
257
|
+
})
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
// ── Backward ──
|
|
261
|
+
|
|
262
|
+
api.reluBackward = (dout, input) => {
|
|
263
|
+
const n = dout.length
|
|
264
|
+
return withTempArena(wasm, n * 8 * 3 + 256, (arena) => {
|
|
265
|
+
const dPtr = writeF64(wasm, arena, dout)
|
|
266
|
+
const iPtr = writeF64(wasm, arena, input)
|
|
267
|
+
const outPtr = wasm._relu_backward(arena, dPtr, iPtr, n)
|
|
268
|
+
return readF64(wasm, outPtr, n)
|
|
269
|
+
})
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
api.sigmoidBackward = (dout, output) => {
|
|
273
|
+
const n = dout.length
|
|
274
|
+
return withTempArena(wasm, n * 8 * 3 + 256, (arena) => {
|
|
275
|
+
const dPtr = writeF64(wasm, arena, dout)
|
|
276
|
+
const oPtr = writeF64(wasm, arena, output)
|
|
277
|
+
const outPtr = wasm._sigmoid_backward(arena, dPtr, oPtr, n)
|
|
278
|
+
return readF64(wasm, outPtr, n)
|
|
279
|
+
})
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
api.softmaxBackward = (dout, output) => {
|
|
283
|
+
const n = dout.length
|
|
284
|
+
return withTempArena(wasm, n * 8 * 3 + 256, (arena) => {
|
|
285
|
+
const dPtr = writeF64(wasm, arena, dout)
|
|
286
|
+
const oPtr = writeF64(wasm, arena, output)
|
|
287
|
+
const outPtr = wasm._softmax_backward(arena, dPtr, oPtr, n)
|
|
288
|
+
return readF64(wasm, outPtr, n)
|
|
289
|
+
})
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
api.matmulBackwardA = (dout, b, m, n, p) => {
|
|
293
|
+
return withTempArena(wasm, (m * p + n * p + m * n) * 8 + 256, (arena) => {
|
|
294
|
+
const dPtr = writeF64(wasm, arena, dout)
|
|
295
|
+
const bPtr = writeF64(wasm, arena, b)
|
|
296
|
+
const outPtr = wasm._matmul_backward_a(arena, dPtr, bPtr, m, n, p)
|
|
297
|
+
return readF64(wasm, outPtr, m * n)
|
|
298
|
+
})
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
api.matmulBackwardB = (a, dout, m, n, p) => {
|
|
302
|
+
return withTempArena(wasm, (m * n + m * p + n * p) * 8 + 256, (arena) => {
|
|
303
|
+
const aPtr = writeF64(wasm, arena, a)
|
|
304
|
+
const dPtr = writeF64(wasm, arena, dout)
|
|
305
|
+
const outPtr = wasm._matmul_backward_b(arena, aPtr, dPtr, m, n, p)
|
|
306
|
+
return readF64(wasm, outPtr, n * p)
|
|
307
|
+
})
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
// ── Optimizer ──
|
|
311
|
+
|
|
312
|
+
api.sgdUpdate = (weights, grads, lr) => {
|
|
313
|
+
const n = weights.length
|
|
314
|
+
return withTempArena(wasm, n * 8 * 3 + 256, (arena) => {
|
|
315
|
+
const wPtr = writeF64(wasm, arena, weights)
|
|
316
|
+
const gPtr = writeF64(wasm, arena, grads)
|
|
317
|
+
const outPtr = wasm._sgd_update(arena, wPtr, gPtr, lr, n)
|
|
318
|
+
return readF64(wasm, outPtr, n)
|
|
319
|
+
})
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
// ── Random ──
|
|
323
|
+
|
|
324
|
+
api.randomSeed = (seed) => wasm._random_seed(seed)
|
|
325
|
+
|
|
326
|
+
api.randomUniform = (min, max, n) => {
|
|
327
|
+
return withTempArena(wasm, n * 8 + 256, (arena) => {
|
|
328
|
+
const outPtr = wasm._random_uniform(arena, min, max, n)
|
|
329
|
+
return readF64(wasm, outPtr, n)
|
|
330
|
+
})
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
api.randomNormal = (mean, std, n) => {
|
|
334
|
+
return withTempArena(wasm, n * 8 + 256, (arena) => {
|
|
335
|
+
const outPtr = wasm._random_normal(arena, mean, std, n)
|
|
336
|
+
return readF64(wasm, outPtr, n)
|
|
337
|
+
})
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
// ── Statistics ──
|
|
341
|
+
|
|
342
|
+
api.normalize = (arr) => {
|
|
343
|
+
const n = arr.length
|
|
344
|
+
return withTempArena(wasm, n * 8 * 2 + 256, (arena) => {
|
|
345
|
+
const ptr = writeF64(wasm, arena, arr)
|
|
346
|
+
const outPtr = wasm._normalize(arena, ptr, n)
|
|
347
|
+
return readF64(wasm, outPtr, n)
|
|
348
|
+
})
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
// ── Calculus (function pointer based) ──
|
|
352
|
+
|
|
353
|
+
api.forwarddiff = (f, x, h) => {
|
|
354
|
+
return withFunctionPointer(wasm, f, (fp) => {
|
|
355
|
+
return wasm._forwarddiff(fp, x, h)
|
|
356
|
+
})
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
api.backwarddiff = (f, x, h) => {
|
|
360
|
+
return withFunctionPointer(wasm, f, (fp) => {
|
|
361
|
+
return wasm._backwarddiff(fp, x, h)
|
|
362
|
+
})
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
api.centralDifference = (f, x, h) => {
|
|
366
|
+
return withFunctionPointer(wasm, f, (fp) => {
|
|
367
|
+
return wasm._central_difference(fp, x, h)
|
|
368
|
+
})
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
api.secondDerivative = (f, x, h) => {
|
|
372
|
+
return withFunctionPointer(wasm, f, (fp) => {
|
|
373
|
+
return wasm._secondderivative(fp, x, h)
|
|
374
|
+
})
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
api.centralNth = (f, x, h, n) => {
|
|
378
|
+
return withFunctionPointer(wasm, f, (fp) => {
|
|
379
|
+
return wasm._centralnth(fp, x, h, n)
|
|
380
|
+
})
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
api.rombergIntegrate = (f, a, b, eps, kMax) => {
|
|
384
|
+
return withFunctionPointer(wasm, f, (fp) => {
|
|
385
|
+
return wasm._romberg_integrate(fp, a, b, eps, kMax)
|
|
386
|
+
})
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
// ── Pipeline (batch operations) ──
|
|
390
|
+
|
|
391
|
+
api.batchRelu = (input) => {
|
|
392
|
+
const n = input.length
|
|
393
|
+
return withTempArena(wasm, n * 8 * 2 + 256, (arena) => {
|
|
394
|
+
const ptr = writeF64(wasm, arena, input)
|
|
395
|
+
const outPtr = wasm._batch_relu(arena, ptr, n)
|
|
396
|
+
return readF64(wasm, outPtr, n)
|
|
397
|
+
})
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
api.batchSigmoid = (input) => {
|
|
401
|
+
const n = input.length
|
|
402
|
+
return withTempArena(wasm, n * 8 * 2 + 256, (arena) => {
|
|
403
|
+
const ptr = writeF64(wasm, arena, input)
|
|
404
|
+
const outPtr = wasm._batch_sigmoid(arena, ptr, n)
|
|
405
|
+
return readF64(wasm, outPtr, n)
|
|
406
|
+
})
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
api.batchSoftmax = (input, rows, cols) => {
|
|
410
|
+
return withTempArena(wasm, rows * cols * 8 * 2 + 256, (arena) => {
|
|
411
|
+
const ptr = writeF64(wasm, arena, input)
|
|
412
|
+
const outPtr = wasm._batch_softmax(arena, ptr, rows, cols)
|
|
413
|
+
return readF64(wasm, outPtr, rows * cols)
|
|
414
|
+
})
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
api.batchNormalize = (features, nSamples, nFeatures) => {
|
|
418
|
+
return withTempArena(wasm, nSamples * nFeatures * 8 * 2 + 256, (arena) => {
|
|
419
|
+
const ptr = writeF64(wasm, arena, features)
|
|
420
|
+
const outPtr = wasm._batch_normalize(arena, ptr, nSamples, nFeatures)
|
|
421
|
+
return readF64(wasm, outPtr, nSamples * nFeatures)
|
|
422
|
+
})
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
api.mseBackward = (pred, targets) => {
|
|
426
|
+
const n = pred.length
|
|
427
|
+
return withTempArena(wasm, n * 8 * 3 + 256, (arena) => {
|
|
428
|
+
const pPtr = writeF64(wasm, arena, pred)
|
|
429
|
+
const tPtr = writeF64(wasm, arena, targets)
|
|
430
|
+
const outPtr = wasm._mse_backward(arena, pPtr, tPtr, n)
|
|
431
|
+
return readF64(wasm, outPtr, n)
|
|
432
|
+
})
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
api.crossEntropyBackward = (pred, targets, nSamples, nClasses) => {
|
|
436
|
+
const total = nSamples * nClasses
|
|
437
|
+
return withTempArena(wasm, total * 8 * 3 + 256, (arena) => {
|
|
438
|
+
const pPtr = writeF64(wasm, arena, pred)
|
|
439
|
+
const tPtr = writeF64(wasm, arena, targets)
|
|
440
|
+
const outPtr = wasm._cross_entropy_backward(arena, pPtr, tPtr, nSamples, nClasses)
|
|
441
|
+
return readF64(wasm, outPtr, total)
|
|
442
|
+
})
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
api.accuracy = (pred, labels, nSamples, nClasses) => {
|
|
446
|
+
return withTempArena(wasm, nSamples * nClasses * 8 + nSamples * 4 + 256, (arena) => {
|
|
447
|
+
const pPtr = writeF64(wasm, arena, pred)
|
|
448
|
+
const lPtr = writeI32(wasm, arena, labels)
|
|
449
|
+
return wasm._accuracy(pPtr, lPtr, nSamples, nClasses)
|
|
450
|
+
})
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
// ── Weights (uses malloc, not arena) ──
|
|
454
|
+
|
|
455
|
+
api.initWeights = (n, mean, std) => {
|
|
456
|
+
const ptr = wasm._init_weights(n, mean, std)
|
|
457
|
+
const data = readF64(wasm, ptr, n)
|
|
458
|
+
return {
|
|
459
|
+
_ptr: ptr,
|
|
460
|
+
data,
|
|
461
|
+
length: n,
|
|
462
|
+
free() { wasm._glue_free(ptr) }
|
|
463
|
+
}
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
// ── Session API (pointer-level control for training loops) ──
|
|
467
|
+
|
|
468
|
+
api.createSession = (arenaBytes) => {
|
|
469
|
+
const arena = wasm._glue_arena_create(arenaBytes || 65536)
|
|
470
|
+
const session = {}
|
|
471
|
+
|
|
472
|
+
session.writeArray = (arr) => writeF64(wasm, arena, arr)
|
|
473
|
+
session.writeI32Array = (arr) => writeI32(wasm, arena, arr)
|
|
474
|
+
session.readArray = (ptr, n) => readF64(wasm, ptr, n)
|
|
475
|
+
|
|
476
|
+
session.copyToPtr = (destPtr, srcArray) => {
|
|
477
|
+
const heap = new Float64Array(wasm.wasmMemory.buffer, destPtr, srcArray.length)
|
|
478
|
+
heap.set(srcArray)
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
session.softmax = (inputPtr, n) => {
|
|
482
|
+
return wasm._softmax(arena, inputPtr, n)
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
session.vecadd = (aPtr, bPtr, n) => {
|
|
486
|
+
return wasm._vecadd(arena, aPtr, bPtr, n)
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
session.vecscale = (arrPtr, scalar, n) => {
|
|
490
|
+
return wasm._vecscale(arena, arrPtr, scalar, n)
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
session.matmul = (aPtr, bPtr, m, n, p) => {
|
|
494
|
+
return wasm._matmul(arena, aPtr, bPtr, m, n, p)
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
session.matadd = (aPtr, bPtr, m, n) => {
|
|
498
|
+
return wasm._matadd(arena, aPtr, bPtr, m, n)
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
session.matscale = (aPtr, s, m, n) => {
|
|
502
|
+
return wasm._matscale(arena, aPtr, s, m, n)
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
session.mattranspose = (aPtr, m, n) => {
|
|
506
|
+
return wasm._mattranspose(arena, aPtr, m, n)
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
session.reluBackward = (doutPtr, inputPtr, n) => {
|
|
510
|
+
return wasm._relu_backward(arena, doutPtr, inputPtr, n)
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
session.sigmoidBackward = (doutPtr, outputPtr, n) => {
|
|
514
|
+
return wasm._sigmoid_backward(arena, doutPtr, outputPtr, n)
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
session.softmaxBackward = (doutPtr, outputPtr, n) => {
|
|
518
|
+
return wasm._softmax_backward(arena, doutPtr, outputPtr, n)
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
session.matmulBackwardA = (doutPtr, bPtr, m, n, p) => {
|
|
522
|
+
return wasm._matmul_backward_a(arena, doutPtr, bPtr, m, n, p)
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
session.matmulBackwardB = (aPtr, doutPtr, m, n, p) => {
|
|
526
|
+
return wasm._matmul_backward_b(arena, aPtr, doutPtr, m, n, p)
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
session.sgdUpdate = (weightsPtr, gradsPtr, lr, n) => {
|
|
530
|
+
return wasm._sgd_update(arena, weightsPtr, gradsPtr, lr, n)
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
session.normalize = (ptr, n) => {
|
|
534
|
+
return wasm._normalize(arena, ptr, n)
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
session.batchRelu = (ptr, n) => {
|
|
538
|
+
return wasm._batch_relu(arena, ptr, n)
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
session.batchSigmoid = (ptr, n) => {
|
|
542
|
+
return wasm._batch_sigmoid(arena, ptr, n)
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
session.batchSoftmax = (ptr, rows, cols) => {
|
|
546
|
+
return wasm._batch_softmax(arena, ptr, rows, cols)
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
session.batchNormalize = (ptr, nSamples, nFeatures) => {
|
|
550
|
+
return wasm._batch_normalize(arena, ptr, nSamples, nFeatures)
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
session.mseBackward = (predPtr, targetsPtr, n) => {
|
|
554
|
+
return wasm._mse_backward(arena, predPtr, targetsPtr, n)
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
session.crossEntropyBackward = (predPtr, targetsPtr, nSamples, nClasses) => {
|
|
558
|
+
return wasm._cross_entropy_backward(arena, predPtr, targetsPtr, nSamples, nClasses)
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
session.randomUniform = (min, max, n) => {
|
|
562
|
+
return wasm._random_uniform(arena, min, max, n)
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
session.randomNormal = (mean, std, n) => {
|
|
566
|
+
return wasm._random_normal(arena, mean, std, n)
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
session.denseForward = (inputPtr, weightsPtr, m, n, p, activation) => {
|
|
570
|
+
const act = actEnum(activation)
|
|
571
|
+
const outPtr = wasm._glue_dense_forward(arena, inputPtr, weightsPtr, m, n, p, act)
|
|
572
|
+
const cachePtr = wasm._glue_get_cache_ptr()
|
|
573
|
+
return { outputPtr: outPtr, cachePtr }
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
session.denseBackward = (doutPtr, inputPtr, weightsPtr, cachePtr, m, n, p, activation) => {
|
|
577
|
+
const act = actEnum(activation)
|
|
578
|
+
wasm._glue_dense_backward(arena, doutPtr, inputPtr, weightsPtr, cachePtr, m, n, p, act)
|
|
579
|
+
return {
|
|
580
|
+
dWeightsPtr: wasm._glue_get_dw_ptr(),
|
|
581
|
+
dInputPtr: wasm._glue_get_di_ptr()
|
|
582
|
+
}
|
|
583
|
+
}
|
|
584
|
+
|
|
585
|
+
session.mseLoss = (predPtr, targetsPtr, n) => {
|
|
586
|
+
return wasm._mse_loss(predPtr, targetsPtr, n)
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
session.crossEntropy = (predPtr, targetsPtr, n) => {
|
|
590
|
+
return wasm._cross_entropy(predPtr, targetsPtr, n)
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
session.accuracy = (predPtr, labelsPtr, nSamples, nClasses) => {
|
|
594
|
+
return wasm._accuracy(predPtr, labelsPtr, nSamples, nClasses)
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
session.clear = () => wasm._glue_arena_clear(arena)
|
|
598
|
+
|
|
599
|
+
session.destroy = () => wasm._glue_arena_destroy(arena)
|
|
600
|
+
|
|
601
|
+
return session
|
|
602
|
+
}
|
|
603
|
+
|
|
604
|
+
// ── Expose constants ──
|
|
605
|
+
api.ACTIVATION_NONE = 0
|
|
606
|
+
api.ACTIVATION_RELU = 1
|
|
607
|
+
api.ACTIVATION_SIGMOID = 2
|
|
608
|
+
api.ACTIVATION_SOFTMAX = 3
|
|
609
|
+
|
|
610
|
+
return api
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
export default init
|