s3db.js 12.4.0 → 13.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,459 @@
1
+ /**
2
+ * Base Model Class
3
+ *
4
+ * Abstract base class for all ML models
5
+ * Provides common functionality for training, prediction, and persistence
6
+ */
7
+
8
+ import {
9
+ TrainingError,
10
+ PredictionError,
11
+ ModelNotTrainedError,
12
+ DataValidationError,
13
+ InsufficientDataError,
14
+ TensorFlowDependencyError
15
+ } from '../ml.errors.js';
16
+
17
+ export class BaseModel {
18
+ constructor(config = {}) {
19
+ if (this.constructor === BaseModel) {
20
+ throw new Error('BaseModel is an abstract class and cannot be instantiated directly');
21
+ }
22
+
23
+ this.config = {
24
+ name: config.name || 'unnamed',
25
+ resource: config.resource,
26
+ features: config.features || [],
27
+ target: config.target,
28
+ modelConfig: {
29
+ epochs: 50,
30
+ batchSize: 32,
31
+ learningRate: 0.01,
32
+ validationSplit: 0.2,
33
+ ...config.modelConfig
34
+ },
35
+ verbose: config.verbose || false
36
+ };
37
+
38
+ // Model state
39
+ this.model = null;
40
+ this.isTrained = false;
41
+ this.normalizer = {
42
+ features: {},
43
+ target: {}
44
+ };
45
+ this.stats = {
46
+ trainedAt: null,
47
+ samples: 0,
48
+ loss: null,
49
+ accuracy: null,
50
+ predictions: 0,
51
+ errors: 0
52
+ };
53
+
54
+ // Validate TensorFlow.js
55
+ this._validateTensorFlow();
56
+ }
57
+
58
+ /**
59
+ * Validate TensorFlow.js is installed
60
+ * @private
61
+ */
62
+ _validateTensorFlow() {
63
+ try {
64
+ this.tf = require('@tensorflow/tfjs-node');
65
+ } catch (error) {
66
+ throw new TensorFlowDependencyError(
67
+ 'TensorFlow.js is not installed. Run: pnpm add @tensorflow/tfjs-node',
68
+ { originalError: error.message }
69
+ );
70
+ }
71
+ }
72
+
73
+ /**
74
+ * Abstract method: Build the model architecture
75
+ * Must be implemented by subclasses
76
+ * @abstract
77
+ */
78
+ buildModel() {
79
+ throw new Error('buildModel() must be implemented by subclass');
80
+ }
81
+
82
+ /**
83
+ * Train the model with provided data
84
+ * @param {Array} data - Training data records
85
+ * @returns {Object} Training results
86
+ */
87
+ async train(data) {
88
+ try {
89
+ if (!data || data.length === 0) {
90
+ throw new InsufficientDataError('No training data provided', {
91
+ model: this.config.name
92
+ });
93
+ }
94
+
95
+ // Validate minimum samples
96
+ const minSamples = this.config.modelConfig.batchSize || 10;
97
+ if (data.length < minSamples) {
98
+ throw new InsufficientDataError(
99
+ `Insufficient training data: ${data.length} samples (minimum: ${minSamples})`,
100
+ { model: this.config.name, samples: data.length, minimum: minSamples }
101
+ );
102
+ }
103
+
104
+ // Prepare data (extract features and target)
105
+ const { xs, ys } = this._prepareData(data);
106
+
107
+ // Build model if not already built
108
+ if (!this.model) {
109
+ this.buildModel();
110
+ }
111
+
112
+ // Train the model
113
+ const history = await this.model.fit(xs, ys, {
114
+ epochs: this.config.modelConfig.epochs,
115
+ batchSize: this.config.modelConfig.batchSize,
116
+ validationSplit: this.config.modelConfig.validationSplit,
117
+ verbose: this.config.verbose ? 1 : 0,
118
+ callbacks: {
119
+ onEpochEnd: (epoch, logs) => {
120
+ if (this.config.verbose && epoch % 10 === 0) {
121
+ console.log(`[MLPlugin] ${this.config.name} - Epoch ${epoch}: loss=${logs.loss.toFixed(4)}`);
122
+ }
123
+ }
124
+ }
125
+ });
126
+
127
+ // Update stats
128
+ this.isTrained = true;
129
+ this.stats.trainedAt = new Date().toISOString();
130
+ this.stats.samples = data.length;
131
+ this.stats.loss = history.history.loss[history.history.loss.length - 1];
132
+
133
+ // Get accuracy if available (classification models)
134
+ if (history.history.acc) {
135
+ this.stats.accuracy = history.history.acc[history.history.acc.length - 1];
136
+ }
137
+
138
+ // Cleanup tensors
139
+ xs.dispose();
140
+ ys.dispose();
141
+
142
+ if (this.config.verbose) {
143
+ console.log(`[MLPlugin] ${this.config.name} - Training completed:`, {
144
+ samples: this.stats.samples,
145
+ loss: this.stats.loss,
146
+ accuracy: this.stats.accuracy
147
+ });
148
+ }
149
+
150
+ return {
151
+ loss: this.stats.loss,
152
+ accuracy: this.stats.accuracy,
153
+ epochs: this.config.modelConfig.epochs,
154
+ samples: this.stats.samples
155
+ };
156
+ } catch (error) {
157
+ this.stats.errors++;
158
+ if (error instanceof InsufficientDataError || error instanceof DataValidationError) {
159
+ throw error;
160
+ }
161
+ throw new TrainingError(`Training failed: ${error.message}`, {
162
+ model: this.config.name,
163
+ originalError: error.message
164
+ });
165
+ }
166
+ }
167
+
168
+ /**
169
+ * Make a prediction with the trained model
170
+ * @param {Object} input - Input features
171
+ * @returns {Object} Prediction result
172
+ */
173
+ async predict(input) {
174
+ if (!this.isTrained) {
175
+ throw new ModelNotTrainedError(`Model "${this.config.name}" is not trained yet`, {
176
+ model: this.config.name
177
+ });
178
+ }
179
+
180
+ try {
181
+ // Validate input
182
+ this._validateInput(input);
183
+
184
+ // Extract and normalize features
185
+ const features = this._extractFeatures(input);
186
+ const normalizedFeatures = this._normalizeFeatures(features);
187
+
188
+ // Convert to tensor
189
+ const inputTensor = this.tf.tensor2d([normalizedFeatures]);
190
+
191
+ // Predict
192
+ const predictionTensor = this.model.predict(inputTensor);
193
+ const predictionArray = await predictionTensor.data();
194
+
195
+ // Cleanup
196
+ inputTensor.dispose();
197
+ predictionTensor.dispose();
198
+
199
+ // Denormalize prediction
200
+ const prediction = this._denormalizePrediction(predictionArray[0]);
201
+
202
+ this.stats.predictions++;
203
+
204
+ return {
205
+ prediction,
206
+ confidence: this._calculateConfidence(predictionArray[0])
207
+ };
208
+ } catch (error) {
209
+ this.stats.errors++;
210
+ if (error instanceof ModelNotTrainedError || error instanceof DataValidationError) {
211
+ throw error;
212
+ }
213
+ throw new PredictionError(`Prediction failed: ${error.message}`, {
214
+ model: this.config.name,
215
+ input,
216
+ originalError: error.message
217
+ });
218
+ }
219
+ }
220
+
221
+ /**
222
+ * Make predictions for multiple inputs
223
+ * @param {Array} inputs - Array of input objects
224
+ * @returns {Array} Array of prediction results
225
+ */
226
+ async predictBatch(inputs) {
227
+ if (!this.isTrained) {
228
+ throw new ModelNotTrainedError(`Model "${this.config.name}" is not trained yet`, {
229
+ model: this.config.name
230
+ });
231
+ }
232
+
233
+ const predictions = [];
234
+ for (const input of inputs) {
235
+ predictions.push(await this.predict(input));
236
+ }
237
+ return predictions;
238
+ }
239
+
240
+ /**
241
+ * Prepare training data (extract features and target)
242
+ * @private
243
+ * @param {Array} data - Raw training data
244
+ * @returns {Object} Prepared tensors {xs, ys}
245
+ */
246
+ _prepareData(data) {
247
+ const features = [];
248
+ const targets = [];
249
+
250
+ for (const record of data) {
251
+ // Validate record has required fields
252
+ const missingFeatures = this.config.features.filter(f => !(f in record));
253
+ if (missingFeatures.length > 0) {
254
+ throw new DataValidationError(
255
+ `Missing features in training data: ${missingFeatures.join(', ')}`,
256
+ { model: this.config.name, missingFeatures, record }
257
+ );
258
+ }
259
+
260
+ if (!(this.config.target in record)) {
261
+ throw new DataValidationError(
262
+ `Missing target "${this.config.target}" in training data`,
263
+ { model: this.config.name, target: this.config.target, record }
264
+ );
265
+ }
266
+
267
+ // Extract features
268
+ const featureValues = this._extractFeatures(record);
269
+ features.push(featureValues);
270
+
271
+ // Extract target
272
+ targets.push(record[this.config.target]);
273
+ }
274
+
275
+ // Calculate normalization parameters
276
+ this._calculateNormalizer(features, targets);
277
+
278
+ // Normalize data
279
+ const normalizedFeatures = features.map(f => this._normalizeFeatures(f));
280
+ const normalizedTargets = targets.map(t => this._normalizeTarget(t));
281
+
282
+ // Convert to tensors
283
+ return {
284
+ xs: this.tf.tensor2d(normalizedFeatures),
285
+ ys: this._prepareTargetTensor(normalizedTargets)
286
+ };
287
+ }
288
+
289
+ /**
290
+ * Prepare target tensor (can be overridden by subclasses)
291
+ * @protected
292
+ * @param {Array} targets - Normalized target values
293
+ * @returns {Tensor} Target tensor
294
+ */
295
+ _prepareTargetTensor(targets) {
296
+ return this.tf.tensor2d(targets.map(t => [t]));
297
+ }
298
+
299
+ /**
300
+ * Extract feature values from a record
301
+ * @private
302
+ * @param {Object} record - Data record
303
+ * @returns {Array} Feature values
304
+ */
305
+ _extractFeatures(record) {
306
+ return this.config.features.map(feature => {
307
+ const value = record[feature];
308
+ if (typeof value !== 'number') {
309
+ throw new DataValidationError(
310
+ `Feature "${feature}" must be a number, got ${typeof value}`,
311
+ { model: this.config.name, feature, value, type: typeof value }
312
+ );
313
+ }
314
+ return value;
315
+ });
316
+ }
317
+
318
+ /**
319
+ * Calculate normalization parameters (min-max scaling)
320
+ * @private
321
+ */
322
+ _calculateNormalizer(features, targets) {
323
+ const numFeatures = features[0].length;
324
+
325
+ // Initialize normalizer
326
+ for (let i = 0; i < numFeatures; i++) {
327
+ const featureName = this.config.features[i];
328
+ const values = features.map(f => f[i]);
329
+ this.normalizer.features[featureName] = {
330
+ min: Math.min(...values),
331
+ max: Math.max(...values)
332
+ };
333
+ }
334
+
335
+ // Normalize target
336
+ this.normalizer.target = {
337
+ min: Math.min(...targets),
338
+ max: Math.max(...targets)
339
+ };
340
+ }
341
+
342
+ /**
343
+ * Normalize features using min-max scaling
344
+ * @private
345
+ */
346
+ _normalizeFeatures(features) {
347
+ return features.map((value, i) => {
348
+ const featureName = this.config.features[i];
349
+ const { min, max } = this.normalizer.features[featureName];
350
+ if (max === min) return 0.5; // Handle constant features
351
+ return (value - min) / (max - min);
352
+ });
353
+ }
354
+
355
+ /**
356
+ * Normalize target value
357
+ * @private
358
+ */
359
+ _normalizeTarget(target) {
360
+ const { min, max } = this.normalizer.target;
361
+ if (max === min) return 0.5;
362
+ return (target - min) / (max - min);
363
+ }
364
+
365
+ /**
366
+ * Denormalize prediction
367
+ * @private
368
+ */
369
+ _denormalizePrediction(normalizedValue) {
370
+ const { min, max } = this.normalizer.target;
371
+ return normalizedValue * (max - min) + min;
372
+ }
373
+
374
+ /**
375
+ * Calculate confidence score (can be overridden)
376
+ * @protected
377
+ */
378
+ _calculateConfidence(value) {
379
+ // Default: simple confidence based on normalized value
380
+ // Closer to 0 or 1 = higher confidence, closer to 0.5 = lower confidence
381
+ const distanceFrom05 = Math.abs(value - 0.5);
382
+ return Math.min(0.5 + distanceFrom05, 1.0);
383
+ }
384
+
385
+ /**
386
+ * Validate input data
387
+ * @private
388
+ */
389
+ _validateInput(input) {
390
+ const missingFeatures = this.config.features.filter(f => !(f in input));
391
+ if (missingFeatures.length > 0) {
392
+ throw new DataValidationError(
393
+ `Missing features: ${missingFeatures.join(', ')}`,
394
+ { model: this.config.name, missingFeatures, input }
395
+ );
396
+ }
397
+ }
398
+
399
+ /**
400
+ * Export model to JSON (for persistence)
401
+ * @returns {Object} Serialized model
402
+ */
403
+ async export() {
404
+ if (!this.model) {
405
+ return null;
406
+ }
407
+
408
+ const modelJSON = await this.model.toJSON();
409
+
410
+ return {
411
+ config: this.config,
412
+ normalizer: this.normalizer,
413
+ stats: this.stats,
414
+ isTrained: this.isTrained,
415
+ model: modelJSON
416
+ };
417
+ }
418
+
419
+ /**
420
+ * Import model from JSON
421
+ * @param {Object} data - Serialized model data
422
+ */
423
+ async import(data) {
424
+ this.config = data.config;
425
+ this.normalizer = data.normalizer;
426
+ this.stats = data.stats;
427
+ this.isTrained = data.isTrained;
428
+
429
+ if (data.model) {
430
+ // Note: Actual model reconstruction depends on the model type
431
+ // This is a placeholder and should be overridden by subclasses
432
+ this.buildModel();
433
+ }
434
+ }
435
+
436
+ /**
437
+ * Dispose model and free memory
438
+ */
439
+ dispose() {
440
+ if (this.model) {
441
+ this.model.dispose();
442
+ this.model = null;
443
+ }
444
+ this.isTrained = false;
445
+ }
446
+
447
+ /**
448
+ * Get model statistics
449
+ */
450
+ getStats() {
451
+ return {
452
+ ...this.stats,
453
+ isTrained: this.isTrained,
454
+ config: this.config
455
+ };
456
+ }
457
+ }
458
+
459
+ export default BaseModel;