@wlearn/lightgbm 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +9 -0
- package/LICENSE +21 -0
- package/README.md +49 -0
- package/package.json +35 -0
- package/src/booster.js +229 -0
- package/src/dataset.js +103 -0
- package/src/index.js +4 -0
- package/src/model.js +424 -0
- package/src/wasm.js +27 -0
- package/wasm/BUILD_INFO +6 -0
- package/wasm/lightgbm.cjs +0 -0
package/CHANGELOG.md
ADDED
package/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 StatSim
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
package/README.md
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# @wlearn/lightgbm
|
|
2
|
+
|
|
3
|
+
LightGBM WASM port for wlearn. Gradient boosting for classification and regression,
|
|
4
|
+
running in browser and Node.js via WebAssembly.
|
|
5
|
+
|
|
6
|
+
## Installation
|
|
7
|
+
|
|
8
|
+
```bash
|
|
9
|
+
npm install @wlearn/lightgbm
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
## Usage
|
|
13
|
+
|
|
14
|
+
```javascript
|
|
15
|
+
import { LGBModel } from '@wlearn/lightgbm'
|
|
16
|
+
|
|
17
|
+
// Create and train
|
|
18
|
+
const model = await LGBModel.create({
|
|
19
|
+
objective: 'binary',
|
|
20
|
+
learning_rate: 0.05,
|
|
21
|
+
num_leaves: 31,
|
|
22
|
+
numRound: 100
|
|
23
|
+
})
|
|
24
|
+
model.fit(X, y)
|
|
25
|
+
|
|
26
|
+
// Predict
|
|
27
|
+
const predictions = model.predict(X_test)
|
|
28
|
+
const probabilities = model.predictProba(X_test)
|
|
29
|
+
const accuracy = model.score(X_test, y_test)
|
|
30
|
+
|
|
31
|
+
// Save and load
|
|
32
|
+
const bundle = model.save()
|
|
33
|
+
const loaded = await LGBModel.load(bundle)
|
|
34
|
+
|
|
35
|
+
// Clean up
|
|
36
|
+
model.dispose()
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Supported objectives
|
|
40
|
+
|
|
41
|
+
- `binary` -- binary classification
|
|
42
|
+
- `multiclass` -- multiclass classification (softmax)
|
|
43
|
+
- `multiclassova` -- multiclass one-vs-all
|
|
44
|
+
- `cross_entropy` -- cross-entropy classification
|
|
45
|
+
- `regression` -- regression (default)
|
|
46
|
+
|
|
47
|
+
## License
|
|
48
|
+
|
|
49
|
+
MIT (upstream LightGBM is MIT-licensed)
|
package/package.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@wlearn/lightgbm",
|
|
3
|
+
"version": "0.1.0",
|
|
4
|
+
"description": "LightGBM WASM port for wlearn",
|
|
5
|
+
"type": "module",
|
|
6
|
+
"main": "src/index.js",
|
|
7
|
+
"exports": {
|
|
8
|
+
".": "./src/index.js"
|
|
9
|
+
},
|
|
10
|
+
"files": [
|
|
11
|
+
"src/",
|
|
12
|
+
"wasm/",
|
|
13
|
+
"LICENSE",
|
|
14
|
+
"README.md",
|
|
15
|
+
"CHANGELOG.md"
|
|
16
|
+
],
|
|
17
|
+
"sideEffects": false,
|
|
18
|
+
"publishConfig": {
|
|
19
|
+
"access": "public"
|
|
20
|
+
},
|
|
21
|
+
"scripts": {
|
|
22
|
+
"test": "node --experimental-vm-modules test/test.js",
|
|
23
|
+
"build": "bash scripts/build-wasm.sh",
|
|
24
|
+
"verify": "bash scripts/verify-exports.sh"
|
|
25
|
+
},
|
|
26
|
+
"dependencies": {
|
|
27
|
+
"@wlearn/types": "0.1.0",
|
|
28
|
+
"@wlearn/core": "0.1.0"
|
|
29
|
+
},
|
|
30
|
+
"license": "MIT",
|
|
31
|
+
"repository": {
|
|
32
|
+
"type": "git",
|
|
33
|
+
"url": "https://github.com/wlearn-org/lightgbm-wasm.git"
|
|
34
|
+
}
|
|
35
|
+
}
|
package/src/booster.js
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import { getWasm } from './wasm.js'
|
|
2
|
+
|
|
3
|
+
// FinalizationRegistry safety net
|
|
4
|
+
const registry = typeof FinalizationRegistry !== 'undefined'
|
|
5
|
+
? new FinalizationRegistry(({ ptr, freeFn }) => {
|
|
6
|
+
if (ptr[0]) {
|
|
7
|
+
console.warn('@wlearn/lightgbm: Booster was not disposed -- calling free() automatically.')
|
|
8
|
+
freeFn(ptr[0])
|
|
9
|
+
}
|
|
10
|
+
})
|
|
11
|
+
: null
|
|
12
|
+
|
|
13
|
+
function withCString(wasm, str, fn) {
|
|
14
|
+
const bytes = new TextEncoder().encode(str + '\0')
|
|
15
|
+
const ptr = wasm._malloc(bytes.length)
|
|
16
|
+
wasm.HEAPU8.set(bytes, ptr)
|
|
17
|
+
try {
|
|
18
|
+
return fn(ptr)
|
|
19
|
+
} finally {
|
|
20
|
+
wasm._free(ptr)
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
function getLastError(wasm) {
|
|
25
|
+
return wasm.UTF8ToString(wasm._wl_lgb_get_last_error())
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
// Internal sentinel for loadModel path
|
|
29
|
+
const LOAD_SENTINEL = Symbol('load')
|
|
30
|
+
|
|
31
|
+
export class Booster {
|
|
32
|
+
#handle = null
|
|
33
|
+
#freed = false
|
|
34
|
+
#ptrRef = null
|
|
35
|
+
|
|
36
|
+
constructor(trainDataHandle, paramsStr) {
|
|
37
|
+
// Internal path: loadModel passes sentinel + handle
|
|
38
|
+
if (trainDataHandle === LOAD_SENTINEL) {
|
|
39
|
+
this.#handle = paramsStr // second arg holds the handle
|
|
40
|
+
this.#freed = false
|
|
41
|
+
this.#ptrRef = [this.#handle]
|
|
42
|
+
if (registry) {
|
|
43
|
+
registry.register(this, {
|
|
44
|
+
ptr: this.#ptrRef,
|
|
45
|
+
freeFn: (h) => getWasm()._wl_lgb_booster_free(h)
|
|
46
|
+
}, this)
|
|
47
|
+
}
|
|
48
|
+
return
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
const wasm = getWasm()
|
|
52
|
+
const outPtr = wasm._malloc(4)
|
|
53
|
+
|
|
54
|
+
const ret = withCString(wasm, paramsStr, (paramsPtr) =>
|
|
55
|
+
wasm._wl_lgb_booster_create(trainDataHandle, paramsPtr, outPtr)
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if (ret !== 0) {
|
|
59
|
+
wasm._free(outPtr)
|
|
60
|
+
throw new Error(`Booster creation failed: ${getLastError(wasm)}`)
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
this.#handle = wasm.getValue(outPtr, 'i32')
|
|
64
|
+
wasm._free(outPtr)
|
|
65
|
+
|
|
66
|
+
// Leak detection
|
|
67
|
+
this.#ptrRef = [this.#handle]
|
|
68
|
+
if (registry) {
|
|
69
|
+
registry.register(this, {
|
|
70
|
+
ptr: this.#ptrRef,
|
|
71
|
+
freeFn: (h) => getWasm()._wl_lgb_booster_free(h)
|
|
72
|
+
}, this)
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
get handle() {
|
|
77
|
+
if (this.#freed) throw new Error('Booster already disposed')
|
|
78
|
+
return this.#handle
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
update() {
|
|
82
|
+
const wasm = getWasm()
|
|
83
|
+
const finishedPtr = wasm._malloc(4)
|
|
84
|
+
wasm.setValue(finishedPtr, 0, 'i32')
|
|
85
|
+
|
|
86
|
+
const ret = wasm._wl_lgb_booster_update(this.handle, finishedPtr)
|
|
87
|
+
const finished = wasm.getValue(finishedPtr, 'i32')
|
|
88
|
+
wasm._free(finishedPtr)
|
|
89
|
+
|
|
90
|
+
if (ret !== 0) {
|
|
91
|
+
throw new Error(`Booster update failed: ${getLastError(wasm)}`)
|
|
92
|
+
}
|
|
93
|
+
return finished !== 0
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
getNumClasses() {
|
|
97
|
+
const wasm = getWasm()
|
|
98
|
+
const outPtr = wasm._malloc(4)
|
|
99
|
+
|
|
100
|
+
const ret = wasm._wl_lgb_booster_get_num_classes(this.handle, outPtr)
|
|
101
|
+
const nc = wasm.getValue(outPtr, 'i32')
|
|
102
|
+
wasm._free(outPtr)
|
|
103
|
+
|
|
104
|
+
if (ret !== 0) {
|
|
105
|
+
throw new Error(`Booster getNumClasses failed: ${getLastError(wasm)}`)
|
|
106
|
+
}
|
|
107
|
+
return nc
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
predict(data, nrow, ncol, { predictType = 0, numIteration = 0 } = {}) {
|
|
111
|
+
const wasm = getWasm()
|
|
112
|
+
|
|
113
|
+
// Copy float32 data to WASM heap
|
|
114
|
+
const dataPtr = wasm._malloc(data.length * 4)
|
|
115
|
+
wasm.HEAPF32.set(data, dataPtr / 4)
|
|
116
|
+
|
|
117
|
+
// Output length
|
|
118
|
+
const outLenPtr = wasm._malloc(4)
|
|
119
|
+
|
|
120
|
+
// Allocate output buffer (estimate: nrow * numClasses)
|
|
121
|
+
// For safety, allocate max(nrow * max_classes, nrow * 100)
|
|
122
|
+
const maxOut = nrow * 100
|
|
123
|
+
const outResultPtr = wasm._malloc(maxOut * 8)
|
|
124
|
+
|
|
125
|
+
const ret = withCString(wasm, '', (paramPtr) =>
|
|
126
|
+
wasm._wl_lgb_booster_predict(
|
|
127
|
+
this.handle, dataPtr, nrow, ncol,
|
|
128
|
+
predictType, numIteration, paramPtr,
|
|
129
|
+
outLenPtr, outResultPtr
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
wasm._free(dataPtr)
|
|
134
|
+
|
|
135
|
+
if (ret !== 0) {
|
|
136
|
+
wasm._free(outLenPtr)
|
|
137
|
+
wasm._free(outResultPtr)
|
|
138
|
+
throw new Error(`Booster predict failed: ${getLastError(wasm)}`)
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
const outLen = wasm.getValue(outLenPtr, 'i32')
|
|
142
|
+
wasm._free(outLenPtr)
|
|
143
|
+
|
|
144
|
+
// Copy results
|
|
145
|
+
const result = new Float64Array(outLen)
|
|
146
|
+
for (let i = 0; i < outLen; i++) {
|
|
147
|
+
result[i] = wasm.HEAPF64[outResultPtr / 8 + i]
|
|
148
|
+
}
|
|
149
|
+
wasm._free(outResultPtr)
|
|
150
|
+
|
|
151
|
+
return result
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
saveModel() {
|
|
155
|
+
const wasm = getWasm()
|
|
156
|
+
|
|
157
|
+
// First pass: get required buffer length
|
|
158
|
+
const outLenPtr = wasm._malloc(4)
|
|
159
|
+
let ret = wasm._wl_lgb_booster_save_model(this.handle, 0, outLenPtr, 0)
|
|
160
|
+
|
|
161
|
+
if (ret !== 0) {
|
|
162
|
+
wasm._free(outLenPtr)
|
|
163
|
+
throw new Error(`Booster saveModel (size query) failed: ${getLastError(wasm)}`)
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
const bufLen = wasm.getValue(outLenPtr, 'i32')
|
|
167
|
+
|
|
168
|
+
// Second pass: get actual model string
|
|
169
|
+
const bufPtr = wasm._malloc(bufLen)
|
|
170
|
+
ret = wasm._wl_lgb_booster_save_model(this.handle, bufLen, outLenPtr, bufPtr)
|
|
171
|
+
|
|
172
|
+
if (ret !== 0) {
|
|
173
|
+
wasm._free(outLenPtr)
|
|
174
|
+
wasm._free(bufPtr)
|
|
175
|
+
throw new Error(`Booster saveModel failed: ${getLastError(wasm)}`)
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
const actualLen = wasm.getValue(outLenPtr, 'i32')
|
|
179
|
+
wasm._free(outLenPtr)
|
|
180
|
+
|
|
181
|
+
// Copy to JS Uint8Array (text model, null-terminated)
|
|
182
|
+
const result = new Uint8Array(actualLen - 1) // exclude null terminator
|
|
183
|
+
result.set(wasm.HEAPU8.subarray(bufPtr, bufPtr + actualLen - 1))
|
|
184
|
+
wasm._free(bufPtr)
|
|
185
|
+
|
|
186
|
+
return result
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
static loadModel(modelBytes) {
|
|
190
|
+
const wasm = getWasm()
|
|
191
|
+
|
|
192
|
+
// Copy model string to WASM heap (add null terminator)
|
|
193
|
+
const buf = modelBytes instanceof Uint8Array ? modelBytes : new Uint8Array(modelBytes)
|
|
194
|
+
const bufPtr = wasm._malloc(buf.length + 1)
|
|
195
|
+
wasm.HEAPU8.set(buf, bufPtr)
|
|
196
|
+
wasm.HEAPU8[bufPtr + buf.length] = 0 // null terminator
|
|
197
|
+
|
|
198
|
+
const outIterPtr = wasm._malloc(4)
|
|
199
|
+
const outPtr = wasm._malloc(4)
|
|
200
|
+
|
|
201
|
+
const ret = wasm._wl_lgb_booster_load_model(bufPtr, outIterPtr, outPtr)
|
|
202
|
+
|
|
203
|
+
wasm._free(bufPtr)
|
|
204
|
+
wasm._free(outIterPtr)
|
|
205
|
+
|
|
206
|
+
if (ret !== 0) {
|
|
207
|
+
wasm._free(outPtr)
|
|
208
|
+
throw new Error(`Booster loadModel failed: ${getLastError(wasm)}`)
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
const handle = wasm.getValue(outPtr, 'i32')
|
|
212
|
+
wasm._free(outPtr)
|
|
213
|
+
|
|
214
|
+
return new Booster(LOAD_SENTINEL, handle)
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
dispose() {
|
|
218
|
+
if (this.#freed) return
|
|
219
|
+
this.#freed = true
|
|
220
|
+
|
|
221
|
+
const wasm = getWasm()
|
|
222
|
+
wasm._wl_lgb_booster_free(this.#handle)
|
|
223
|
+
|
|
224
|
+
if (this.#ptrRef) this.#ptrRef[0] = null
|
|
225
|
+
if (registry) registry.unregister(this)
|
|
226
|
+
|
|
227
|
+
this.#handle = null
|
|
228
|
+
}
|
|
229
|
+
}
|
package/src/dataset.js
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import { getWasm } from './wasm.js'
|
|
2
|
+
|
|
3
|
+
// FinalizationRegistry safety net
|
|
4
|
+
const registry = typeof FinalizationRegistry !== 'undefined'
|
|
5
|
+
? new FinalizationRegistry(({ ptr, freeFn }) => {
|
|
6
|
+
if (ptr[0]) {
|
|
7
|
+
console.warn('@wlearn/lightgbm: Dataset was not disposed -- calling free() automatically.')
|
|
8
|
+
freeFn(ptr[0])
|
|
9
|
+
}
|
|
10
|
+
})
|
|
11
|
+
: null
|
|
12
|
+
|
|
13
|
+
function withCString(wasm, str, fn) {
|
|
14
|
+
const bytes = new TextEncoder().encode(str + '\0')
|
|
15
|
+
const ptr = wasm._malloc(bytes.length)
|
|
16
|
+
wasm.HEAPU8.set(bytes, ptr)
|
|
17
|
+
try {
|
|
18
|
+
return fn(ptr)
|
|
19
|
+
} finally {
|
|
20
|
+
wasm._free(ptr)
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
function getLastError(wasm) {
|
|
25
|
+
return wasm.UTF8ToString(wasm._wl_lgb_get_last_error())
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
export class Dataset {
|
|
29
|
+
#handle = null
|
|
30
|
+
#freed = false
|
|
31
|
+
#ptrRef = null
|
|
32
|
+
|
|
33
|
+
constructor(data, nrow, ncol, params = '') {
|
|
34
|
+
const wasm = getWasm()
|
|
35
|
+
|
|
36
|
+
// Copy float32 data to WASM heap
|
|
37
|
+
const dataBytes = data.length * 4
|
|
38
|
+
const dataPtr = wasm._malloc(dataBytes)
|
|
39
|
+
wasm.HEAPF32.set(data, dataPtr / 4)
|
|
40
|
+
|
|
41
|
+
// Output handle pointer
|
|
42
|
+
const outPtr = wasm._malloc(4)
|
|
43
|
+
|
|
44
|
+
const ret = withCString(wasm, params, (paramsPtr) =>
|
|
45
|
+
wasm._wl_lgb_dataset_create_from_mat(dataPtr, nrow, ncol, paramsPtr, outPtr)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if (ret !== 0) {
|
|
49
|
+
wasm._free(dataPtr)
|
|
50
|
+
wasm._free(outPtr)
|
|
51
|
+
throw new Error(`Dataset creation failed: ${getLastError(wasm)}`)
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
this.#handle = wasm.getValue(outPtr, 'i32')
|
|
55
|
+
wasm._free(dataPtr)
|
|
56
|
+
wasm._free(outPtr)
|
|
57
|
+
|
|
58
|
+
// Leak detection
|
|
59
|
+
this.#ptrRef = [this.#handle]
|
|
60
|
+
if (registry) {
|
|
61
|
+
registry.register(this, {
|
|
62
|
+
ptr: this.#ptrRef,
|
|
63
|
+
freeFn: (h) => getWasm()._wl_lgb_dataset_free(h)
|
|
64
|
+
}, this)
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
get handle() {
|
|
69
|
+
if (this.#freed) throw new Error('Dataset already disposed')
|
|
70
|
+
return this.#handle
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
setLabel(labels) {
|
|
74
|
+
const wasm = getWasm()
|
|
75
|
+
const arr = labels instanceof Float32Array ? labels : new Float32Array(labels)
|
|
76
|
+
const ptr = wasm._malloc(arr.length * 4)
|
|
77
|
+
wasm.HEAPF32.set(arr, ptr / 4)
|
|
78
|
+
|
|
79
|
+
// C_API_DTYPE_FLOAT32 = 0
|
|
80
|
+
const ret = withCString(wasm, 'label', (fieldPtr) =>
|
|
81
|
+
wasm._wl_lgb_dataset_set_field(this.handle, fieldPtr, ptr, arr.length, 0)
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
wasm._free(ptr)
|
|
85
|
+
|
|
86
|
+
if (ret !== 0) {
|
|
87
|
+
throw new Error(`Dataset setLabel failed: ${getLastError(wasm)}`)
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
dispose() {
|
|
92
|
+
if (this.#freed) return
|
|
93
|
+
this.#freed = true
|
|
94
|
+
|
|
95
|
+
const wasm = getWasm()
|
|
96
|
+
wasm._wl_lgb_dataset_free(this.#handle)
|
|
97
|
+
|
|
98
|
+
if (this.#ptrRef) this.#ptrRef[0] = null
|
|
99
|
+
if (registry) registry.unregister(this)
|
|
100
|
+
|
|
101
|
+
this.#handle = null
|
|
102
|
+
}
|
|
103
|
+
}
|
package/src/index.js
ADDED
package/src/model.js
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
import { loadLGB, getWasm } from './wasm.js'
|
|
2
|
+
import { Dataset } from './dataset.js'
|
|
3
|
+
import { Booster } from './booster.js'
|
|
4
|
+
import {
|
|
5
|
+
normalizeY,
|
|
6
|
+
encodeBundle, decodeBundle,
|
|
7
|
+
register,
|
|
8
|
+
DisposedError, NotFittedError
|
|
9
|
+
} from '@wlearn/core'
|
|
10
|
+
|
|
11
|
+
// FinalizationRegistry safety net
|
|
12
|
+
const leakRegistry = typeof FinalizationRegistry !== 'undefined'
|
|
13
|
+
? new FinalizationRegistry(({ ref, freeFn }) => {
|
|
14
|
+
if (ref[0]) {
|
|
15
|
+
console.warn('@wlearn/lightgbm: LGBModel was not disposed -- calling free() automatically.')
|
|
16
|
+
freeFn(ref[0])
|
|
17
|
+
}
|
|
18
|
+
})
|
|
19
|
+
: null
|
|
20
|
+
|
|
21
|
+
// --- Objective classification ---
|
|
22
|
+
|
|
23
|
+
const CLASSIFIER_OBJECTIVES = new Set([
|
|
24
|
+
'binary', 'multiclass', 'multiclassova', 'cross_entropy'
|
|
25
|
+
])
|
|
26
|
+
|
|
27
|
+
const PROBA_OBJECTIVES = new Set([
|
|
28
|
+
'binary', 'multiclass', 'multiclassova'
|
|
29
|
+
])
|
|
30
|
+
|
|
31
|
+
// LightGBM params that are wlearn-only (not passed to Booster)
|
|
32
|
+
const WLEARN_PARAMS = new Set(['numRound', 'coerce'])
|
|
33
|
+
|
|
34
|
+
// --- Internal sentinel for load path ---
|
|
35
|
+
const LOAD_SENTINEL = Symbol('load')
|
|
36
|
+
|
|
37
|
+
// --- LGBModel ---
|
|
38
|
+
|
|
39
|
+
export class LGBModel {
|
|
40
|
+
#booster = null
|
|
41
|
+
#freed = false
|
|
42
|
+
#boosterRef = null
|
|
43
|
+
#params = {}
|
|
44
|
+
#fitted = false
|
|
45
|
+
#nrClass = 0
|
|
46
|
+
#classes = null
|
|
47
|
+
|
|
48
|
+
constructor(handle, params, extra) {
|
|
49
|
+
if (handle === LOAD_SENTINEL) {
|
|
50
|
+
// Load path
|
|
51
|
+
this.#booster = params
|
|
52
|
+
this.#params = extra.params || {}
|
|
53
|
+
this.#nrClass = extra.nrClass || 0
|
|
54
|
+
this.#classes = extra.classes ? new Int32Array(extra.classes) : null
|
|
55
|
+
this.#fitted = true
|
|
56
|
+
this.#freed = false
|
|
57
|
+
this.#boosterRef = [this.#booster]
|
|
58
|
+
if (leakRegistry) {
|
|
59
|
+
leakRegistry.register(this, {
|
|
60
|
+
ref: this.#boosterRef,
|
|
61
|
+
freeFn: (b) => { try { b.dispose() } catch {} }
|
|
62
|
+
}, this)
|
|
63
|
+
}
|
|
64
|
+
} else {
|
|
65
|
+
// Normal construction from create()
|
|
66
|
+
this.#params = handle || {}
|
|
67
|
+
this.#freed = false
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
static async create(params = {}) {
|
|
72
|
+
await loadLGB()
|
|
73
|
+
return new LGBModel(params)
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
// --- Estimator interface ---
|
|
77
|
+
|
|
78
|
+
fit(X, y) {
|
|
79
|
+
this.#ensureNotDisposed()
|
|
80
|
+
|
|
81
|
+
// Dispose previous booster if refitting
|
|
82
|
+
if (this.#booster) {
|
|
83
|
+
this.#booster.dispose()
|
|
84
|
+
this.#booster = null
|
|
85
|
+
if (this.#boosterRef) this.#boosterRef[0] = null
|
|
86
|
+
if (leakRegistry) leakRegistry.unregister(this)
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
const { data: xData, rows, cols } = this.#normalizeX(X)
|
|
90
|
+
const yNorm = normalizeY(y)
|
|
91
|
+
|
|
92
|
+
// Detect objective (default to regression)
|
|
93
|
+
const obj = this.#params.objective || 'regression'
|
|
94
|
+
|
|
95
|
+
// For classifiers: validate and extract classes, remap to 0-based
|
|
96
|
+
let yTrain
|
|
97
|
+
if (CLASSIFIER_OBJECTIVES.has(obj)) {
|
|
98
|
+
const unique = new Set()
|
|
99
|
+
for (let i = 0; i < yNorm.length; i++) {
|
|
100
|
+
const v = yNorm[i]
|
|
101
|
+
if (v !== Math.floor(v)) {
|
|
102
|
+
throw new Error(`Classifier labels must be integers, got ${v} at index ${i}`)
|
|
103
|
+
}
|
|
104
|
+
unique.add(v)
|
|
105
|
+
}
|
|
106
|
+
const sorted = [...unique].sort((a, b) => a - b)
|
|
107
|
+
this.#classes = new Int32Array(sorted)
|
|
108
|
+
this.#nrClass = sorted.length
|
|
109
|
+
|
|
110
|
+
// Remap to 0-based contiguous
|
|
111
|
+
const classMap = new Map()
|
|
112
|
+
for (let i = 0; i < sorted.length; i++) classMap.set(sorted[i], i)
|
|
113
|
+
yTrain = new Float32Array(yNorm.length)
|
|
114
|
+
for (let i = 0; i < yNorm.length; i++) yTrain[i] = classMap.get(yNorm[i])
|
|
115
|
+
} else {
|
|
116
|
+
this.#classes = null
|
|
117
|
+
this.#nrClass = 0
|
|
118
|
+
yTrain = yNorm instanceof Float32Array ? yNorm : new Float32Array(yNorm)
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
if (yTrain.length !== rows) {
|
|
122
|
+
throw new Error(`y length (${yTrain.length}) does not match X rows (${rows})`)
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// Build LightGBM param string: "key1=value1 key2=value2"
|
|
126
|
+
const numRound = this.#params.numRound || 100
|
|
127
|
+
const lgbParams = {}
|
|
128
|
+
for (const [key, val] of Object.entries(this.#params)) {
|
|
129
|
+
if (!WLEARN_PARAMS.has(key)) lgbParams[key] = val
|
|
130
|
+
}
|
|
131
|
+
// Defaults
|
|
132
|
+
if (!('objective' in lgbParams)) lgbParams.objective = obj
|
|
133
|
+
if (!('verbosity' in lgbParams)) lgbParams.verbosity = -1
|
|
134
|
+
|
|
135
|
+
// Auto-set num_class for multiclass
|
|
136
|
+
if ((obj === 'multiclass' || obj === 'multiclassova') &&
|
|
137
|
+
!('num_class' in lgbParams) && this.#nrClass > 0) {
|
|
138
|
+
lgbParams.num_class = this.#nrClass
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
const paramStr = Object.entries(lgbParams)
|
|
142
|
+
.map(([k, v]) => `${k}=${v}`)
|
|
143
|
+
.join(' ')
|
|
144
|
+
|
|
145
|
+
// Create Dataset
|
|
146
|
+
const ds = new Dataset(xData, rows, cols, paramStr)
|
|
147
|
+
ds.setLabel(yTrain)
|
|
148
|
+
|
|
149
|
+
// Create Booster and train
|
|
150
|
+
const booster = new Booster(ds.handle, paramStr)
|
|
151
|
+
for (let i = 0; i < numRound; i++) {
|
|
152
|
+
booster.update()
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
ds.dispose()
|
|
156
|
+
|
|
157
|
+
this.#booster = booster
|
|
158
|
+
this.#fitted = true
|
|
159
|
+
|
|
160
|
+
this.#boosterRef = [this.#booster]
|
|
161
|
+
if (leakRegistry) {
|
|
162
|
+
leakRegistry.register(this, {
|
|
163
|
+
ref: this.#boosterRef,
|
|
164
|
+
freeFn: (b) => { try { b.dispose() } catch {} }
|
|
165
|
+
}, this)
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
return this
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
predict(X) {
|
|
172
|
+
this.#ensureFitted()
|
|
173
|
+
const { data: xData, rows, cols } = this.#normalizeX(X)
|
|
174
|
+
const rawPreds = this.#booster.predict(xData, rows, cols)
|
|
175
|
+
|
|
176
|
+
const obj = this.#params.objective || 'regression'
|
|
177
|
+
|
|
178
|
+
if (!CLASSIFIER_OBJECTIVES.has(obj)) {
|
|
179
|
+
return rawPreds
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
const result = new Float64Array(rows)
|
|
183
|
+
|
|
184
|
+
if (obj === 'binary') {
|
|
185
|
+
// Raw is P(class=1), threshold at 0.5
|
|
186
|
+
for (let i = 0; i < rows; i++) {
|
|
187
|
+
result[i] = this.#classes[rawPreds[i] > 0.5 ? 1 : 0]
|
|
188
|
+
}
|
|
189
|
+
} else if (obj === 'multiclass' || obj === 'multiclassova') {
|
|
190
|
+
// Raw is rows * nrClass probabilities, argmax
|
|
191
|
+
const nc = this.#nrClass
|
|
192
|
+
for (let i = 0; i < rows; i++) {
|
|
193
|
+
let best = 0
|
|
194
|
+
for (let c = 1; c < nc; c++) {
|
|
195
|
+
if (rawPreds[i * nc + c] > rawPreds[i * nc + best]) best = c
|
|
196
|
+
}
|
|
197
|
+
result[i] = this.#classes[best]
|
|
198
|
+
}
|
|
199
|
+
} else {
|
|
200
|
+
// cross_entropy: threshold at 0.5
|
|
201
|
+
for (let i = 0; i < rows; i++) {
|
|
202
|
+
result[i] = this.#classes[rawPreds[i] > 0.5 ? 1 : 0]
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
return result
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
predictProba(X) {
|
|
210
|
+
this.#ensureFitted()
|
|
211
|
+
const obj = this.#params.objective || 'regression'
|
|
212
|
+
|
|
213
|
+
if (!PROBA_OBJECTIVES.has(obj)) {
|
|
214
|
+
throw new Error(`predictProba requires classification objective, got "${obj}"`)
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
const { data: xData, rows, cols } = this.#normalizeX(X)
|
|
218
|
+
const rawPreds = this.#booster.predict(xData, rows, cols)
|
|
219
|
+
|
|
220
|
+
if (obj === 'binary') {
|
|
221
|
+
// LightGBM returns P(class=1). Expand to rows * 2: [P(class=0), P(class=1)]
|
|
222
|
+
const result = new Float64Array(rows * 2)
|
|
223
|
+
for (let i = 0; i < rows; i++) {
|
|
224
|
+
const p1 = rawPreds[i]
|
|
225
|
+
result[i * 2] = 1 - p1
|
|
226
|
+
result[i * 2 + 1] = p1
|
|
227
|
+
}
|
|
228
|
+
return result
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
// multiclass / multiclassova: already rows * nrClass
|
|
232
|
+
return new Float64Array(rawPreds)
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
score(X, y) {
|
|
236
|
+
const preds = this.predict(X)
|
|
237
|
+
const yArr = normalizeY(y)
|
|
238
|
+
|
|
239
|
+
if (!this.#isClassifier()) {
|
|
240
|
+
// R-squared
|
|
241
|
+
let ssRes = 0, ssTot = 0, yMean = 0
|
|
242
|
+
for (let i = 0; i < yArr.length; i++) yMean += yArr[i]
|
|
243
|
+
yMean /= yArr.length
|
|
244
|
+
for (let i = 0; i < yArr.length; i++) {
|
|
245
|
+
ssRes += (yArr[i] - preds[i]) ** 2
|
|
246
|
+
ssTot += (yArr[i] - yMean) ** 2
|
|
247
|
+
}
|
|
248
|
+
return ssTot === 0 ? 0 : 1 - ssRes / ssTot
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// Accuracy
|
|
252
|
+
let correct = 0
|
|
253
|
+
for (let i = 0; i < preds.length; i++) {
|
|
254
|
+
if (preds[i] === yArr[i]) correct++
|
|
255
|
+
}
|
|
256
|
+
return correct / preds.length
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// --- Model I/O ---
|
|
260
|
+
|
|
261
|
+
save() {
|
|
262
|
+
this.#ensureFitted()
|
|
263
|
+
const modelBytes = this.#booster.saveModel()
|
|
264
|
+
const typeId = this.#isClassifier()
|
|
265
|
+
? 'wlearn.lightgbm.classifier@1'
|
|
266
|
+
: 'wlearn.lightgbm.regressor@1'
|
|
267
|
+
return encodeBundle(
|
|
268
|
+
{
|
|
269
|
+
typeId,
|
|
270
|
+
params: this.getParams(),
|
|
271
|
+
metadata: {
|
|
272
|
+
nrClass: this.#nrClass,
|
|
273
|
+
classes: this.#classes ? Array.from(this.#classes) : [],
|
|
274
|
+
objective: this.#params.objective || 'regression'
|
|
275
|
+
}
|
|
276
|
+
},
|
|
277
|
+
[{ id: 'model', data: modelBytes }]
|
|
278
|
+
)
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
static async load(bytes) {
|
|
282
|
+
const { manifest, toc, blobs } = decodeBundle(bytes)
|
|
283
|
+
return LGBModel._fromBundle(manifest, toc, blobs)
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
static async _fromBundle(manifest, toc, blobs) {
|
|
287
|
+
await loadLGB()
|
|
288
|
+
|
|
289
|
+
const entry = toc.find(e => e.id === 'model')
|
|
290
|
+
if (!entry) throw new Error('Bundle missing "model" artifact')
|
|
291
|
+
const raw = blobs.subarray(entry.offset, entry.offset + entry.length)
|
|
292
|
+
|
|
293
|
+
const booster = Booster.loadModel(raw)
|
|
294
|
+
|
|
295
|
+
const meta = manifest.metadata || {}
|
|
296
|
+
return new LGBModel(LOAD_SENTINEL, booster, {
|
|
297
|
+
params: manifest.params || {},
|
|
298
|
+
nrClass: meta.nrClass || 0,
|
|
299
|
+
classes: meta.classes || null
|
|
300
|
+
})
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
dispose() {
|
|
304
|
+
if (this.#freed) return
|
|
305
|
+
this.#freed = true
|
|
306
|
+
|
|
307
|
+
if (this.#booster) {
|
|
308
|
+
this.#booster.dispose()
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
if (this.#boosterRef) this.#boosterRef[0] = null
|
|
312
|
+
if (leakRegistry) leakRegistry.unregister(this)
|
|
313
|
+
|
|
314
|
+
this.#booster = null
|
|
315
|
+
this.#fitted = false
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
// --- Params ---
|
|
319
|
+
|
|
320
|
+
getParams() {
|
|
321
|
+
return { ...this.#params }
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
setParams(p) {
|
|
325
|
+
Object.assign(this.#params, p)
|
|
326
|
+
return this
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
static defaultSearchSpace() {
|
|
330
|
+
return {
|
|
331
|
+
objective: { type: 'categorical', values: ['binary', 'regression'] },
|
|
332
|
+
max_depth: { type: 'int_uniform', low: 3, high: 12 },
|
|
333
|
+
learning_rate: { type: 'log_uniform', low: 0.01, high: 0.3 },
|
|
334
|
+
numRound: { type: 'int_uniform', low: 50, high: 500 },
|
|
335
|
+
subsample: { type: 'uniform', low: 0.5, high: 1.0 },
|
|
336
|
+
colsample_bytree: { type: 'uniform', low: 0.5, high: 1.0 },
|
|
337
|
+
min_child_weight: { type: 'log_uniform', low: 1, high: 10 },
|
|
338
|
+
reg_lambda: { type: 'log_uniform', low: 1e-3, high: 10 },
|
|
339
|
+
reg_alpha: { type: 'log_uniform', low: 1e-3, high: 10 },
|
|
340
|
+
num_leaves: { type: 'int_uniform', low: 15, high: 127 }
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
// --- Inspection ---
|
|
345
|
+
|
|
346
|
+
get nrClass() {
|
|
347
|
+
return this.#nrClass
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
get classes() {
|
|
351
|
+
return this.#classes ? Int32Array.from(this.#classes) : new Int32Array(0)
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
get isFitted() {
|
|
355
|
+
return this.#fitted && !this.#freed
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
get capabilities() {
|
|
359
|
+
const obj = this.#params.objective || 'regression'
|
|
360
|
+
const isCls = CLASSIFIER_OBJECTIVES.has(obj)
|
|
361
|
+
return {
|
|
362
|
+
classifier: isCls,
|
|
363
|
+
regressor: !isCls,
|
|
364
|
+
predictProba: PROBA_OBJECTIVES.has(obj),
|
|
365
|
+
decisionFunction: false,
|
|
366
|
+
sampleWeight: false,
|
|
367
|
+
csr: false,
|
|
368
|
+
earlyStopping: false
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
get probaDim() {
|
|
373
|
+
if (!this.isFitted) return 0
|
|
374
|
+
const obj = this.#params.objective || 'regression'
|
|
375
|
+
if (obj === 'binary') return 2
|
|
376
|
+
if (obj === 'multiclass' || obj === 'multiclassova') return this.#nrClass
|
|
377
|
+
return 0
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
// --- Private helpers ---
|
|
381
|
+
|
|
382
|
+
#normalizeX(X) {
|
|
383
|
+
// Fast path: typed matrix { data, rows, cols }
|
|
384
|
+
if (X && typeof X === 'object' && !Array.isArray(X) && X.data) {
|
|
385
|
+
const { data, rows, cols } = X
|
|
386
|
+
if (data instanceof Float32Array) return { data, rows, cols }
|
|
387
|
+
return { data: new Float32Array(data), rows, cols }
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
// Slow path: number[][]
|
|
391
|
+
if (Array.isArray(X) && Array.isArray(X[0])) {
|
|
392
|
+
const rows = X.length
|
|
393
|
+
const cols = X[0].length
|
|
394
|
+
const data = new Float32Array(rows * cols)
|
|
395
|
+
for (let i = 0; i < rows; i++) {
|
|
396
|
+
for (let j = 0; j < cols; j++) {
|
|
397
|
+
data[i * cols + j] = X[i][j]
|
|
398
|
+
}
|
|
399
|
+
}
|
|
400
|
+
return { data, rows, cols }
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
throw new Error('X must be number[][] or { data: TypedArray, rows, cols }')
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
#ensureNotDisposed() {
|
|
407
|
+
if (this.#freed) throw new DisposedError('LGBModel has been disposed.')
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
#ensureFitted() {
|
|
411
|
+
if (this.#freed) throw new DisposedError('LGBModel has been disposed.')
|
|
412
|
+
if (!this.#fitted) throw new NotFittedError('LGBModel is not fitted. Call fit() first.')
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
#isClassifier() {
|
|
416
|
+
const obj = this.#params.objective || 'regression'
|
|
417
|
+
return CLASSIFIER_OBJECTIVES.has(obj)
|
|
418
|
+
}
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
// --- Register loaders with @wlearn/core ---
|
|
422
|
+
|
|
423
|
+
register('wlearn.lightgbm.classifier@1', async (m, t, b) => LGBModel._fromBundle(m, t, b))
|
|
424
|
+
register('wlearn.lightgbm.regressor@1', async (m, t, b) => LGBModel._fromBundle(m, t, b))
|
package/src/wasm.js
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
// WASM loader -- loads the LightGBM WASM module (singleton, lazy init)
|
|
2
|
+
|
|
3
|
+
import { createRequire } from 'module'
|
|
4
|
+
|
|
5
|
+
let wasmModule = null
|
|
6
|
+
let loading = null
|
|
7
|
+
|
|
8
|
+
export async function loadLGB(options = {}) {
|
|
9
|
+
if (wasmModule) return wasmModule
|
|
10
|
+
if (loading) return loading
|
|
11
|
+
|
|
12
|
+
loading = (async () => {
|
|
13
|
+
// SINGLE_FILE=1: .wasm is embedded in the .cjs file, no locateFile needed
|
|
14
|
+
// Emscripten output is CJS, use createRequire for ESM compatibility
|
|
15
|
+
const require = createRequire(import.meta.url)
|
|
16
|
+
const createLightGBM = require('../wasm/lightgbm.cjs')
|
|
17
|
+
wasmModule = await createLightGBM(options)
|
|
18
|
+
return wasmModule
|
|
19
|
+
})()
|
|
20
|
+
|
|
21
|
+
return loading
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
export function getWasm() {
|
|
25
|
+
if (!wasmModule) throw new Error('WASM not loaded -- call loadLGB() first')
|
|
26
|
+
return wasmModule
|
|
27
|
+
}
|
package/wasm/BUILD_INFO
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
upstream: LightGBM v4.6.0
|
|
2
|
+
upstream_commit: d02a01ac6f51d36c9e62388243bcb75c3b1b1774
|
|
3
|
+
build_date: 2026-02-27T13:05:21Z
|
|
4
|
+
emscripten: emcc (Emscripten gcc/clang-like replacement + linker emulating GNU ld) 5.0.2 (dc80f645ee70178c11666de0c3860d9e064d50e4)
|
|
5
|
+
build_flags: -O2 -fexceptions SINGLE_FILE=1
|
|
6
|
+
wasm_embedded: true
|
|
Binary file
|