@wlearn/automl 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json ADDED
@@ -0,0 +1,27 @@
1
+ {
2
+ "name": "@wlearn/automl",
3
+ "version": "0.1.0",
4
+ "description": "AutoML engine for wlearn: search space sampling, random search, successive halving, ensemble construction",
5
+ "type": "module",
6
+ "main": "src/index.js",
7
+ "exports": {
8
+ ".": "./src/index.js"
9
+ },
10
+ "files": [
11
+ "src/"
12
+ ],
13
+ "sideEffects": false,
14
+ "dependencies": {
15
+ "@wlearn/core": "0.1.0",
16
+ "@wlearn/ensemble": "0.1.0",
17
+ "@wlearn/types": "0.1.0"
18
+ },
19
+ "scripts": {
20
+ "test": "node --test test/*.js"
21
+ },
22
+ "publishConfig": {
23
+ "access": "public"
24
+ },
25
+ "author": "Anton Zemlyansky",
26
+ "license": "MIT"
27
+ }
@@ -0,0 +1,261 @@
1
+ import { normalizeX, normalizeY, ValidationError, Preprocessor } from '@wlearn/core'
2
+ import { getOofPredictions, caruanaSelect, VotingEnsemble, StackingEnsemble } from '@wlearn/ensemble'
3
+ import { RandomSearch } from './search.js'
4
+ import { SuccessiveHalvingSearch } from './halving.js'
5
+ import { PortfolioSearch } from './portfolio.js'
6
+ import { ProgressiveSearch } from './progressive.js'
7
+ import { detectTask } from './common.js'
8
+
9
+ /**
10
+ * Compute pairwise disagreement rate between two prediction vectors.
11
+ * For classification: fraction of samples where argmax differs.
12
+ * For regression: 1 - correlation (capped at [0,1]).
13
+ */
14
+ function _disagreementRate(a, b, n, task) {
15
+ if (task === 'classification') {
16
+ const nClasses = a.length / n
17
+ let disagree = 0
18
+ for (let i = 0; i < n; i++) {
19
+ let bestA = 0, bestB = 0, bestVA = -Infinity, bestVB = -Infinity
20
+ for (let c = 0; c < nClasses; c++) {
21
+ const idx = i * nClasses + c
22
+ if (a[idx] > bestVA) { bestVA = a[idx]; bestA = c }
23
+ if (b[idx] > bestVB) { bestVB = b[idx]; bestB = c }
24
+ }
25
+ if (bestA !== bestB) disagree++
26
+ }
27
+ return disagree / n
28
+ }
29
+ // Regression: 1 - abs(correlation)
30
+ let sumA = 0, sumB = 0, sumAA = 0, sumBB = 0, sumAB = 0
31
+ for (let i = 0; i < n; i++) {
32
+ sumA += a[i]; sumB += b[i]
33
+ sumAA += a[i] * a[i]; sumBB += b[i] * b[i]
34
+ sumAB += a[i] * b[i]
35
+ }
36
+ const denom = Math.sqrt((sumAA - sumA * sumA / n) * (sumBB - sumB * sumB / n))
37
+ if (denom < 1e-12) return 1
38
+ const corr = (sumAB - sumA * sumB / n) / denom
39
+ return 1 - Math.abs(corr)
40
+ }
41
+
42
+ /**
43
+ * Filter pool by minimum pairwise disagreement.
44
+ * Always keeps index 0 (best model). Greedily adds candidates that
45
+ * have at least minDisagreement with all already-selected members.
46
+ * Returns array of retained indices.
47
+ */
48
+ function _filterByDisagreement(oofPreds, yn, task, minDisagreement) {
49
+ const n = yn.length
50
+ if (oofPreds.length <= 2 || minDisagreement <= 0) {
51
+ return oofPreds.map((_, i) => i)
52
+ }
53
+ const kept = [0]
54
+ for (let i = 1; i < oofPreds.length; i++) {
55
+ let diverse = true
56
+ for (const j of kept) {
57
+ if (_disagreementRate(oofPreds[i], oofPreds[j], n, task) < minDisagreement) {
58
+ diverse = false
59
+ break
60
+ }
61
+ }
62
+ if (diverse) kept.push(i)
63
+ }
64
+ // Always keep at least 2 for ensemble
65
+ if (kept.length < 2 && oofPreds.length >= 2) {
66
+ if (!kept.includes(1)) kept.push(1)
67
+ }
68
+ return kept
69
+ }
70
+
71
+ /**
72
+ * Normalize model specs: accept both ModelSpec objects and [name, cls, params?] tuples.
73
+ */
74
+ function _normalizeSpecs(models) {
75
+ return models.map(m => {
76
+ if (Array.isArray(m)) {
77
+ return { name: m[0], cls: m[1], params: m[2] || {} }
78
+ }
79
+ return m
80
+ })
81
+ }
82
+
83
+ /**
84
+ * High-level AutoML: random search + optional Caruana ensemble + refit.
85
+ *
86
+ * @param {Array} models - ModelSpec[] or EstimatorSpec tuples [name, cls, params?]
87
+ * @param {object|number[][]} X - feature matrix
88
+ * @param {TypedArray|number[]} y - labels
89
+ * @param {object} opts
90
+ * @returns {Promise<{ model: object, leaderboard: object[], bestParams: object, bestModelName: string, bestScore: number }>}
91
+ */
92
+ export async function autoFit(models, X, y, opts = {}) {
93
+ const {
94
+ ensemble = true,
95
+ ensembleSize = 20,
96
+ refit = true,
97
+ strategy = 'random',
98
+ minDisagreement = 0.05,
99
+ stacking = 'auto',
100
+ metaEstimator = null,
101
+ preprocess = false,
102
+ onProgress = null,
103
+ ...searchOpts
104
+ } = opts
105
+
106
+ const specs = _normalizeSpecs(models)
107
+ if (specs.length === 0) {
108
+ throw new ValidationError('autoFit: at least one model is required')
109
+ }
110
+
111
+ // Optional preprocessing
112
+ let preprocessor = null
113
+ if (preprocess) {
114
+ const ppConfig = typeof preprocess === 'object' ? preprocess : {}
115
+ preprocessor = new Preprocessor(ppConfig)
116
+ const Xpre = normalizeX(X)
117
+ const ypre = normalizeY(y)
118
+ const Xt = preprocessor.fitTransform(Xpre, ypre)
119
+ X = Xt
120
+ }
121
+
122
+ // Run search
123
+ const searchOptsWithProgress = { ...searchOpts, onProgress }
124
+ let search
125
+ if (strategy === 'portfolio') {
126
+ search = new PortfolioSearch(specs, searchOptsWithProgress)
127
+ } else if (strategy === 'halving') {
128
+ search = new SuccessiveHalvingSearch(specs, searchOptsWithProgress)
129
+ } else if (strategy === 'progressive') {
130
+ search = new ProgressiveSearch(specs, searchOptsWithProgress)
131
+ } else {
132
+ search = new RandomSearch(specs, searchOptsWithProgress)
133
+ }
134
+ const { leaderboard, bestResult } = await search.fit(X, y)
135
+ const ranked = leaderboard.ranked()
136
+
137
+ const Xn = normalizeX(X)
138
+ const yn = normalizeY(y)
139
+ const task = searchOpts.task || detectTask(yn)
140
+ const scoring = searchOpts.scoring || (task === 'classification' ? 'accuracy' : 'r2')
141
+ const cv = searchOpts.cv || 5
142
+ const seed = searchOpts.seed || 42
143
+
144
+ let model = null
145
+
146
+ if (ensemble) {
147
+ if (onProgress) {
148
+ onProgress({ phase: 'ensemble', message: 'building ensemble' })
149
+ }
150
+ // Diversity-aware pool: best per family + top overall with disagreement filter
151
+ const familyBest = new Map()
152
+ const familySecond = new Map()
153
+ for (const entry of ranked) {
154
+ if (!familyBest.has(entry.modelName)) {
155
+ familyBest.set(entry.modelName, entry)
156
+ } else if (!familySecond.has(entry.modelName)) {
157
+ familySecond.set(entry.modelName, entry)
158
+ }
159
+ }
160
+
161
+ // Seed pool: best per family (guaranteed diversity)
162
+ const pool = [...familyBest.values()]
163
+ const poolIds = new Set(pool.map(e => e.id))
164
+
165
+ // Add second-best per family if available (for intra-family diversity)
166
+ for (const entry of familySecond.values()) {
167
+ if (pool.length >= ensembleSize * 2) break
168
+ if (!poolIds.has(entry.id)) {
169
+ pool.push(entry)
170
+ poolIds.add(entry.id)
171
+ }
172
+ }
173
+
174
+ // Fill remaining slots from top overall
175
+ for (const entry of ranked) {
176
+ if (pool.length >= ensembleSize * 2) break
177
+ if (!poolIds.has(entry.id)) {
178
+ pool.push(entry)
179
+ poolIds.add(entry.id)
180
+ }
181
+ }
182
+
183
+ // Map model names to classes
184
+ const clsMap = new Map()
185
+ for (const spec of specs) {
186
+ clsMap.set(spec.name, spec.cls)
187
+ }
188
+
189
+ // Build estimator specs for OOF
190
+ const estSpecs = pool.map((entry, i) => {
191
+ const cls = clsMap.get(entry.modelName)
192
+ return [`${entry.modelName}_${i}`, cls, entry.params]
193
+ })
194
+
195
+ // Generate OOF predictions
196
+ const { oofPreds } = await getOofPredictions(estSpecs, Xn, yn, {
197
+ cv, seed, task,
198
+ })
199
+
200
+ // Disagreement filter: remove near-duplicate predictions
201
+ const filteredIdx = _filterByDisagreement(oofPreds, yn, task, minDisagreement)
202
+ const filteredOofs = filteredIdx.map(i => oofPreds[i])
203
+ const filteredSpecs = filteredIdx.map(i => estSpecs[i])
204
+
205
+ // Caruana selection on filtered pool
206
+ const { indices: selIndices, weights } = caruanaSelect(filteredOofs, yn, {
207
+ maxSize: ensembleSize,
208
+ scoring,
209
+ task,
210
+ })
211
+
212
+ // Build ensemble from selected
213
+ const indices = selIndices
214
+ const selectedSpecs = Array.from(indices, i => filteredSpecs[i])
215
+ const selectedWeights = weights
216
+
217
+ // Determine if two-layer stacking should be used
218
+ const selectedFamilies = new Set(selectedSpecs.map(s => s[0].split('_')[0]))
219
+ const useStacking = stacking === true ||
220
+ (stacking === 'auto' && selectedFamilies.size >= 3 && metaEstimator)
221
+
222
+ if (useStacking && metaEstimator) {
223
+ // Two-layer stacking: L0 = selected base models, L1 = meta-model
224
+ const metaSpec = Array.isArray(metaEstimator)
225
+ ? metaEstimator
226
+ : ['meta', metaEstimator.cls || metaEstimator, metaEstimator.params || {}]
227
+ const ens = await StackingEnsemble.create({
228
+ estimators: selectedSpecs,
229
+ finalEstimator: metaSpec,
230
+ passthrough: true,
231
+ task,
232
+ cv,
233
+ seed,
234
+ })
235
+ await ens.fit(Xn, yn)
236
+ model = ens
237
+ } else {
238
+ // Default: VotingEnsemble
239
+ const ens = await VotingEnsemble.create({
240
+ estimators: selectedSpecs,
241
+ weights: selectedWeights,
242
+ voting: task === 'classification' ? 'soft' : undefined,
243
+ task,
244
+ })
245
+ await ens.fit(Xn, yn)
246
+ model = ens
247
+ }
248
+ } else if (refit) {
249
+ model = await search.refitBest(X, y)
250
+ }
251
+
252
+ return {
253
+ model,
254
+ preprocessor,
255
+ leaderboard: ranked,
256
+ bestParams: bestResult.params,
257
+ bestModelName: bestResult.modelName,
258
+ bestScore: bestResult.meanScore,
259
+ }
260
+ }
261
+
package/src/common.js ADDED
@@ -0,0 +1,108 @@
1
+ import { makeLCG } from '@wlearn/core'
2
+
3
+ const { round } = Math
4
+
5
+ /**
6
+ * Detect task type from labels.
7
+ */
8
+ export function detectTask(y) {
9
+ if (y instanceof Int32Array) return 'classification'
10
+ const unique = new Set()
11
+ for (let i = 0; i < y.length; i++) {
12
+ if (y[i] !== round(y[i])) return 'regression'
13
+ unique.add(y[i])
14
+ }
15
+ return unique.size <= 20 ? 'classification' : 'regression'
16
+ }
17
+
18
+ /**
19
+ * High-resolution timer.
20
+ */
21
+ export function now() {
22
+ if (typeof performance !== 'undefined') return performance.now()
23
+ return Date.now()
24
+ }
25
+
26
+ /**
27
+ * Stable JSON stringify with sorted keys.
28
+ * Numbers use toString() to avoid precision drift.
29
+ * Params must be JSON-serializable primitives only (enforced by SearchSpace IR).
30
+ */
31
+ function stableStringify(obj) {
32
+ if (obj === null || obj === undefined) return String(obj)
33
+ if (typeof obj === 'number') return obj.toString()
34
+ if (typeof obj === 'string') return JSON.stringify(obj)
35
+ if (typeof obj === 'boolean') return String(obj)
36
+ if (Array.isArray(obj)) {
37
+ return '[' + obj.map(stableStringify).join(',') + ']'
38
+ }
39
+ if (typeof obj === 'object') {
40
+ const keys = Object.keys(obj).sort()
41
+ return '{' + keys.map(k => JSON.stringify(k) + ':' + stableStringify(obj[k])).join(',') + '}'
42
+ }
43
+ return String(obj)
44
+ }
45
+
46
+ /**
47
+ * Stable candidate ID from model label and params.
48
+ */
49
+ export function makeCandidateId(modelLabel, params) {
50
+ return modelLabel + ':' + stableStringify(params)
51
+ }
52
+
53
+ /**
54
+ * Simple integer hash for strings (FNV-1a inspired).
55
+ */
56
+ function hashString(str) {
57
+ let h = 0x811c9dc5
58
+ for (let i = 0; i < str.length; i++) {
59
+ h ^= str.charCodeAt(i)
60
+ h = (h * 0x01000193) & 0x7fffffff
61
+ }
62
+ return h
63
+ }
64
+
65
+ /**
66
+ * Derive a deterministic seed from base seed, candidate ID, and fold index.
67
+ */
68
+ export function seedFor(candidateId, foldIdx, baseSeed) {
69
+ const h = hashString(candidateId)
70
+ // Mix: multiply-xor-shift
71
+ let s = (baseSeed * 2654435761 + h * 40503 + foldIdx * 65537) & 0x7fffffff
72
+ s = ((s >>> 16) ^ s) * 0x45d9f3b & 0x7fffffff
73
+ return s
74
+ }
75
+
76
+ /**
77
+ * Partial Fisher-Yates: shuffle only first k positions of indices array.
78
+ * O(k) time, mutates indices in-place. Returns indices subarray [0..k-1].
79
+ */
80
+ export function partialShuffle(indices, k, rng) {
81
+ const n = indices.length
82
+ const m = Math.min(k, n)
83
+ for (let i = 0; i < m; i++) {
84
+ const j = i + ((rng() * (n - i)) | 0)
85
+ const tmp = indices[i]
86
+ indices[i] = indices[j]
87
+ indices[j] = tmp
88
+ }
89
+ return indices.subarray ? indices.subarray(0, m) : indices.slice(0, m)
90
+ }
91
+
92
+ /**
93
+ * Map scoring name to greaterIsBetter.
94
+ * All built-in scorers are greater-is-better (neg_mse, neg_mae are negated).
95
+ * Custom functions default to true.
96
+ */
97
+ export function scorerGreaterIsBetter(scoring) {
98
+ if (typeof scoring === 'function') return true
99
+ switch (scoring) {
100
+ case 'accuracy':
101
+ case 'r2':
102
+ case 'neg_mse':
103
+ case 'neg_mae':
104
+ return true
105
+ default:
106
+ return true
107
+ }
108
+ }
@@ -0,0 +1,209 @@
1
+ import { normalizeX, normalizeY, makeLCG, getScorer } from '@wlearn/core'
2
+ import { Leaderboard } from './leaderboard.js'
3
+ import { now, seedFor, partialShuffle } from './common.js'
4
+
5
+ const { ceil, min } = Math
6
+
7
+ /**
8
+ * Subset rows of X by index array.
9
+ */
10
+ function subsetX(X, indices) {
11
+ const { data, cols } = X
12
+ const rows = indices.length
13
+ const out = new Float64Array(rows * cols)
14
+ for (let i = 0; i < rows; i++) {
15
+ const srcOff = indices[i] * cols
16
+ out.set(data.subarray(srcOff, srcOff + cols), i * cols)
17
+ }
18
+ return { data: out, rows, cols }
19
+ }
20
+
21
+ /**
22
+ * Subset labels by index array.
23
+ */
24
+ function subsetY(y, indices) {
25
+ const out = new (y.constructor)(indices.length)
26
+ for (let i = 0; i < indices.length; i++) {
27
+ out[i] = y[indices[i]]
28
+ }
29
+ return out
30
+ }
31
+
32
+ /**
33
+ * Executor: evaluation engine and canonical leaderboard owner.
34
+ *
35
+ * Evaluates candidates across all CV folds, applies budgets,
36
+ * records results in a single Leaderboard instance.
37
+ */
38
+ export class Executor {
39
+ #folds
40
+ #scorerFn
41
+ #X
42
+ #y
43
+ #timeLimitMs
44
+ #seed
45
+ #startTime
46
+ #leaderboard
47
+ #onProgress
48
+
49
+ /**
50
+ * @param {object} opts
51
+ * @param {Array<{train: Int32Array, test: Int32Array}>} opts.folds - CV folds
52
+ * @param {string|Function} opts.scoring - scorer name or function
53
+ * @param {object} opts.X - normalized feature matrix
54
+ * @param {TypedArray} opts.y - normalized labels
55
+ * @param {number} opts.timeLimitMs - global time limit (0 = no limit)
56
+ * @param {number} opts.seed - base seed for reproducibility
57
+ * @param {Function} opts.onProgress - optional progress callback
58
+ */
59
+ constructor({ folds, scoring, X, y, timeLimitMs = 0, seed = 42, onProgress }) {
60
+ this.#folds = folds
61
+ this.#scorerFn = getScorer(scoring)
62
+ this.#X = X
63
+ this.#y = y
64
+ this.#timeLimitMs = timeLimitMs
65
+ this.#seed = seed
66
+ this.#startTime = now()
67
+ this.#leaderboard = new Leaderboard()
68
+ this.#onProgress = onProgress || null
69
+ }
70
+
71
+ get leaderboard() {
72
+ return this.#leaderboard
73
+ }
74
+
75
+ get isTimedOut() {
76
+ if (this.#timeLimitMs <= 0) return false
77
+ return (now() - this.#startTime) > this.#timeLimitMs
78
+ }
79
+
80
+ /**
81
+ * Evaluate one candidate across all CV folds.
82
+ *
83
+ * @param {object} candidateEval
84
+ * @param {string} candidateEval.candidateId - stable identifier
85
+ * @param {object} candidateEval.cls - estimator class with create/fit/predict/dispose
86
+ * @param {object} candidateEval.params - hyperparameters
87
+ * @param {object} [candidateEval.budget] - optional budget constraint
88
+ * @returns {Promise<object>} CandidateResult
89
+ */
90
+ async evaluateCandidate({ candidateId, cls, params, budget }) {
91
+ const folds = this.#folds
92
+ const scores = new Float64Array(folds.length)
93
+ const t0 = now()
94
+ let totalTrainUsed = 0
95
+
96
+ // Resolve effective params (apply rounds budget if applicable)
97
+ const effectiveParams = this.#applyRoundsBudget(cls, params, budget)
98
+
99
+ for (let f = 0; f < folds.length; f++) {
100
+ let { train, test } = folds[f]
101
+
102
+ // Apply subsample budget to train only
103
+ if (budget && budget.type === 'subsample') {
104
+ train = this.#subsampleTrain(train, budget.value, candidateId, f)
105
+ }
106
+
107
+ totalTrainUsed += train.length
108
+
109
+ const Xtrain = subsetX(this.#X, train)
110
+ const ytrain = subsetY(this.#y, train)
111
+ const Xtest = subsetX(this.#X, test)
112
+ const ytest = subsetY(this.#y, test)
113
+
114
+ const model = await cls.create(effectiveParams)
115
+ try {
116
+ model.fit(Xtrain, ytrain)
117
+ const preds = await model.predict(Xtest)
118
+ scores[f] = this.#scorerFn(ytest, preds)
119
+ } finally {
120
+ model.dispose()
121
+ }
122
+ }
123
+
124
+ const fitTimeMs = now() - t0
125
+
126
+ // Record in leaderboard
127
+ const entry = this.#leaderboard.add({
128
+ modelName: candidateId.split(':')[0],
129
+ params,
130
+ scores,
131
+ fitTimeMs,
132
+ })
133
+
134
+ return {
135
+ candidateId,
136
+ meanScore: entry.meanScore,
137
+ foldScores: scores,
138
+ stdScore: entry.stdScore,
139
+ fitTimeMs,
140
+ nTrainUsed: Math.round(totalTrainUsed / folds.length),
141
+ nTest: folds[0].test.length,
142
+ }
143
+ }
144
+
145
+ /**
146
+ * Apply rounds budget by setting the model's rounds param if:
147
+ * 1. Budget type is 'rounds'
148
+ * 2. Model exposes budgetSpec().roundsParam
149
+ * 3. Candidate params don't already set that param (candidate config wins)
150
+ */
151
+ #applyRoundsBudget(cls, params, budget) {
152
+ if (!budget || budget.type !== 'rounds') return params
153
+ const spec = cls.budgetSpec?.()
154
+ if (!spec || !spec.roundsParam) return params
155
+ if (params[spec.roundsParam] !== undefined) return params
156
+ return { ...params, [spec.roundsParam]: budget.value }
157
+ }
158
+
159
+ /**
160
+ * Subsample train indices using partial Fisher-Yates with deterministic seed.
161
+ * Returns a new array of selected indices. Test indices are never subsampled.
162
+ */
163
+ #subsampleTrain(train, fraction, candidateId, foldIdx) {
164
+ const k = Math.max(1, ceil(train.length * fraction))
165
+ if (k >= train.length) return train
166
+ // Copy to avoid mutating the original fold indices
167
+ const copy = new Int32Array(train)
168
+ const seed = seedFor(candidateId, foldIdx, this.#seed)
169
+ const rng = makeLCG(seed)
170
+ return partialShuffle(copy, k, rng)
171
+ }
172
+
173
+ /**
174
+ * Run a strategy to completion.
175
+ * Returns { leaderboard } only. Callers decide "best".
176
+ */
177
+ async runStrategy(strategy) {
178
+ let done = 0
179
+ while (!strategy.isDone()) {
180
+ if (this.isTimedOut) break
181
+ const task = strategy.next()
182
+ if (task === null) break
183
+ try {
184
+ const result = await this.evaluateCandidate(task)
185
+ strategy.report(result)
186
+ done++
187
+ if (this.#onProgress) {
188
+ const best = this.#leaderboard.best()
189
+ this.#onProgress({
190
+ phase: 'search',
191
+ candidatesDone: done,
192
+ bestScore: best ? best.meanScore : null,
193
+ bestModel: best ? best.modelName : null,
194
+ lastCandidate: {
195
+ model: result.candidateId.split(':')[0],
196
+ score: result.meanScore,
197
+ timeMs: result.fitTimeMs,
198
+ },
199
+ elapsedMs: now() - this.#startTime,
200
+ })
201
+ }
202
+ } catch {
203
+ done++
204
+ // Skip failed candidates (invalid params, create errors, etc.)
205
+ }
206
+ }
207
+ return { leaderboard: this.#leaderboard }
208
+ }
209
+ }
package/src/halving.js ADDED
@@ -0,0 +1,95 @@
1
+ import { stratifiedKFold, kFold, normalizeX, normalizeY,
2
+ ValidationError } from '@wlearn/core'
3
+ import { Executor } from './executor.js'
4
+ import { HalvingStrategy } from './strategy-halving.js'
5
+ import { detectTask, scorerGreaterIsBetter } from './common.js'
6
+
7
+ /**
8
+ * Successive halving search: multi-round elimination tournament.
9
+ * Evaluates many candidates on small subsamples, progressively
10
+ * eliminates the worst and increases resource allocation.
11
+ */
12
+ export class SuccessiveHalvingSearch {
13
+ #models
14
+ #opts
15
+ #leaderboard = null
16
+ #bestResult = null
17
+ #rounds = null
18
+
19
+ constructor(models, opts = {}) {
20
+ if (!models || models.length === 0) {
21
+ throw new ValidationError('SuccessiveHalvingSearch: at least one model is required')
22
+ }
23
+ this.#models = models
24
+ this.#opts = {
25
+ scoring: null,
26
+ cv: 5,
27
+ seed: 42,
28
+ task: null,
29
+ nIter: 20,
30
+ maxTimeMs: 0,
31
+ factor: 3,
32
+ minResources: 0,
33
+ onProgress: null,
34
+ ...opts,
35
+ }
36
+ }
37
+
38
+ async fit(X, y) {
39
+ const Xn = normalizeX(X)
40
+ const yn = normalizeY(y)
41
+ const n = Xn.rows
42
+ const task = this.#opts.task || detectTask(yn)
43
+ const scoring = this.#opts.scoring || (task === 'classification' ? 'accuracy' : 'r2')
44
+ const greaterIsBetter = scorerGreaterIsBetter(scoring)
45
+ const { cv, seed, nIter, maxTimeMs, factor, onProgress } = this.#opts
46
+
47
+ // Generate base folds on full data
48
+ const folds = task === 'classification'
49
+ ? stratifiedKFold(yn, cv, { shuffle: true, seed })
50
+ : kFold(n, cv, { shuffle: true, seed })
51
+
52
+ const executor = new Executor({
53
+ folds,
54
+ scoring,
55
+ X: Xn,
56
+ y: yn,
57
+ timeLimitMs: maxTimeMs,
58
+ seed,
59
+ onProgress,
60
+ })
61
+
62
+ const strategy = new HalvingStrategy(this.#models, {
63
+ nIter,
64
+ seed,
65
+ factor,
66
+ nSamples: n,
67
+ greaterIsBetter,
68
+ cv,
69
+ })
70
+
71
+ const { leaderboard } = await executor.runStrategy(strategy)
72
+
73
+ this.#leaderboard = leaderboard
74
+ this.#bestResult = leaderboard.best()
75
+ this.#rounds = strategy.rounds
76
+ return { leaderboard, bestResult: this.#bestResult, rounds: this.#rounds }
77
+ }
78
+
79
+ async refitBest(X, y) {
80
+ if (!this.#bestResult) {
81
+ throw new ValidationError('SuccessiveHalvingSearch: must call fit() first')
82
+ }
83
+ const best = this.#bestResult
84
+ const model = this.#models.find(m => m.name === best.modelName)
85
+ const instance = await model.cls.create(best.params)
86
+ const Xn = normalizeX(X)
87
+ const yn = normalizeY(y)
88
+ instance.fit(Xn, yn)
89
+ return instance
90
+ }
91
+
92
+ get leaderboard() { return this.#leaderboard }
93
+ get bestResult() { return this.#bestResult }
94
+ get rounds() { return this.#rounds }
95
+ }