@wlearn/ensemble 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +26 -0
- package/src/bagging.js +410 -0
- package/src/index.js +6 -0
- package/src/oof.js +96 -0
- package/src/selection.js +127 -0
- package/src/stacking.js +372 -0
- package/src/voting.js +311 -0
- package/src/weights.js +143 -0
package/src/stacking.js
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
import {
|
|
2
|
+
encodeBundle, decodeBundle, register, load as registryLoad,
|
|
3
|
+
normalizeX, normalizeY, accuracy, r2Score,
|
|
4
|
+
stratifiedKFold, kFold,
|
|
5
|
+
ValidationError, NotFittedError, DisposedError,
|
|
6
|
+
lift
|
|
7
|
+
} from '@wlearn/core'
|
|
8
|
+
|
|
9
|
+
const TYPE_ID_CLS = 'wlearn.ensemble.stacking.classifier@1'
|
|
10
|
+
const TYPE_ID_REG = 'wlearn.ensemble.stacking.regressor@1'
|
|
11
|
+
let _registered = false
|
|
12
|
+
|
|
13
|
+
export class StackingEnsemble {
|
|
14
|
+
#baseSpecs // [name, Class, params][]
|
|
15
|
+
#metaSpec // [name, Class, params]
|
|
16
|
+
#baseModels // fitted base model instances (on full data)
|
|
17
|
+
#metaModel // fitted meta-model instance
|
|
18
|
+
#cv
|
|
19
|
+
#task
|
|
20
|
+
#passthrough
|
|
21
|
+
#seed
|
|
22
|
+
#classes
|
|
23
|
+
#nClasses
|
|
24
|
+
#nMetaCols
|
|
25
|
+
#fitted = false
|
|
26
|
+
#disposed = false
|
|
27
|
+
|
|
28
|
+
constructor(params) {
|
|
29
|
+
this.#baseSpecs = params.estimators || []
|
|
30
|
+
this.#metaSpec = params.finalEstimator || null
|
|
31
|
+
this.#cv = params.cv || 5
|
|
32
|
+
this.#task = params.task || 'classification'
|
|
33
|
+
this.#passthrough = params.passthrough || false
|
|
34
|
+
this.#seed = params.seed ?? 42
|
|
35
|
+
this.#baseModels = null
|
|
36
|
+
this.#metaModel = null
|
|
37
|
+
this.#classes = null
|
|
38
|
+
this.#nClasses = 0
|
|
39
|
+
this.#nMetaCols = 0
|
|
40
|
+
StackingEnsemble._register()
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
static async create(params = {}) {
|
|
44
|
+
return new StackingEnsemble(params)
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
#ensureAlive() {
|
|
48
|
+
if (this.#disposed) throw new DisposedError('StackingEnsemble has been disposed.')
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
#ensureFitted() {
|
|
52
|
+
this.#ensureAlive()
|
|
53
|
+
if (!this.#fitted) throw new NotFittedError('StackingEnsemble is not fitted. Call fit() first.')
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
async fit(X, y) {
|
|
57
|
+
this.#ensureAlive()
|
|
58
|
+
if (!this.#metaSpec) {
|
|
59
|
+
throw new ValidationError('StackingEnsemble requires a finalEstimator')
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
const Xn = normalizeX(X)
|
|
63
|
+
const yn = normalizeY(y)
|
|
64
|
+
const n = Xn.rows
|
|
65
|
+
|
|
66
|
+
// Discover classes
|
|
67
|
+
if (this.#task === 'classification') {
|
|
68
|
+
const labelSet = new Set()
|
|
69
|
+
for (let i = 0; i < yn.length; i++) labelSet.add(yn[i])
|
|
70
|
+
this.#classes = new Int32Array([...labelSet].sort((a, b) => a - b))
|
|
71
|
+
this.#nClasses = this.#classes.length
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// Generate folds
|
|
75
|
+
const folds = this.#task === 'classification'
|
|
76
|
+
? stratifiedKFold(yn, this.#cv, { shuffle: true, seed: this.#seed })
|
|
77
|
+
: kFold(n, this.#cv, { shuffle: true, seed: this.#seed })
|
|
78
|
+
|
|
79
|
+
// Step 1: Generate OOF predictions for each base model
|
|
80
|
+
const nBase = this.#baseSpecs.length
|
|
81
|
+
const colsPerModel = this.#task === 'classification' ? this.#nClasses : 1
|
|
82
|
+
const oofCols = nBase * colsPerModel
|
|
83
|
+
const oofData = new Float64Array(n * oofCols)
|
|
84
|
+
|
|
85
|
+
for (let b = 0; b < nBase; b++) {
|
|
86
|
+
const [, EstClass, params] = this.#baseSpecs[b]
|
|
87
|
+
for (const { train, test } of folds) {
|
|
88
|
+
const Xtrain = _subsetX(Xn, train)
|
|
89
|
+
const ytrain = _subsetY(yn, train)
|
|
90
|
+
const Xtest = _subsetX(Xn, test)
|
|
91
|
+
|
|
92
|
+
const model = await EstClass.create(params || {})
|
|
93
|
+
try {
|
|
94
|
+
model.fit(Xtrain, ytrain)
|
|
95
|
+
if (this.#task === 'classification') {
|
|
96
|
+
const proba = await model.predictProba(Xtest)
|
|
97
|
+
for (let i = 0; i < test.length; i++) {
|
|
98
|
+
const row = test[i]
|
|
99
|
+
for (let c = 0; c < this.#nClasses; c++) {
|
|
100
|
+
oofData[row * oofCols + b * colsPerModel + c] = proba[i * this.#nClasses + c]
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
} else {
|
|
104
|
+
const preds = await model.predict(Xtest)
|
|
105
|
+
for (let i = 0; i < test.length; i++) {
|
|
106
|
+
oofData[test[i] * oofCols + b] = preds[i]
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
} finally {
|
|
110
|
+
model.dispose()
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// Step 2: Build meta-feature matrix
|
|
116
|
+
let metaX
|
|
117
|
+
if (this.#passthrough) {
|
|
118
|
+
this.#nMetaCols = oofCols + Xn.cols
|
|
119
|
+
const metaData = new Float64Array(n * this.#nMetaCols)
|
|
120
|
+
for (let i = 0; i < n; i++) {
|
|
121
|
+
// OOF predictions
|
|
122
|
+
metaData.set(
|
|
123
|
+
oofData.subarray(i * oofCols, (i + 1) * oofCols),
|
|
124
|
+
i * this.#nMetaCols
|
|
125
|
+
)
|
|
126
|
+
// Original features
|
|
127
|
+
metaData.set(
|
|
128
|
+
Xn.data.subarray(i * Xn.cols, (i + 1) * Xn.cols),
|
|
129
|
+
i * this.#nMetaCols + oofCols
|
|
130
|
+
)
|
|
131
|
+
}
|
|
132
|
+
metaX = { data: metaData, rows: n, cols: this.#nMetaCols }
|
|
133
|
+
} else {
|
|
134
|
+
this.#nMetaCols = oofCols
|
|
135
|
+
metaX = { data: oofData, rows: n, cols: oofCols }
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
// Step 3: Train base models on full data
|
|
139
|
+
this.#baseModels = []
|
|
140
|
+
for (const [, EstClass, params] of this.#baseSpecs) {
|
|
141
|
+
const model = await EstClass.create(params || {})
|
|
142
|
+
model.fit(Xn, yn)
|
|
143
|
+
this.#baseModels.push(model)
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
// Step 4: Train meta-model on OOF features
|
|
147
|
+
const [, MetaClass, metaParams] = this.#metaSpec
|
|
148
|
+
this.#metaModel = await MetaClass.create(metaParams || {})
|
|
149
|
+
this.#metaModel.fit(metaX, yn)
|
|
150
|
+
|
|
151
|
+
this.#fitted = true
|
|
152
|
+
return this
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
predict(X) {
|
|
156
|
+
this.#ensureFitted()
|
|
157
|
+
const metaX = this.#buildMetaFeatures(X)
|
|
158
|
+
return lift(metaX, mx => this.#metaModel.predict(mx))
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
predictProba(X) {
|
|
162
|
+
this.#ensureFitted()
|
|
163
|
+
if (this.#task !== 'classification') {
|
|
164
|
+
throw new ValidationError('predictProba is only available for classification')
|
|
165
|
+
}
|
|
166
|
+
if (typeof this.#metaModel.predictProba !== 'function') {
|
|
167
|
+
throw new ValidationError('Meta-model does not support predictProba')
|
|
168
|
+
}
|
|
169
|
+
const metaX = this.#buildMetaFeatures(X)
|
|
170
|
+
return lift(metaX, mx => this.#metaModel.predictProba(mx))
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
score(X, y) {
|
|
174
|
+
this.#ensureFitted()
|
|
175
|
+
const preds = this.predict(X)
|
|
176
|
+
const yn = normalizeY(y)
|
|
177
|
+
const scorer = this.#task === 'classification' ? accuracy : r2Score
|
|
178
|
+
return lift(preds, p => scorer(yn, p))
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
save() {
|
|
182
|
+
this.#ensureFitted()
|
|
183
|
+
const typeId = this.#task === 'classification' ? TYPE_ID_CLS : TYPE_ID_REG
|
|
184
|
+
const manifest = {
|
|
185
|
+
typeId,
|
|
186
|
+
params: {
|
|
187
|
+
task: this.#task,
|
|
188
|
+
cv: this.#cv,
|
|
189
|
+
passthrough: this.#passthrough,
|
|
190
|
+
seed: this.#seed,
|
|
191
|
+
estimatorNames: this.#baseSpecs.map(s => s[0]),
|
|
192
|
+
metaName: this.#metaSpec[0],
|
|
193
|
+
classes: this.#classes ? [...this.#classes] : null,
|
|
194
|
+
nMetaCols: this.#nMetaCols,
|
|
195
|
+
},
|
|
196
|
+
}
|
|
197
|
+
const artifacts = this.#baseModels.map((model, i) => ({
|
|
198
|
+
id: this.#baseSpecs[i][0],
|
|
199
|
+
data: model.save(),
|
|
200
|
+
mediaType: 'application/x-wlearn-bundle',
|
|
201
|
+
}))
|
|
202
|
+
artifacts.push({
|
|
203
|
+
id: this.#metaSpec[0],
|
|
204
|
+
data: this.#metaModel.save(),
|
|
205
|
+
mediaType: 'application/x-wlearn-bundle',
|
|
206
|
+
})
|
|
207
|
+
return encodeBundle(manifest, artifacts)
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
static async load(bytes) {
|
|
211
|
+
const { manifest, toc, blobs } = decodeBundle(bytes)
|
|
212
|
+
return StackingEnsemble._loadFromParts(manifest, toc, blobs)
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
dispose() {
|
|
216
|
+
if (this.#disposed) return
|
|
217
|
+
this.#disposed = true
|
|
218
|
+
if (this.#baseModels) {
|
|
219
|
+
for (const m of this.#baseModels) m.dispose()
|
|
220
|
+
}
|
|
221
|
+
if (this.#metaModel) this.#metaModel.dispose()
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
getParams() {
|
|
225
|
+
return {
|
|
226
|
+
task: this.#task,
|
|
227
|
+
cv: this.#cv,
|
|
228
|
+
passthrough: this.#passthrough,
|
|
229
|
+
seed: this.#seed,
|
|
230
|
+
estimatorNames: this.#baseSpecs.map(s => s[0]),
|
|
231
|
+
metaName: this.#metaSpec ? this.#metaSpec[0] : null,
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
setParams(p) {
|
|
236
|
+
this.#ensureAlive()
|
|
237
|
+
if (p.cv !== undefined) this.#cv = p.cv
|
|
238
|
+
if (p.passthrough !== undefined) this.#passthrough = p.passthrough
|
|
239
|
+
if (p.seed !== undefined) this.#seed = p.seed
|
|
240
|
+
return this
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
get capabilities() {
|
|
244
|
+
return {
|
|
245
|
+
classifier: this.#task === 'classification',
|
|
246
|
+
regressor: this.#task === 'regression',
|
|
247
|
+
predictProba: this.#task === 'classification',
|
|
248
|
+
decisionFunction: false,
|
|
249
|
+
sampleWeight: false,
|
|
250
|
+
csr: false,
|
|
251
|
+
earlyStopping: false,
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
get isFitted() { return this.#fitted }
|
|
256
|
+
get classes() { return this.#classes }
|
|
257
|
+
|
|
258
|
+
// --- Private helpers ---
|
|
259
|
+
|
|
260
|
+
#buildMetaFeatures(X) {
|
|
261
|
+
const Xn = normalizeX(X)
|
|
262
|
+
const n = Xn.rows
|
|
263
|
+
const nBase = this.#baseModels.length
|
|
264
|
+
const colsPerModel = this.#task === 'classification' ? this.#nClasses : 1
|
|
265
|
+
const oofCols = nBase * colsPerModel
|
|
266
|
+
|
|
267
|
+
// Collect predictions from all base models
|
|
268
|
+
const rawOutputs = []
|
|
269
|
+
let hasPromise = false
|
|
270
|
+
for (let b = 0; b < nBase; b++) {
|
|
271
|
+
const out = this.#task === 'classification'
|
|
272
|
+
? this.#baseModels[b].predictProba(Xn)
|
|
273
|
+
: this.#baseModels[b].predict(Xn)
|
|
274
|
+
if (out != null && typeof out.then === 'function') hasPromise = true
|
|
275
|
+
rawOutputs.push(out)
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
const assemble = (outputs) => {
|
|
279
|
+
const metaData = new Float64Array(n * this.#nMetaCols)
|
|
280
|
+
for (let b = 0; b < nBase; b++) {
|
|
281
|
+
if (this.#task === 'classification') {
|
|
282
|
+
const proba = outputs[b]
|
|
283
|
+
for (let i = 0; i < n; i++) {
|
|
284
|
+
for (let c = 0; c < this.#nClasses; c++) {
|
|
285
|
+
metaData[i * this.#nMetaCols + b * colsPerModel + c] = proba[i * this.#nClasses + c]
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
} else {
|
|
289
|
+
const preds = outputs[b]
|
|
290
|
+
for (let i = 0; i < n; i++) {
|
|
291
|
+
metaData[i * this.#nMetaCols + b] = preds[i]
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
if (this.#passthrough) {
|
|
296
|
+
for (let i = 0; i < n; i++) {
|
|
297
|
+
metaData.set(
|
|
298
|
+
Xn.data.subarray(i * Xn.cols, (i + 1) * Xn.cols),
|
|
299
|
+
i * this.#nMetaCols + oofCols
|
|
300
|
+
)
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
return { data: metaData, rows: n, cols: this.#nMetaCols }
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
return hasPromise ? Promise.all(rawOutputs).then(assemble) : assemble(rawOutputs)
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
static _register() {
|
|
310
|
+
if (_registered) return
|
|
311
|
+
_registered = true
|
|
312
|
+
const loader = (manifest, toc, blobs) => {
|
|
313
|
+
return StackingEnsemble._loadFromParts(manifest, toc, blobs)
|
|
314
|
+
}
|
|
315
|
+
register(TYPE_ID_CLS, loader)
|
|
316
|
+
register(TYPE_ID_REG, loader)
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
static async _loadFromParts(manifest, toc, blobs) {
|
|
320
|
+
const p = manifest.params
|
|
321
|
+
const ens = new StackingEnsemble({
|
|
322
|
+
task: p.task,
|
|
323
|
+
cv: p.cv,
|
|
324
|
+
passthrough: p.passthrough,
|
|
325
|
+
seed: p.seed,
|
|
326
|
+
})
|
|
327
|
+
ens.#classes = p.classes ? new Int32Array(p.classes) : null
|
|
328
|
+
ens.#nClasses = ens.#classes ? ens.#classes.length : 0
|
|
329
|
+
ens.#nMetaCols = p.nMetaCols
|
|
330
|
+
ens.#baseSpecs = p.estimatorNames.map(name => [name, null, null])
|
|
331
|
+
ens.#metaSpec = [p.metaName, null, null]
|
|
332
|
+
|
|
333
|
+
// Load base models
|
|
334
|
+
ens.#baseModels = []
|
|
335
|
+
for (const name of p.estimatorNames) {
|
|
336
|
+
const entry = toc.find(t => t.id === name)
|
|
337
|
+
if (!entry) throw new ValidationError(`No artifact for base estimator "${name}"`)
|
|
338
|
+
const blob = blobs.subarray(entry.offset, entry.offset + entry.length)
|
|
339
|
+
ens.#baseModels.push(await registryLoad(blob))
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
// Load meta-model
|
|
343
|
+
const metaEntry = toc.find(t => t.id === p.metaName)
|
|
344
|
+
if (!metaEntry) throw new ValidationError(`No artifact for meta estimator "${p.metaName}"`)
|
|
345
|
+
const metaBlob = blobs.subarray(metaEntry.offset, metaEntry.offset + metaEntry.length)
|
|
346
|
+
ens.#metaModel = await registryLoad(metaBlob)
|
|
347
|
+
|
|
348
|
+
ens.#fitted = true
|
|
349
|
+
return ens
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
// --- Subset helpers ---
|
|
354
|
+
|
|
355
|
+
function _subsetX(X, indices) {
|
|
356
|
+
const { data, cols } = X
|
|
357
|
+
const rows = indices.length
|
|
358
|
+
const out = new Float64Array(rows * cols)
|
|
359
|
+
for (let i = 0; i < rows; i++) {
|
|
360
|
+
const srcOff = indices[i] * cols
|
|
361
|
+
out.set(data.subarray(srcOff, srcOff + cols), i * cols)
|
|
362
|
+
}
|
|
363
|
+
return { data: out, rows, cols }
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
function _subsetY(y, indices) {
|
|
367
|
+
const out = new (y.constructor)(indices.length)
|
|
368
|
+
for (let i = 0; i < indices.length; i++) {
|
|
369
|
+
out[i] = y[indices[i]]
|
|
370
|
+
}
|
|
371
|
+
return out
|
|
372
|
+
}
|
package/src/voting.js
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
import {
|
|
2
|
+
encodeBundle, decodeBundle, register, load as registryLoad,
|
|
3
|
+
normalizeX, normalizeY, accuracy, r2Score,
|
|
4
|
+
ValidationError, NotFittedError, DisposedError,
|
|
5
|
+
lift
|
|
6
|
+
} from '@wlearn/core'
|
|
7
|
+
|
|
8
|
+
const TYPE_ID_CLS = 'wlearn.ensemble.voting.classifier@1'
|
|
9
|
+
const TYPE_ID_REG = 'wlearn.ensemble.voting.regressor@1'
|
|
10
|
+
let _registered = false
|
|
11
|
+
|
|
12
|
+
export class VotingEnsemble {
|
|
13
|
+
#specs // [name, Class, params][]
|
|
14
|
+
#models // fitted instances
|
|
15
|
+
#weights
|
|
16
|
+
#voting // 'soft' | 'hard'
|
|
17
|
+
#task // 'classification' | 'regression'
|
|
18
|
+
#classes
|
|
19
|
+
#fitted = false
|
|
20
|
+
#disposed = false
|
|
21
|
+
|
|
22
|
+
constructor(params) {
|
|
23
|
+
this.#specs = params.estimators || []
|
|
24
|
+
this.#weights = params.weights || null
|
|
25
|
+
this.#voting = params.voting || 'soft'
|
|
26
|
+
this.#task = params.task || 'classification'
|
|
27
|
+
this.#models = null
|
|
28
|
+
this.#classes = null
|
|
29
|
+
VotingEnsemble._register()
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
static async create(params = {}) {
|
|
33
|
+
return new VotingEnsemble(params)
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
#ensureAlive() {
|
|
37
|
+
if (this.#disposed) throw new DisposedError('VotingEnsemble has been disposed.')
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
#ensureFitted() {
|
|
41
|
+
this.#ensureAlive()
|
|
42
|
+
if (!this.#fitted) throw new NotFittedError('VotingEnsemble is not fitted. Call fit() first.')
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
async fit(X, y) {
|
|
46
|
+
this.#ensureAlive()
|
|
47
|
+
const Xn = normalizeX(X)
|
|
48
|
+
const yn = normalizeY(y)
|
|
49
|
+
|
|
50
|
+
if (this.#task === 'classification') {
|
|
51
|
+
const labelSet = new Set()
|
|
52
|
+
for (let i = 0; i < yn.length; i++) labelSet.add(yn[i])
|
|
53
|
+
this.#classes = new Int32Array([...labelSet].sort((a, b) => a - b))
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
// Default equal weights
|
|
57
|
+
if (!this.#weights) {
|
|
58
|
+
this.#weights = new Float64Array(this.#specs.length).fill(1 / this.#specs.length)
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// Instantiate and fit all models
|
|
62
|
+
this.#models = []
|
|
63
|
+
for (const [name, EstClass, params] of this.#specs) {
|
|
64
|
+
const model = await EstClass.create(params || {})
|
|
65
|
+
model.fit(Xn, yn)
|
|
66
|
+
this.#models.push(model)
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
this.#fitted = true
|
|
70
|
+
return this
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
predict(X) {
|
|
74
|
+
this.#ensureFitted()
|
|
75
|
+
const Xn = normalizeX(X)
|
|
76
|
+
const n = Xn.rows
|
|
77
|
+
|
|
78
|
+
if (this.#task === 'regression') {
|
|
79
|
+
return this.#weightedAverage(Xn, n)
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
if (this.#voting === 'soft') {
|
|
83
|
+
const proba = this.predictProba(Xn)
|
|
84
|
+
return lift(proba, p => {
|
|
85
|
+
const nc = this.#classes.length
|
|
86
|
+
const out = new Float64Array(n)
|
|
87
|
+
for (let i = 0; i < n; i++) {
|
|
88
|
+
let bestC = 0, bestV = -Infinity
|
|
89
|
+
for (let c = 0; c < nc; c++) {
|
|
90
|
+
if (p[i * nc + c] > bestV) { bestV = p[i * nc + c]; bestC = c }
|
|
91
|
+
}
|
|
92
|
+
out[i] = this.#classes[bestC]
|
|
93
|
+
}
|
|
94
|
+
return out
|
|
95
|
+
})
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// Hard voting: majority vote
|
|
99
|
+
return this.#majorityVote(Xn, n)
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
predictProba(X) {
|
|
103
|
+
this.#ensureFitted()
|
|
104
|
+
if (this.#task !== 'classification') {
|
|
105
|
+
throw new ValidationError('predictProba is only available for classification')
|
|
106
|
+
}
|
|
107
|
+
if (this.#voting === 'hard') {
|
|
108
|
+
throw new ValidationError('predictProba requires voting="soft"')
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
const Xn = normalizeX(X)
|
|
112
|
+
const n = Xn.rows
|
|
113
|
+
const nc = this.#classes.length
|
|
114
|
+
|
|
115
|
+
// Collect predictions from all models
|
|
116
|
+
const rawOutputs = []
|
|
117
|
+
let hasPromise = false
|
|
118
|
+
for (let m = 0; m < this.#models.length; m++) {
|
|
119
|
+
const out = this.#models[m].predictProba(Xn)
|
|
120
|
+
if (out != null && typeof out.then === 'function') hasPromise = true
|
|
121
|
+
rawOutputs.push(out)
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
const assemble = (outputs) => {
|
|
125
|
+
const result = new Float64Array(n * nc)
|
|
126
|
+
for (let m = 0; m < outputs.length; m++) {
|
|
127
|
+
const proba = outputs[m]
|
|
128
|
+
const w = this.#weights[m]
|
|
129
|
+
for (let i = 0; i < n * nc; i++) {
|
|
130
|
+
result[i] += w * proba[i]
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
return result
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
return hasPromise ? Promise.all(rawOutputs).then(assemble) : assemble(rawOutputs)
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
score(X, y) {
|
|
140
|
+
this.#ensureFitted()
|
|
141
|
+
const preds = this.predict(X)
|
|
142
|
+
const yn = normalizeY(y)
|
|
143
|
+
const scorer = this.#task === 'classification' ? accuracy : r2Score
|
|
144
|
+
return lift(preds, p => scorer(yn, p))
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
save() {
|
|
148
|
+
this.#ensureFitted()
|
|
149
|
+
const typeId = this.#task === 'classification' ? TYPE_ID_CLS : TYPE_ID_REG
|
|
150
|
+
const manifest = {
|
|
151
|
+
typeId,
|
|
152
|
+
params: {
|
|
153
|
+
task: this.#task,
|
|
154
|
+
voting: this.#voting,
|
|
155
|
+
weights: [...this.#weights],
|
|
156
|
+
estimatorNames: this.#specs.map(s => s[0]),
|
|
157
|
+
classes: this.#classes ? [...this.#classes] : null,
|
|
158
|
+
},
|
|
159
|
+
}
|
|
160
|
+
const artifacts = this.#models.map((model, i) => ({
|
|
161
|
+
id: this.#specs[i][0],
|
|
162
|
+
data: model.save(),
|
|
163
|
+
mediaType: 'application/x-wlearn-bundle',
|
|
164
|
+
}))
|
|
165
|
+
return encodeBundle(manifest, artifacts)
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
static async load(bytes) {
|
|
169
|
+
const { manifest, toc, blobs } = decodeBundle(bytes)
|
|
170
|
+
const p = manifest.params
|
|
171
|
+
const ens = new VotingEnsemble({
|
|
172
|
+
task: p.task,
|
|
173
|
+
voting: p.voting,
|
|
174
|
+
weights: new Float64Array(p.weights),
|
|
175
|
+
})
|
|
176
|
+
ens.#classes = p.classes ? new Int32Array(p.classes) : null
|
|
177
|
+
ens.#specs = p.estimatorNames.map(name => [name, null, null])
|
|
178
|
+
|
|
179
|
+
// Load submodels via registry
|
|
180
|
+
ens.#models = []
|
|
181
|
+
for (const name of p.estimatorNames) {
|
|
182
|
+
const entry = toc.find(t => t.id === name)
|
|
183
|
+
if (!entry) throw new ValidationError(`No artifact for estimator "${name}"`)
|
|
184
|
+
const blob = blobs.subarray(entry.offset, entry.offset + entry.length)
|
|
185
|
+
const model = await registryLoad(blob)
|
|
186
|
+
ens.#models.push(model)
|
|
187
|
+
}
|
|
188
|
+
ens.#fitted = true
|
|
189
|
+
return ens
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
dispose() {
|
|
193
|
+
if (this.#disposed) return
|
|
194
|
+
this.#disposed = true
|
|
195
|
+
if (this.#models) {
|
|
196
|
+
for (const m of this.#models) m.dispose()
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
getParams() {
|
|
201
|
+
return {
|
|
202
|
+
task: this.#task,
|
|
203
|
+
voting: this.#voting,
|
|
204
|
+
weights: this.#weights ? [...this.#weights] : null,
|
|
205
|
+
estimatorNames: this.#specs.map(s => s[0]),
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
setParams(p) {
|
|
210
|
+
this.#ensureAlive()
|
|
211
|
+
if (p.voting !== undefined) this.#voting = p.voting
|
|
212
|
+
if (p.weights !== undefined) this.#weights = new Float64Array(p.weights)
|
|
213
|
+
return this
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
get capabilities() {
|
|
217
|
+
return {
|
|
218
|
+
classifier: this.#task === 'classification',
|
|
219
|
+
regressor: this.#task === 'regression',
|
|
220
|
+
predictProba: this.#task === 'classification' && this.#voting === 'soft',
|
|
221
|
+
decisionFunction: false,
|
|
222
|
+
sampleWeight: false,
|
|
223
|
+
csr: false,
|
|
224
|
+
earlyStopping: false,
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
get isFitted() { return this.#fitted }
|
|
229
|
+
get classes() { return this.#classes }
|
|
230
|
+
|
|
231
|
+
// --- Private helpers ---
|
|
232
|
+
|
|
233
|
+
#weightedAverage(Xn, n) {
|
|
234
|
+
const rawOutputs = []
|
|
235
|
+
let hasPromise = false
|
|
236
|
+
for (let m = 0; m < this.#models.length; m++) {
|
|
237
|
+
const out = this.#models[m].predict(Xn)
|
|
238
|
+
if (out != null && typeof out.then === 'function') hasPromise = true
|
|
239
|
+
rawOutputs.push(out)
|
|
240
|
+
}
|
|
241
|
+
const assemble = (outputs) => {
|
|
242
|
+
const result = new Float64Array(n)
|
|
243
|
+
for (let m = 0; m < outputs.length; m++) {
|
|
244
|
+
const w = this.#weights[m]
|
|
245
|
+
for (let i = 0; i < n; i++) result[i] += w * outputs[m][i]
|
|
246
|
+
}
|
|
247
|
+
return result
|
|
248
|
+
}
|
|
249
|
+
return hasPromise ? Promise.all(rawOutputs).then(assemble) : assemble(rawOutputs)
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
#majorityVote(Xn, n) {
|
|
253
|
+
const rawOutputs = []
|
|
254
|
+
let hasPromise = false
|
|
255
|
+
for (let m = 0; m < this.#models.length; m++) {
|
|
256
|
+
const out = this.#models[m].predict(Xn)
|
|
257
|
+
if (out != null && typeof out.then === 'function') hasPromise = true
|
|
258
|
+
rawOutputs.push(out)
|
|
259
|
+
}
|
|
260
|
+
const assemble = (outputs) => {
|
|
261
|
+
const nc = this.#classes.length
|
|
262
|
+
const result = new Float64Array(n)
|
|
263
|
+
for (let i = 0; i < n; i++) {
|
|
264
|
+
const votes = new Float64Array(nc)
|
|
265
|
+
for (let m = 0; m < outputs.length; m++) {
|
|
266
|
+
const pred = outputs[m][i]
|
|
267
|
+
const classIdx = this.#classes.indexOf(pred)
|
|
268
|
+
if (classIdx >= 0) votes[classIdx] += this.#weights[m]
|
|
269
|
+
}
|
|
270
|
+
let bestC = 0, bestV = -Infinity
|
|
271
|
+
for (let c = 0; c < nc; c++) {
|
|
272
|
+
if (votes[c] > bestV) { bestV = votes[c]; bestC = c }
|
|
273
|
+
}
|
|
274
|
+
result[i] = this.#classes[bestC]
|
|
275
|
+
}
|
|
276
|
+
return result
|
|
277
|
+
}
|
|
278
|
+
return hasPromise ? Promise.all(rawOutputs).then(assemble) : assemble(rawOutputs)
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
static _register() {
|
|
282
|
+
if (_registered) return
|
|
283
|
+
_registered = true
|
|
284
|
+
const loader = (manifest, toc, blobs) => {
|
|
285
|
+
return VotingEnsemble._loadFromParts(manifest, toc, blobs)
|
|
286
|
+
}
|
|
287
|
+
register(TYPE_ID_CLS, loader)
|
|
288
|
+
register(TYPE_ID_REG, loader)
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
static async _loadFromParts(manifest, toc, blobs) {
|
|
292
|
+
const p = manifest.params
|
|
293
|
+
const ens = new VotingEnsemble({
|
|
294
|
+
task: p.task,
|
|
295
|
+
voting: p.voting,
|
|
296
|
+
weights: new Float64Array(p.weights),
|
|
297
|
+
})
|
|
298
|
+
ens.#classes = p.classes ? new Int32Array(p.classes) : null
|
|
299
|
+
ens.#specs = p.estimatorNames.map(name => [name, null, null])
|
|
300
|
+
ens.#models = []
|
|
301
|
+
for (const name of p.estimatorNames) {
|
|
302
|
+
const entry = toc.find(t => t.id === name)
|
|
303
|
+
if (!entry) throw new ValidationError(`No artifact for estimator "${name}"`)
|
|
304
|
+
const blob = blobs.subarray(entry.offset, entry.offset + entry.length)
|
|
305
|
+
const model = await registryLoad(blob)
|
|
306
|
+
ens.#models.push(model)
|
|
307
|
+
}
|
|
308
|
+
ens.#fitted = true
|
|
309
|
+
return ens
|
|
310
|
+
}
|
|
311
|
+
}
|