@hazeljs/ml 0.2.0-beta.69 → 0.2.0-beta.70
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +121 -7
- package/dist/evaluation/metrics.service.d.ts +24 -0
- package/dist/evaluation/metrics.service.d.ts.map +1 -1
- package/dist/evaluation/metrics.service.js +111 -2
- package/dist/evaluation/metrics.service.test.js +174 -0
- package/dist/index.d.ts +1 -1
- package/dist/index.d.ts.map +1 -1
- package/dist/inference/batch.service.d.ts.map +1 -1
- package/dist/inference/batch.service.js +9 -4
- package/dist/inference/batch.service.test.js +25 -0
- package/dist/inference/predictor.service.test.js +29 -0
- package/dist/registry/model.registry.test.js +17 -0
- package/dist/training/pipeline.service.d.ts +8 -0
- package/dist/training/pipeline.service.d.ts.map +1 -1
- package/dist/training/pipeline.service.js +21 -5
- package/dist/training/pipeline.service.test.js +14 -0
- package/dist/training/trainer.service.test.js +10 -0
- package/package.json +2 -2
package/README.md
CHANGED
|
@@ -107,6 +107,89 @@ export class MLController {
|
|
|
107
107
|
}
|
|
108
108
|
```
|
|
109
109
|
|
|
110
|
+
## ML Decorators
|
|
111
|
+
|
|
112
|
+
The package uses three decorators to declare ML models and their behaviour. The registry and services discover them via reflection—no manual wiring needed.
|
|
113
|
+
|
|
114
|
+
### `@Model` (class)
|
|
115
|
+
|
|
116
|
+
Marks a class as an ML model and attaches **registry metadata**. Required so the model can be registered and looked up by name/version.
|
|
117
|
+
|
|
118
|
+
| Property | Type | Required | Description |
|
|
119
|
+
|---------------|----------|----------|-------------|
|
|
120
|
+
| `name` | string | Yes | Unique model id (e.g. `'sentiment-classifier'`). |
|
|
121
|
+
| `version` | string | Yes | Semver (e.g. `'1.0.0'`). |
|
|
122
|
+
| `framework` | string | Yes | `'tensorflow'` \| `'onnx'` \| `'custom'`. |
|
|
123
|
+
| `description` | string | No | Human-readable description. |
|
|
124
|
+
| `tags` | string[] | No | Tags for filtering (default: `[]`). |
|
|
125
|
+
|
|
126
|
+
**Example:** One model per class; use `@Injectable()` so the app can construct it.
|
|
127
|
+
|
|
128
|
+
```typescript
|
|
129
|
+
@Model({
|
|
130
|
+
name: 'spam-classifier',
|
|
131
|
+
version: '1.0.0',
|
|
132
|
+
framework: 'custom',
|
|
133
|
+
description: 'Binary spam/ham classifier',
|
|
134
|
+
tags: ['nlp', 'moderation'],
|
|
135
|
+
})
|
|
136
|
+
@Injectable()
|
|
137
|
+
export class SpamClassifier {
|
|
138
|
+
// ...
|
|
139
|
+
}
|
|
140
|
+
```
|
|
141
|
+
|
|
142
|
+
---
|
|
143
|
+
|
|
144
|
+
### `@Train` (method)
|
|
145
|
+
|
|
146
|
+
Marks the **single method** that trains this model. `TrainerService.train(modelName, data)` will call it. Optional config is for documentation or pipeline wiring.
|
|
147
|
+
|
|
148
|
+
| Option | Type | Default | Description |
|
|
149
|
+
|-------------|--------|-----------|-------------|
|
|
150
|
+
| `pipeline` | string | `'default'` | Name of a registered `PipelineService` pipeline to run before training. |
|
|
151
|
+
| `batchSize` | number | `32` | Hint for batching (your logic can ignore it). |
|
|
152
|
+
| `epochs` | number | `10` | Hint for epochs (your logic can ignore it). |
|
|
153
|
+
|
|
154
|
+
**Example:** Exactly one `@Train()` method per model; it receives training data and can return metrics.
|
|
155
|
+
|
|
156
|
+
```typescript
|
|
157
|
+
@Train({ pipeline: 'sentiment-preprocessing', epochs: 5 })
|
|
158
|
+
async train(data: { samples: Array<{ text: string; label: string }> }): Promise<TrainingResult> {
|
|
159
|
+
// Your training logic
|
|
160
|
+
return { accuracy: 0.95, loss: 0.05 };
|
|
161
|
+
}
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
---
|
|
165
|
+
|
|
166
|
+
### `@Predict` (method)
|
|
167
|
+
|
|
168
|
+
Marks the **single method** that runs inference. `PredictorService.predict(modelName, input)` will call it.
|
|
169
|
+
|
|
170
|
+
| Option | Type | Default | Description |
|
|
171
|
+
|-----------|---------|------------|-------------|
|
|
172
|
+
| `batch` | boolean | `false` | Hint that the method supports batch input (semantic only). |
|
|
173
|
+
| `endpoint`| string | `'/predict'` | Hint for route naming (semantic only). |
|
|
174
|
+
|
|
175
|
+
**Example:** Exactly one `@Predict()` method per model; it receives one input and returns a prediction object.
|
|
176
|
+
|
|
177
|
+
```typescript
|
|
178
|
+
@Predict({ batch: true, endpoint: '/predict' })
|
|
179
|
+
async predict(input: { text: string }): Promise<{ label: string; confidence: number }> {
|
|
180
|
+
// Your inference logic
|
|
181
|
+
return { label: 'ham', confidence: 0.92 };
|
|
182
|
+
}
|
|
183
|
+
```
|
|
184
|
+
|
|
185
|
+
---
|
|
186
|
+
|
|
187
|
+
### Rules
|
|
188
|
+
|
|
189
|
+
- **One model class** = one `@Model`, one `@Train` method, one `@Predict` method.
|
|
190
|
+
- **Order:** Put `@Model` on the class, then `@Train` and `@Predict` on the methods. Use `@Injectable()` from `@hazeljs/core` so the app can instantiate the model.
|
|
191
|
+
- **Discovery:** When you pass model classes to `MLModule.forRoot({ models: [...] })`, the bootstrap finds the `@Train` and `@Predict` methods and registers them with the registry.
|
|
192
|
+
|
|
110
193
|
## Model registration
|
|
111
194
|
|
|
112
195
|
Models are registered when passed to `MLModule.forRoot({ models: [...] })`. The bootstrap discovers `@Train` and `@Predict` methods via reflection.
|
|
@@ -134,11 +217,16 @@ import { PipelineService } from '@hazeljs/ml';
|
|
|
134
217
|
|
|
135
218
|
const pipeline = new PipelineService();
|
|
136
219
|
const steps = [
|
|
137
|
-
{ name: 'normalize',
|
|
138
|
-
{ name: 'filter',
|
|
220
|
+
{ name: 'normalize', transform: (d: unknown) => ({ ...(d as object), text: (d as { text: string }).text?.toLowerCase() }) },
|
|
221
|
+
{ name: 'filter', transform: (d: unknown) => (d as { text: string }).text?.length > 0 ? d : null },
|
|
139
222
|
];
|
|
223
|
+
// Inline steps (no registration required)
|
|
140
224
|
const processed = await pipeline.run(data, steps);
|
|
141
225
|
await model.train(processed);
|
|
226
|
+
|
|
227
|
+
// Or register a named pipeline
|
|
228
|
+
pipeline.registerPipeline('default', steps);
|
|
229
|
+
const processed2 = await pipeline.run('default', data);
|
|
142
230
|
```
|
|
143
231
|
|
|
144
232
|
## Batch predictions
|
|
@@ -149,7 +237,9 @@ import { BatchService } from '@hazeljs/ml';
|
|
|
149
237
|
const batchService = new BatchService(predictorService);
|
|
150
238
|
const results = await batchService.predictBatch('sentiment-classifier', items, {
|
|
151
239
|
batchSize: 32,
|
|
240
|
+
concurrency: 4,
|
|
152
241
|
});
|
|
242
|
+
// Results are returned in the same order as inputs
|
|
153
243
|
```
|
|
154
244
|
|
|
155
245
|
## Metrics and evaluation
|
|
@@ -157,9 +247,32 @@ const results = await batchService.predictBatch('sentiment-classifier', items, {
|
|
|
157
247
|
```typescript
|
|
158
248
|
import { MetricsService } from '@hazeljs/ml';
|
|
159
249
|
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
250
|
+
// Evaluate model on test data (inject MetricsService via MLModule - it receives PredictorService)
|
|
251
|
+
@Injectable()
|
|
252
|
+
class EvaluationService {
|
|
253
|
+
constructor(private metricsService: MetricsService) {}
|
|
254
|
+
|
|
255
|
+
async runEvaluation() {
|
|
256
|
+
const testData = [
|
|
257
|
+
{ text: 'great product', label: 'positive' },
|
|
258
|
+
{ text: 'terrible', label: 'negative' },
|
|
259
|
+
];
|
|
260
|
+
const evaluation = await this.metricsService.evaluate('sentiment-classifier', testData, {
|
|
261
|
+
metrics: ['accuracy', 'f1', 'precision', 'recall'],
|
|
262
|
+
labelKey: 'label', // key in test sample for ground truth
|
|
263
|
+
predictionKey: 'sentiment', // key in prediction result (auto-detect: label, sentiment, class)
|
|
264
|
+
});
|
|
265
|
+
// evaluation.metrics: { accuracy, precision, recall, f1Score }
|
|
266
|
+
// Result is automatically recorded via recordEvaluation()
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// Manual recording
|
|
271
|
+
metricsService.recordEvaluation({
|
|
272
|
+
modelName: 'my-model',
|
|
273
|
+
version: '1.0.0',
|
|
274
|
+
metrics: { accuracy: 0.95, loss: 0.05 },
|
|
275
|
+
evaluatedAt: new Date(),
|
|
163
276
|
});
|
|
164
277
|
```
|
|
165
278
|
|
|
@@ -174,9 +287,10 @@ const evaluation = await metricsService.evaluate(modelName, testData, {
|
|
|
174
287
|
| `BatchService` | Batch prediction with configurable batch size |
|
|
175
288
|
| `MetricsService` | Model evaluation and metrics tracking |
|
|
176
289
|
|
|
177
|
-
##
|
|
290
|
+
## Examples
|
|
178
291
|
|
|
179
|
-
|
|
292
|
+
- **[hazeljs-ml-starter](../../../hazeljs-ml-starter)** – Full app: sentiment, spam, intent classifiers, REST API, training pipeline, and metrics.
|
|
293
|
+
- **[example/src/ml](../../example/src/ml)** – Minimal runnable example of the three decorators (`npm run ml:decorators` from the example repo).
|
|
180
294
|
|
|
181
295
|
## Links
|
|
182
296
|
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import { ModelRegistry } from '../registry/model.registry';
|
|
2
|
+
import { PredictorService } from '../inference/predictor.service';
|
|
1
3
|
export interface ModelMetrics {
|
|
2
4
|
accuracy?: number;
|
|
3
5
|
precision?: number;
|
|
@@ -12,11 +14,21 @@ export interface EvaluationResult {
|
|
|
12
14
|
metrics: ModelMetrics;
|
|
13
15
|
evaluatedAt: Date;
|
|
14
16
|
}
|
|
17
|
+
export type EvaluateMetric = 'accuracy' | 'f1' | 'precision' | 'recall';
|
|
18
|
+
export interface EvaluateOptions {
|
|
19
|
+
metrics?: EvaluateMetric[];
|
|
20
|
+
labelKey?: string;
|
|
21
|
+
predictionKey?: string;
|
|
22
|
+
version?: string;
|
|
23
|
+
}
|
|
15
24
|
/**
|
|
16
25
|
* Metrics Service - Model evaluation and metrics
|
|
17
26
|
* Tracks model performance for A/B testing and monitoring
|
|
18
27
|
*/
|
|
19
28
|
export declare class MetricsService {
|
|
29
|
+
private readonly modelRegistry?;
|
|
30
|
+
private readonly predictorService?;
|
|
31
|
+
constructor(modelRegistry?: ModelRegistry | undefined, predictorService?: PredictorService | undefined);
|
|
20
32
|
private metrics;
|
|
21
33
|
recordEvaluation(result: EvaluationResult): void;
|
|
22
34
|
getMetrics(modelName: string, version?: string): EvaluationResult | undefined;
|
|
@@ -26,5 +38,17 @@ export declare class MetricsService {
|
|
|
26
38
|
b: EvaluationResult | undefined;
|
|
27
39
|
winner?: string;
|
|
28
40
|
};
|
|
41
|
+
/**
|
|
42
|
+
* Evaluate model on test data by running predictions and computing metrics.
|
|
43
|
+
* Requires PredictorService to be injected.
|
|
44
|
+
*
|
|
45
|
+
* @param modelName - Registered model name
|
|
46
|
+
* @param testData - Array of samples. Each sample must contain the model input and a label key.
|
|
47
|
+
* @param options - labelKey (default: 'label'), predictionKey (tries 'label'|'sentiment'|'class'), metrics, version
|
|
48
|
+
*/
|
|
49
|
+
evaluate(modelName: string, testData: Record<string, unknown>[], options?: EvaluateOptions): Promise<EvaluationResult>;
|
|
50
|
+
private extractPredictedLabel;
|
|
51
|
+
private computeAccuracy;
|
|
52
|
+
private computePrecisionRecallF1;
|
|
29
53
|
}
|
|
30
54
|
//# sourceMappingURL=metrics.service.d.ts.map
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"metrics.service.d.ts","sourceRoot":"","sources":["../../src/evaluation/metrics.service.ts"],"names":[],"mappings":"
|
|
1
|
+
{"version":3,"file":"metrics.service.d.ts","sourceRoot":"","sources":["../../src/evaluation/metrics.service.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,aAAa,EAAE,MAAM,4BAA4B,CAAC;AAC3D,OAAO,EAAE,gBAAgB,EAAE,MAAM,gCAAgC,CAAC;AAIlE,MAAM,WAAW,YAAY;IAC3B,QAAQ,CAAC,EAAE,MAAM,CAAC;IAClB,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB,MAAM,CAAC,EAAE,MAAM,CAAC;IAChB,OAAO,CAAC,EAAE,MAAM,CAAC;IACjB,IAAI,CAAC,EAAE,MAAM,CAAC;IACd,CAAC,GAAG,EAAE,MAAM,GAAG,MAAM,GAAG,SAAS,CAAC;CACnC;AAED,MAAM,WAAW,gBAAgB;IAC/B,SAAS,EAAE,MAAM,CAAC;IAClB,OAAO,EAAE,MAAM,CAAC;IAChB,OAAO,EAAE,YAAY,CAAC;IACtB,WAAW,EAAE,IAAI,CAAC;CACnB;AAED,MAAM,MAAM,cAAc,GAAG,UAAU,GAAG,IAAI,GAAG,WAAW,GAAG,QAAQ,CAAC;AAExE,MAAM,WAAW,eAAe;IAC9B,OAAO,CAAC,EAAE,cAAc,EAAE,CAAC;IAC3B,QAAQ,CAAC,EAAE,MAAM,CAAC;IAClB,aAAa,CAAC,EAAE,MAAM,CAAC;IACvB,OAAO,CAAC,EAAE,MAAM,CAAC;CAClB;AAED;;;GAGG;AACH,qBACa,cAAc;IAEvB,OAAO,CAAC,QAAQ,CAAC,aAAa,CAAC;IAC/B,OAAO,CAAC,QAAQ,CAAC,gBAAgB,CAAC;gBADjB,aAAa,CAAC,EAAE,aAAa,YAAA,EAC7B,gBAAgB,CAAC,EAAE,gBAAgB,YAAA;IAGtD,OAAO,CAAC,OAAO,CAA8C;IAE7D,gBAAgB,CAAC,MAAM,EAAE,gBAAgB,GAAG,IAAI;IAQhD,UAAU,CAAC,SAAS,EAAE,MAAM,EAAE,OAAO,CAAC,EAAE,MAAM,GAAG,gBAAgB,GAAG,SAAS;IAQ7E,UAAU,CAAC,SAAS,EAAE,MAAM,GAAG,gBAAgB,EAAE;IAIjD,eAAe,CACb,SAAS,EAAE,MAAM,EACjB,QAAQ,EAAE,MAAM,EAChB,QAAQ,EAAE,MAAM,GACf;QACD,CAAC,EAAE,gBAAgB,GAAG,SAAS,CAAC;QAChC,CAAC,EAAE,gBAAgB,GAAG,SAAS,CAAC;QAChC,MAAM,CAAC,EAAE,MAAM,CAAC;KACjB;IAaD;;;;;;;OAOG;IACG,QAAQ,CACZ,SAAS,EAAE,MAAM,EACjB,QAAQ,EAAE,MAAM,CAAC,MAAM,EAAE,OAAO,CAAC,EAAE,EACnC,OAAO,GAAE,eAAoB,GAC5B,OAAO,CAAC,gBAAgB,CAAC;IA2D5B,OAAO,CAAC,qBAAqB;IAW7B,OAAO,CAAC,eAAe;IAQvB,OAAO,CAAC,wBAAwB;CAmCjC"}
|
|
@@ -5,19 +5,26 @@ var __decorate = (this && this.__decorate) || function (decorators, target, key,
|
|
|
5
5
|
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;
|
|
6
6
|
return c > 3 && r && Object.defineProperty(target, key, r), r;
|
|
7
7
|
};
|
|
8
|
+
var __metadata = (this && this.__metadata) || function (k, v) {
|
|
9
|
+
if (typeof Reflect === "object" && typeof Reflect.metadata === "function") return Reflect.metadata(k, v);
|
|
10
|
+
};
|
|
8
11
|
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
9
12
|
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
10
13
|
};
|
|
11
14
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
12
15
|
exports.MetricsService = void 0;
|
|
13
16
|
const core_1 = require("@hazeljs/core");
|
|
17
|
+
const model_registry_1 = require("../registry/model.registry");
|
|
18
|
+
const predictor_service_1 = require("../inference/predictor.service");
|
|
14
19
|
const core_2 = __importDefault(require("@hazeljs/core"));
|
|
15
20
|
/**
|
|
16
21
|
* Metrics Service - Model evaluation and metrics
|
|
17
22
|
* Tracks model performance for A/B testing and monitoring
|
|
18
23
|
*/
|
|
19
24
|
let MetricsService = class MetricsService {
|
|
20
|
-
constructor() {
|
|
25
|
+
constructor(modelRegistry, predictorService) {
|
|
26
|
+
this.modelRegistry = modelRegistry;
|
|
27
|
+
this.predictorService = predictorService;
|
|
21
28
|
this.metrics = new Map();
|
|
22
29
|
}
|
|
23
30
|
recordEvaluation(result) {
|
|
@@ -47,8 +54,110 @@ let MetricsService = class MetricsService {
|
|
|
47
54
|
}
|
|
48
55
|
return { a, b, winner };
|
|
49
56
|
}
|
|
57
|
+
/**
|
|
58
|
+
* Evaluate model on test data by running predictions and computing metrics.
|
|
59
|
+
* Requires PredictorService to be injected.
|
|
60
|
+
*
|
|
61
|
+
* @param modelName - Registered model name
|
|
62
|
+
* @param testData - Array of samples. Each sample must contain the model input and a label key.
|
|
63
|
+
* @param options - labelKey (default: 'label'), predictionKey (tries 'label'|'sentiment'|'class'), metrics, version
|
|
64
|
+
*/
|
|
65
|
+
async evaluate(modelName, testData, options = {}) {
|
|
66
|
+
if (!this.predictorService) {
|
|
67
|
+
throw new Error('MetricsService.evaluate() requires PredictorService. Ensure MLModule is configured with PredictorService.');
|
|
68
|
+
}
|
|
69
|
+
const { metrics: requestedMetrics = ['accuracy', 'f1', 'precision', 'recall'], labelKey = 'label', predictionKey, version, } = options;
|
|
70
|
+
if (testData.length === 0) {
|
|
71
|
+
throw new Error('testData cannot be empty');
|
|
72
|
+
}
|
|
73
|
+
const predictions = [];
|
|
74
|
+
for (const sample of testData) {
|
|
75
|
+
const { [labelKey]: _label, ...input } = sample;
|
|
76
|
+
const pred = await this.predictorService.predict(modelName, input, version);
|
|
77
|
+
predictions.push(pred);
|
|
78
|
+
}
|
|
79
|
+
const labels = testData.map((s) => String(s[labelKey] ?? ''));
|
|
80
|
+
const predictedLabels = predictions.map((p) => this.extractPredictedLabel(p, predictionKey));
|
|
81
|
+
const computed = {};
|
|
82
|
+
if (requestedMetrics.includes('accuracy')) {
|
|
83
|
+
computed.accuracy = this.computeAccuracy(labels, predictedLabels);
|
|
84
|
+
}
|
|
85
|
+
if (requestedMetrics.includes('precision') ||
|
|
86
|
+
requestedMetrics.includes('recall') ||
|
|
87
|
+
requestedMetrics.includes('f1')) {
|
|
88
|
+
const { precision, recall, f1Score } = this.computePrecisionRecallF1(labels, predictedLabels);
|
|
89
|
+
if (requestedMetrics.includes('precision'))
|
|
90
|
+
computed.precision = precision;
|
|
91
|
+
if (requestedMetrics.includes('recall'))
|
|
92
|
+
computed.recall = recall;
|
|
93
|
+
if (requestedMetrics.includes('f1'))
|
|
94
|
+
computed.f1Score = f1Score;
|
|
95
|
+
}
|
|
96
|
+
const model = this.modelRegistry?.get(modelName, version);
|
|
97
|
+
const modelVersion = model?.metadata.version ?? version ?? 'unknown';
|
|
98
|
+
const result = {
|
|
99
|
+
modelName,
|
|
100
|
+
version: modelVersion,
|
|
101
|
+
metrics: computed,
|
|
102
|
+
evaluatedAt: new Date(),
|
|
103
|
+
};
|
|
104
|
+
this.recordEvaluation(result);
|
|
105
|
+
core_2.default.debug(`Evaluated ${modelName}@${modelVersion}`, computed);
|
|
106
|
+
return result;
|
|
107
|
+
}
|
|
108
|
+
extractPredictedLabel(prediction, key) {
|
|
109
|
+
if (key && prediction[key] !== undefined) {
|
|
110
|
+
return String(prediction[key]);
|
|
111
|
+
}
|
|
112
|
+
for (const k of ['label', 'sentiment', 'class', 'prediction']) {
|
|
113
|
+
if (prediction[k] !== undefined)
|
|
114
|
+
return String(prediction[k]);
|
|
115
|
+
}
|
|
116
|
+
const first = Object.values(prediction)[0];
|
|
117
|
+
return first !== undefined ? String(first) : '';
|
|
118
|
+
}
|
|
119
|
+
computeAccuracy(labels, predicted) {
|
|
120
|
+
let correct = 0;
|
|
121
|
+
for (let i = 0; i < labels.length; i++) {
|
|
122
|
+
if (labels[i] === predicted[i])
|
|
123
|
+
correct++;
|
|
124
|
+
}
|
|
125
|
+
return labels.length > 0 ? correct / labels.length : 0;
|
|
126
|
+
}
|
|
127
|
+
computePrecisionRecallF1(labels, predicted) {
|
|
128
|
+
const classes = [...new Set([...labels, ...predicted])].filter(Boolean);
|
|
129
|
+
if (classes.length === 0)
|
|
130
|
+
return { precision: 0, recall: 0, f1Score: 0 };
|
|
131
|
+
let totalPrecision = 0;
|
|
132
|
+
let totalRecall = 0;
|
|
133
|
+
let count = 0;
|
|
134
|
+
for (const cls of classes) {
|
|
135
|
+
let tp = 0, fp = 0, fn = 0;
|
|
136
|
+
for (let i = 0; i < labels.length; i++) {
|
|
137
|
+
const isPred = predicted[i] === cls;
|
|
138
|
+
const isActual = labels[i] === cls;
|
|
139
|
+
if (isPred && isActual)
|
|
140
|
+
tp++;
|
|
141
|
+
if (isPred && !isActual)
|
|
142
|
+
fp++;
|
|
143
|
+
if (!isPred && isActual)
|
|
144
|
+
fn++;
|
|
145
|
+
}
|
|
146
|
+
const precision = tp + fp > 0 ? tp / (tp + fp) : 0;
|
|
147
|
+
const recall = tp + fn > 0 ? tp / (tp + fn) : 0;
|
|
148
|
+
totalPrecision += precision;
|
|
149
|
+
totalRecall += recall;
|
|
150
|
+
count++;
|
|
151
|
+
}
|
|
152
|
+
const precision = count > 0 ? totalPrecision / count : 0;
|
|
153
|
+
const recall = count > 0 ? totalRecall / count : 0;
|
|
154
|
+
const f1Score = precision + recall > 0 ? (2 * precision * recall) / (precision + recall) : 0;
|
|
155
|
+
return { precision, recall, f1Score };
|
|
156
|
+
}
|
|
50
157
|
};
|
|
51
158
|
exports.MetricsService = MetricsService;
|
|
52
159
|
exports.MetricsService = MetricsService = __decorate([
|
|
53
|
-
(0, core_1.Service)()
|
|
160
|
+
(0, core_1.Service)(),
|
|
161
|
+
__metadata("design:paramtypes", [model_registry_1.ModelRegistry,
|
|
162
|
+
predictor_service_1.PredictorService])
|
|
54
163
|
], MetricsService);
|
|
@@ -1,6 +1,18 @@
|
|
|
1
1
|
"use strict";
|
|
2
|
+
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) {
|
|
3
|
+
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;
|
|
4
|
+
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc);
|
|
5
|
+
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;
|
|
6
|
+
return c > 3 && r && Object.defineProperty(target, key, r), r;
|
|
7
|
+
};
|
|
8
|
+
var __metadata = (this && this.__metadata) || function (k, v) {
|
|
9
|
+
if (typeof Reflect === "object" && typeof Reflect.metadata === "function") return Reflect.metadata(k, v);
|
|
10
|
+
};
|
|
2
11
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
12
|
+
const model_registry_1 = require("../registry/model.registry");
|
|
13
|
+
const predictor_service_1 = require("../inference/predictor.service");
|
|
3
14
|
const metrics_service_1 = require("./metrics.service");
|
|
15
|
+
const decorators_1 = require("../decorators");
|
|
4
16
|
describe('MetricsService', () => {
|
|
5
17
|
let service;
|
|
6
18
|
beforeEach(() => {
|
|
@@ -76,4 +88,166 @@ describe('MetricsService', () => {
|
|
|
76
88
|
const { winner } = service.compareVersions('m', 'a', 'b');
|
|
77
89
|
expect(winner).toBeUndefined();
|
|
78
90
|
});
|
|
91
|
+
describe('evaluate', () => {
|
|
92
|
+
let EvalModel = class EvalModel {
|
|
93
|
+
train() { }
|
|
94
|
+
async predict(input) {
|
|
95
|
+
const sentiment = input.text.includes('good') ? 'positive' : 'negative';
|
|
96
|
+
return { sentiment };
|
|
97
|
+
}
|
|
98
|
+
};
|
|
99
|
+
__decorate([
|
|
100
|
+
(0, decorators_1.Train)(),
|
|
101
|
+
__metadata("design:type", Function),
|
|
102
|
+
__metadata("design:paramtypes", []),
|
|
103
|
+
__metadata("design:returntype", void 0)
|
|
104
|
+
], EvalModel.prototype, "train", null);
|
|
105
|
+
__decorate([
|
|
106
|
+
(0, decorators_1.Predict)(),
|
|
107
|
+
__metadata("design:type", Function),
|
|
108
|
+
__metadata("design:paramtypes", [Object]),
|
|
109
|
+
__metadata("design:returntype", Promise)
|
|
110
|
+
], EvalModel.prototype, "predict", null);
|
|
111
|
+
EvalModel = __decorate([
|
|
112
|
+
(0, decorators_1.Model)({ name: 'eval-model', version: '1.0.0', framework: 'custom' })
|
|
113
|
+
], EvalModel);
|
|
114
|
+
it('throws when PredictorService not injected', async () => {
|
|
115
|
+
const svc = new metrics_service_1.MetricsService();
|
|
116
|
+
await expect(svc.evaluate('any', [{ text: 'hello', label: 'neutral' }])).rejects.toThrow('MetricsService.evaluate() requires PredictorService');
|
|
117
|
+
});
|
|
118
|
+
it('throws when testData is empty', async () => {
|
|
119
|
+
const registry = new model_registry_1.ModelRegistry();
|
|
120
|
+
registry.register({
|
|
121
|
+
metadata: { name: 'eval-model', version: '1.0.0', framework: 'custom' },
|
|
122
|
+
instance: new EvalModel(),
|
|
123
|
+
predictMethod: 'predict',
|
|
124
|
+
});
|
|
125
|
+
const predictor = new predictor_service_1.PredictorService(registry);
|
|
126
|
+
const svc = new metrics_service_1.MetricsService(registry, predictor);
|
|
127
|
+
await expect(svc.evaluate('eval-model', [])).rejects.toThrow('testData cannot be empty');
|
|
128
|
+
});
|
|
129
|
+
it('computes accuracy, precision, recall, f1 from test data', async () => {
|
|
130
|
+
const registry = new model_registry_1.ModelRegistry();
|
|
131
|
+
registry.register({
|
|
132
|
+
metadata: { name: 'eval-model', version: '1.0.0', framework: 'custom' },
|
|
133
|
+
instance: new EvalModel(),
|
|
134
|
+
predictMethod: 'predict',
|
|
135
|
+
});
|
|
136
|
+
const predictor = new predictor_service_1.PredictorService(registry);
|
|
137
|
+
const svc = new metrics_service_1.MetricsService(registry, predictor);
|
|
138
|
+
const testData = [
|
|
139
|
+
{ text: 'good day', label: 'positive' },
|
|
140
|
+
{ text: 'bad day', label: 'negative' },
|
|
141
|
+
{ text: 'good weather', label: 'positive' },
|
|
142
|
+
{ text: 'bad weather', label: 'negative' },
|
|
143
|
+
];
|
|
144
|
+
const result = await svc.evaluate('eval-model', testData);
|
|
145
|
+
expect(result.modelName).toBe('eval-model');
|
|
146
|
+
expect(result.version).toBe('1.0.0');
|
|
147
|
+
expect(result.metrics.accuracy).toBe(1);
|
|
148
|
+
expect(result.metrics.precision).toBe(1);
|
|
149
|
+
expect(result.metrics.recall).toBe(1);
|
|
150
|
+
expect(result.metrics.f1Score).toBe(1);
|
|
151
|
+
});
|
|
152
|
+
it('supports custom labelKey and predictionKey', async () => {
|
|
153
|
+
let CustomModel = class CustomModel {
|
|
154
|
+
train() { }
|
|
155
|
+
async predict(input) {
|
|
156
|
+
return { outcome: input.x === 'a' ? 'yes' : 'no' };
|
|
157
|
+
}
|
|
158
|
+
};
|
|
159
|
+
__decorate([
|
|
160
|
+
(0, decorators_1.Train)(),
|
|
161
|
+
__metadata("design:type", Function),
|
|
162
|
+
__metadata("design:paramtypes", []),
|
|
163
|
+
__metadata("design:returntype", void 0)
|
|
164
|
+
], CustomModel.prototype, "train", null);
|
|
165
|
+
__decorate([
|
|
166
|
+
(0, decorators_1.Predict)(),
|
|
167
|
+
__metadata("design:type", Function),
|
|
168
|
+
__metadata("design:paramtypes", [Object]),
|
|
169
|
+
__metadata("design:returntype", Promise)
|
|
170
|
+
], CustomModel.prototype, "predict", null);
|
|
171
|
+
CustomModel = __decorate([
|
|
172
|
+
(0, decorators_1.Model)({ name: 'custom-model', version: '1.0.0', framework: 'custom' })
|
|
173
|
+
], CustomModel);
|
|
174
|
+
const registry = new model_registry_1.ModelRegistry();
|
|
175
|
+
registry.register({
|
|
176
|
+
metadata: { name: 'custom-model', version: '1.0.0', framework: 'custom' },
|
|
177
|
+
instance: new CustomModel(),
|
|
178
|
+
predictMethod: 'predict',
|
|
179
|
+
});
|
|
180
|
+
const predictor = new predictor_service_1.PredictorService(registry);
|
|
181
|
+
const svc = new metrics_service_1.MetricsService(registry, predictor);
|
|
182
|
+
const result = await svc.evaluate('custom-model', [
|
|
183
|
+
{ x: 'a', outcome: 'yes' },
|
|
184
|
+
{ x: 'b', outcome: 'no' },
|
|
185
|
+
], { labelKey: 'outcome', predictionKey: 'outcome' });
|
|
186
|
+
expect(result.metrics.accuracy).toBe(1);
|
|
187
|
+
});
|
|
188
|
+
it('extractPredictedLabel uses first value when no known key', async () => {
|
|
189
|
+
let FallbackModel = class FallbackModel {
|
|
190
|
+
train() { }
|
|
191
|
+
async predict() {
|
|
192
|
+
return { customKey: 'positive' };
|
|
193
|
+
}
|
|
194
|
+
};
|
|
195
|
+
__decorate([
|
|
196
|
+
(0, decorators_1.Train)(),
|
|
197
|
+
__metadata("design:type", Function),
|
|
198
|
+
__metadata("design:paramtypes", []),
|
|
199
|
+
__metadata("design:returntype", void 0)
|
|
200
|
+
], FallbackModel.prototype, "train", null);
|
|
201
|
+
__decorate([
|
|
202
|
+
(0, decorators_1.Predict)(),
|
|
203
|
+
__metadata("design:type", Function),
|
|
204
|
+
__metadata("design:paramtypes", []),
|
|
205
|
+
__metadata("design:returntype", Promise)
|
|
206
|
+
], FallbackModel.prototype, "predict", null);
|
|
207
|
+
FallbackModel = __decorate([
|
|
208
|
+
(0, decorators_1.Model)({ name: 'fallback-model', version: '1.0.0', framework: 'custom' })
|
|
209
|
+
], FallbackModel);
|
|
210
|
+
const registry = new model_registry_1.ModelRegistry();
|
|
211
|
+
registry.register({
|
|
212
|
+
metadata: { name: 'fallback-model', version: '1.0.0', framework: 'custom' },
|
|
213
|
+
instance: new FallbackModel(),
|
|
214
|
+
predictMethod: 'predict',
|
|
215
|
+
});
|
|
216
|
+
const predictor = new predictor_service_1.PredictorService(registry);
|
|
217
|
+
const svc = new metrics_service_1.MetricsService(registry, predictor);
|
|
218
|
+
const result = await svc.evaluate('fallback-model', [{ text: 'x', label: 'positive' }], {
|
|
219
|
+
metrics: ['accuracy'],
|
|
220
|
+
});
|
|
221
|
+
expect(result.metrics.accuracy).toBe(1);
|
|
222
|
+
});
|
|
223
|
+
it('evaluate with only accuracy metric skips precision/recall/f1', async () => {
|
|
224
|
+
const registry = new model_registry_1.ModelRegistry();
|
|
225
|
+
registry.register({
|
|
226
|
+
metadata: { name: 'eval-model', version: '1.0.0', framework: 'custom' },
|
|
227
|
+
instance: new EvalModel(),
|
|
228
|
+
predictMethod: 'predict',
|
|
229
|
+
});
|
|
230
|
+
const predictor = new predictor_service_1.PredictorService(registry);
|
|
231
|
+
const svc = new metrics_service_1.MetricsService(registry, predictor);
|
|
232
|
+
const result = await svc.evaluate('eval-model', [{ text: 'good', label: 'positive' }], {
|
|
233
|
+
metrics: ['accuracy'],
|
|
234
|
+
});
|
|
235
|
+
expect(result.metrics.accuracy).toBeDefined();
|
|
236
|
+
expect(result.metrics.precision).toBeUndefined();
|
|
237
|
+
expect(result.metrics.recall).toBeUndefined();
|
|
238
|
+
expect(result.metrics.f1Score).toBeUndefined();
|
|
239
|
+
});
|
|
240
|
+
it('evaluate when modelRegistry is undefined uses version unknown', async () => {
|
|
241
|
+
const registry = new model_registry_1.ModelRegistry();
|
|
242
|
+
registry.register({
|
|
243
|
+
metadata: { name: 'eval-model', version: '1.0.0', framework: 'custom' },
|
|
244
|
+
instance: new EvalModel(),
|
|
245
|
+
predictMethod: 'predict',
|
|
246
|
+
});
|
|
247
|
+
const predictor = new predictor_service_1.PredictorService(registry);
|
|
248
|
+
const svc = new metrics_service_1.MetricsService(undefined, predictor);
|
|
249
|
+
const result = await svc.evaluate('eval-model', [{ text: 'good', label: 'positive' }]);
|
|
250
|
+
expect(result.version).toBeDefined();
|
|
251
|
+
});
|
|
252
|
+
});
|
|
79
253
|
});
|
package/dist/index.d.ts
CHANGED
|
@@ -9,7 +9,7 @@ export { TrainerService } from './training/trainer.service';
|
|
|
9
9
|
export { PipelineService, type PipelineStep } from './training/pipeline.service';
|
|
10
10
|
export { PredictorService } from './inference/predictor.service';
|
|
11
11
|
export { BatchService, type BatchPredictionOptions } from './inference/batch.service';
|
|
12
|
-
export { MetricsService, type ModelMetrics, type EvaluationResult, } from './evaluation/metrics.service';
|
|
12
|
+
export { MetricsService, type ModelMetrics, type EvaluationResult, type EvaluateOptions, type EvaluateMetric, } from './evaluation/metrics.service';
|
|
13
13
|
export { registerMLModel } from './ml-model.base';
|
|
14
14
|
export { Injectable } from '@hazeljs/core';
|
|
15
15
|
export type { MLFramework, ModelMetadata, TrainingData, TrainingResult, PredictionResult, ModelVersion, } from './ml.types';
|
package/dist/index.d.ts.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,kBAAkB,CAAC;AAG1B,OAAO,EAAE,QAAQ,EAAE,SAAS,EAAE,KAAK,eAAe,EAAE,MAAM,aAAa,CAAC;AAGxE,OAAO,EACL,KAAK,EACL,KAAK,EACL,OAAO,EACP,gBAAgB,EAChB,gBAAgB,EAChB,gBAAgB,EAChB,gBAAgB,EAChB,kBAAkB,EAClB,kBAAkB,EAClB,KAAK,YAAY,EACjB,KAAK,cAAc,GACpB,MAAM,cAAc,CAAC;AAGtB,OAAO,EAAE,aAAa,EAAE,KAAK,eAAe,EAAE,MAAM,2BAA2B,CAAC;AAGhF,OAAO,EAAE,cAAc,EAAE,MAAM,4BAA4B,CAAC;AAC5D,OAAO,EAAE,eAAe,EAAE,KAAK,YAAY,EAAE,MAAM,6BAA6B,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAE,MAAM,+BAA+B,CAAC;AACjE,OAAO,EAAE,YAAY,EAAE,KAAK,sBAAsB,EAAE,MAAM,2BAA2B,CAAC;AACtF,OAAO,EACL,cAAc,EACd,KAAK,YAAY,EACjB,KAAK,gBAAgB,
|
|
1
|
+
{"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../src/index.ts"],"names":[],"mappings":"AAAA;;GAEG;AAEH,OAAO,kBAAkB,CAAC;AAG1B,OAAO,EAAE,QAAQ,EAAE,SAAS,EAAE,KAAK,eAAe,EAAE,MAAM,aAAa,CAAC;AAGxE,OAAO,EACL,KAAK,EACL,KAAK,EACL,OAAO,EACP,gBAAgB,EAChB,gBAAgB,EAChB,gBAAgB,EAChB,gBAAgB,EAChB,kBAAkB,EAClB,kBAAkB,EAClB,KAAK,YAAY,EACjB,KAAK,cAAc,GACpB,MAAM,cAAc,CAAC;AAGtB,OAAO,EAAE,aAAa,EAAE,KAAK,eAAe,EAAE,MAAM,2BAA2B,CAAC;AAGhF,OAAO,EAAE,cAAc,EAAE,MAAM,4BAA4B,CAAC;AAC5D,OAAO,EAAE,eAAe,EAAE,KAAK,YAAY,EAAE,MAAM,6BAA6B,CAAC;AACjF,OAAO,EAAE,gBAAgB,EAAE,MAAM,+BAA+B,CAAC;AACjE,OAAO,EAAE,YAAY,EAAE,KAAK,sBAAsB,EAAE,MAAM,2BAA2B,CAAC;AACtF,OAAO,EACL,cAAc,EACd,KAAK,YAAY,EACjB,KAAK,gBAAgB,EACrB,KAAK,eAAe,EACpB,KAAK,cAAc,GACpB,MAAM,8BAA8B,CAAC;AAGtC,OAAO,EAAE,eAAe,EAAE,MAAM,iBAAiB,CAAC;AAGlD,OAAO,EAAE,UAAU,EAAE,MAAM,eAAe,CAAC;AAG3C,YAAY,EACV,WAAW,EACX,aAAa,EACb,YAAY,EACZ,cAAc,EACd,gBAAgB,EAChB,YAAY,GACb,MAAM,YAAY,CAAC"}
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"batch.service.d.ts","sourceRoot":"","sources":["../../src/inference/batch.service.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,gBAAgB,EAAE,MAAM,qBAAqB,CAAC;AACvD,OAAO,EAAE,gBAAgB,EAAE,MAAM,aAAa,CAAC;AAG/C,MAAM,WAAW,sBAAsB;IACrC,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB,WAAW,CAAC,EAAE,MAAM,CAAC;CACtB;AAED;;;GAGG;AACH,qBACa,YAAY;IACX,OAAO,CAAC,QAAQ,CAAC,gBAAgB;gBAAhB,gBAAgB,EAAE,gBAAgB;IAEzD,YAAY,CAAC,CAAC,GAAG,OAAO,EAC5B,SAAS,EAAE,MAAM,EACjB,MAAM,EAAE,OAAO,EAAE,EACjB,OAAO,GAAE,sBAA2B,EACpC,OAAO,CAAC,EAAE,MAAM,GACf,OAAO,CAAC,gBAAgB,CAAC,CAAC,CAAC,EAAE,CAAC;
|
|
1
|
+
{"version":3,"file":"batch.service.d.ts","sourceRoot":"","sources":["../../src/inference/batch.service.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,gBAAgB,EAAE,MAAM,qBAAqB,CAAC;AACvD,OAAO,EAAE,gBAAgB,EAAE,MAAM,aAAa,CAAC;AAG/C,MAAM,WAAW,sBAAsB;IACrC,SAAS,CAAC,EAAE,MAAM,CAAC;IACnB,WAAW,CAAC,EAAE,MAAM,CAAC;CACtB;AAED;;;GAGG;AACH,qBACa,YAAY;IACX,OAAO,CAAC,QAAQ,CAAC,gBAAgB;gBAAhB,gBAAgB,EAAE,gBAAgB;IAEzD,YAAY,CAAC,CAAC,GAAG,OAAO,EAC5B,SAAS,EAAE,MAAM,EACjB,MAAM,EAAE,OAAO,EAAE,EACjB,OAAO,GAAE,sBAA2B,EACpC,OAAO,CAAC,EAAE,MAAM,GACf,OAAO,CAAC,gBAAgB,CAAC,CAAC,CAAC,EAAE,CAAC;CA+BlC"}
|
|
@@ -27,15 +27,20 @@ let BatchService = class BatchService {
|
|
|
27
27
|
async predictBatch(modelName, inputs, options = {}, version) {
|
|
28
28
|
const { batchSize = 32, concurrency = 4 } = options;
|
|
29
29
|
core_2.default.debug(`Batch prediction: ${inputs.length} inputs, batchSize=${batchSize}`);
|
|
30
|
-
const results =
|
|
30
|
+
const results = new Array(inputs.length);
|
|
31
31
|
const batches = [];
|
|
32
32
|
for (let i = 0; i < inputs.length; i += batchSize) {
|
|
33
|
-
|
|
33
|
+
const batch = inputs.slice(i, i + batchSize).map((input, j) => ({ input, idx: i + j }));
|
|
34
|
+
batches.push(batch);
|
|
34
35
|
}
|
|
35
36
|
for (let i = 0; i < batches.length; i += concurrency) {
|
|
36
37
|
const batchGroup = batches.slice(i, i + concurrency);
|
|
37
|
-
const batchResults = await Promise.all(batchGroup.flatMap((batch) => batch.map((input) => this.predictorService
|
|
38
|
-
|
|
38
|
+
const batchResults = await Promise.all(batchGroup.flatMap((batch) => batch.map(({ input, idx }) => this.predictorService
|
|
39
|
+
.predict(modelName, input, version)
|
|
40
|
+
.then((r) => ({ idx, r })))));
|
|
41
|
+
for (const { idx, r } of batchResults) {
|
|
42
|
+
results[idx] = r;
|
|
43
|
+
}
|
|
39
44
|
}
|
|
40
45
|
return results;
|
|
41
46
|
}
|
|
@@ -55,7 +55,32 @@ describe('BatchService', () => {
|
|
|
55
55
|
});
|
|
56
56
|
expect(results).toHaveLength(5);
|
|
57
57
|
});
|
|
58
|
+
it('preserves result order matching input order', async () => {
|
|
59
|
+
const inputs = [10, 20, 30, 40, 50];
|
|
60
|
+
const results = await batchService.predictBatch('batch-model', inputs, {
|
|
61
|
+
batchSize: 2,
|
|
62
|
+
concurrency: 2,
|
|
63
|
+
});
|
|
64
|
+
expect(results).toEqual([
|
|
65
|
+
{ value: 20 },
|
|
66
|
+
{ value: 40 },
|
|
67
|
+
{ value: 60 },
|
|
68
|
+
{ value: 80 },
|
|
69
|
+
{ value: 100 },
|
|
70
|
+
]);
|
|
71
|
+
});
|
|
58
72
|
it('throws when model not found', async () => {
|
|
59
73
|
await expect(batchService.predictBatch('unknown', [1])).rejects.toThrow('Model not found: unknown');
|
|
60
74
|
});
|
|
75
|
+
it('uses default batchSize and concurrency when options empty', async () => {
|
|
76
|
+
const results = await batchService.predictBatch('batch-model', [1, 2], {});
|
|
77
|
+
expect(results).toHaveLength(2);
|
|
78
|
+
expect(results).toEqual([{ value: 2 }, { value: 4 }]);
|
|
79
|
+
});
|
|
80
|
+
it('uses custom concurrency with default batchSize', async () => {
|
|
81
|
+
const results = await batchService.predictBatch('batch-model', [1, 2, 3], {
|
|
82
|
+
concurrency: 1,
|
|
83
|
+
});
|
|
84
|
+
expect(results).toEqual([{ value: 2 }, { value: 4 }, { value: 6 }]);
|
|
85
|
+
});
|
|
61
86
|
});
|
|
@@ -83,4 +83,33 @@ describe('PredictorService', () => {
|
|
|
83
83
|
const instance = new TestModel();
|
|
84
84
|
expect(predictor.discoverPredictMethod(instance)).toBe('predict');
|
|
85
85
|
});
|
|
86
|
+
it('throws when predict method is not a function', async () => {
|
|
87
|
+
const fakeInstance = {
|
|
88
|
+
predict: 'not-a-function',
|
|
89
|
+
};
|
|
90
|
+
registry.register({
|
|
91
|
+
metadata: { name: 'broken-model', version: '1.0.0', framework: 'custom' },
|
|
92
|
+
instance: fakeInstance,
|
|
93
|
+
trainMethod: undefined,
|
|
94
|
+
predictMethod: 'predict',
|
|
95
|
+
});
|
|
96
|
+
await expect(predictor.predict('broken-model', {})).rejects.toThrow('Prediction method predict not found on model');
|
|
97
|
+
});
|
|
98
|
+
it('discoverPredictMethod returns undefined when no @Predict method', () => {
|
|
99
|
+
let NoPredictClass = class NoPredictClass {
|
|
100
|
+
train() { }
|
|
101
|
+
someMethod() { }
|
|
102
|
+
};
|
|
103
|
+
__decorate([
|
|
104
|
+
(0, decorators_1.Train)(),
|
|
105
|
+
__metadata("design:type", Function),
|
|
106
|
+
__metadata("design:paramtypes", []),
|
|
107
|
+
__metadata("design:returntype", void 0)
|
|
108
|
+
], NoPredictClass.prototype, "train", null);
|
|
109
|
+
NoPredictClass = __decorate([
|
|
110
|
+
(0, decorators_1.Model)({ name: 'no-predict-method', version: '1.0.0', framework: 'custom' })
|
|
111
|
+
], NoPredictClass);
|
|
112
|
+
const instance = new NoPredictClass();
|
|
113
|
+
expect(predictor.discoverPredictMethod(instance)).toBeUndefined();
|
|
114
|
+
});
|
|
86
115
|
});
|
|
@@ -49,6 +49,9 @@ describe('ModelRegistry', () => {
|
|
|
49
49
|
expect(list.map((m) => m.name)).toContain('a');
|
|
50
50
|
expect(list.map((m) => m.name)).toContain('b');
|
|
51
51
|
});
|
|
52
|
+
it('getVersions returns empty array for unknown model', () => {
|
|
53
|
+
expect(registry.getVersions('unknown')).toEqual([]);
|
|
54
|
+
});
|
|
52
55
|
it('getVersions returns version history', () => {
|
|
53
56
|
registry.register({
|
|
54
57
|
metadata: { name: 'model', version: '1.0.0', framework: 'tensorflow' },
|
|
@@ -73,4 +76,18 @@ describe('ModelRegistry', () => {
|
|
|
73
76
|
expect(registry.get('model', '1.0.0')).toBeUndefined();
|
|
74
77
|
expect(registry.unregister('model', '1.0.0')).toBe(false);
|
|
75
78
|
});
|
|
79
|
+
it('unregister updates versions list when deleting', () => {
|
|
80
|
+
registry.register({
|
|
81
|
+
metadata: { name: 'm', version: '1.0.0', framework: 'tensorflow' },
|
|
82
|
+
instance: {},
|
|
83
|
+
});
|
|
84
|
+
registry.register({
|
|
85
|
+
metadata: { name: 'm', version: '2.0.0', framework: 'tensorflow' },
|
|
86
|
+
instance: {},
|
|
87
|
+
});
|
|
88
|
+
registry.unregister('m', '1.0.0');
|
|
89
|
+
const versions = registry.getVersions('m');
|
|
90
|
+
expect(versions).toHaveLength(1);
|
|
91
|
+
expect(versions[0].version).toBe('2.0.0');
|
|
92
|
+
});
|
|
76
93
|
});
|
|
@@ -10,7 +10,15 @@ export interface PipelineStep {
|
|
|
10
10
|
export declare class PipelineService {
|
|
11
11
|
private pipelines;
|
|
12
12
|
registerPipeline(name: string, steps: PipelineStep[]): void;
|
|
13
|
+
/**
|
|
14
|
+
* Run a registered pipeline by name.
|
|
15
|
+
*/
|
|
13
16
|
run(name: string, data: TrainingData): Promise<TrainingData>;
|
|
17
|
+
/**
|
|
18
|
+
* Run an ad-hoc pipeline with inline steps (no registration required).
|
|
19
|
+
*/
|
|
20
|
+
run(data: TrainingData, steps: PipelineStep[]): Promise<TrainingData>;
|
|
21
|
+
private executeSteps;
|
|
14
22
|
getPipeline(name: string): PipelineStep[] | undefined;
|
|
15
23
|
listPipelines(): string[];
|
|
16
24
|
}
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"pipeline.service.d.ts","sourceRoot":"","sources":["../../src/training/pipeline.service.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,YAAY,EAAE,MAAM,aAAa,CAAC;AAG3C,MAAM,WAAW,YAAY;IAC3B,IAAI,EAAE,MAAM,CAAC;IACb,SAAS,EAAE,CAAC,IAAI,EAAE,OAAO,KAAK,OAAO,CAAC,OAAO,CAAC,GAAG,OAAO,CAAC;CAC1D;AAED;;;GAGG;AACH,qBACa,eAAe;IAC1B,OAAO,CAAC,SAAS,CAA0C;IAE3D,gBAAgB,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,YAAY,EAAE,GAAG,IAAI;
|
|
1
|
+
{"version":3,"file":"pipeline.service.d.ts","sourceRoot":"","sources":["../../src/training/pipeline.service.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,YAAY,EAAE,MAAM,aAAa,CAAC;AAG3C,MAAM,WAAW,YAAY;IAC3B,IAAI,EAAE,MAAM,CAAC;IACb,SAAS,EAAE,CAAC,IAAI,EAAE,OAAO,KAAK,OAAO,CAAC,OAAO,CAAC,GAAG,OAAO,CAAC;CAC1D;AAED;;;GAGG;AACH,qBACa,eAAe;IAC1B,OAAO,CAAC,SAAS,CAA0C;IAE3D,gBAAgB,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,YAAY,EAAE,GAAG,IAAI;IAK3D;;OAEG;IACG,GAAG,CAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,YAAY,GAAG,OAAO,CAAC,YAAY,CAAC;IAElE;;OAEG;IACG,GAAG,CAAC,IAAI,EAAE,YAAY,EAAE,KAAK,EAAE,YAAY,EAAE,GAAG,OAAO,CAAC,YAAY,CAAC;YA2B7D,YAAY;IAa1B,WAAW,CAAC,IAAI,EAAE,MAAM,GAAG,YAAY,EAAE,GAAG,SAAS;IAIrD,aAAa,IAAI,MAAM,EAAE;CAG1B"}
|
|
@@ -24,14 +24,30 @@ let PipelineService = class PipelineService {
|
|
|
24
24
|
this.pipelines.set(name, steps);
|
|
25
25
|
core_2.default.debug(`Registered pipeline: ${name} with ${steps.length} steps`);
|
|
26
26
|
}
|
|
27
|
-
async run(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
27
|
+
async run(nameOrData, dataOrSteps) {
|
|
28
|
+
let data;
|
|
29
|
+
let steps;
|
|
30
|
+
let label;
|
|
31
|
+
if (typeof nameOrData === 'string') {
|
|
32
|
+
label = nameOrData;
|
|
33
|
+
data = dataOrSteps;
|
|
34
|
+
const registered = this.pipelines.get(nameOrData);
|
|
35
|
+
if (!registered) {
|
|
36
|
+
throw new Error(`Pipeline not found: ${nameOrData}`);
|
|
37
|
+
}
|
|
38
|
+
steps = registered;
|
|
31
39
|
}
|
|
40
|
+
else {
|
|
41
|
+
label = 'inline';
|
|
42
|
+
data = nameOrData;
|
|
43
|
+
steps = dataOrSteps;
|
|
44
|
+
}
|
|
45
|
+
return this.executeSteps(data, steps, label);
|
|
46
|
+
}
|
|
47
|
+
async executeSteps(data, steps, label) {
|
|
32
48
|
let result = data;
|
|
33
49
|
for (const step of steps) {
|
|
34
|
-
core_2.default.debug(`Pipeline ${
|
|
50
|
+
core_2.default.debug(`Pipeline ${label}: executing step ${step.name}`);
|
|
35
51
|
result = await Promise.resolve(step.transform(result));
|
|
36
52
|
}
|
|
37
53
|
return result;
|
|
@@ -35,4 +35,18 @@ describe('PipelineService', () => {
|
|
|
35
35
|
service.registerPipeline('b', []);
|
|
36
36
|
expect(service.listPipelines()).toEqual(['a', 'b']);
|
|
37
37
|
});
|
|
38
|
+
it('run with inline steps (no registration)', async () => {
|
|
39
|
+
const steps = [
|
|
40
|
+
{
|
|
41
|
+
name: 'lower',
|
|
42
|
+
transform: (d) => ({
|
|
43
|
+
...d,
|
|
44
|
+
x: d.x?.toLowerCase() ?? '',
|
|
45
|
+
}),
|
|
46
|
+
},
|
|
47
|
+
{ name: 'add', transform: (d) => ({ ...d, y: 1 }) },
|
|
48
|
+
];
|
|
49
|
+
const result = await service.run({ x: 'HELLO' }, steps);
|
|
50
|
+
expect(result).toEqual({ x: 'hello', y: 1 });
|
|
51
|
+
});
|
|
38
52
|
});
|
|
@@ -86,4 +86,14 @@ describe('TrainerService', () => {
|
|
|
86
86
|
}
|
|
87
87
|
expect(trainer.discoverTrainMethod(new PlainClass())).toBeUndefined();
|
|
88
88
|
});
|
|
89
|
+
it('throws when train method is not a function', async () => {
|
|
90
|
+
const fakeInstance = { train: 'not-a-function' };
|
|
91
|
+
registry.register({
|
|
92
|
+
metadata: { name: 'broken-train', version: '1.0.0', framework: 'custom' },
|
|
93
|
+
instance: fakeInstance,
|
|
94
|
+
trainMethod: 'train',
|
|
95
|
+
predictMethod: undefined,
|
|
96
|
+
});
|
|
97
|
+
await expect(trainer.train('broken-train', {})).rejects.toThrow('Training method train not found on model');
|
|
98
|
+
});
|
|
89
99
|
});
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@hazeljs/ml",
|
|
3
|
-
"version": "0.2.0-beta.
|
|
3
|
+
"version": "0.2.0-beta.70",
|
|
4
4
|
"description": "Machine Learning & Model Management for HazelJS framework",
|
|
5
5
|
"main": "dist/index.js",
|
|
6
6
|
"types": "dist/index.d.ts",
|
|
@@ -48,5 +48,5 @@
|
|
|
48
48
|
"url": "https://github.com/hazeljs/hazel-js/issues"
|
|
49
49
|
},
|
|
50
50
|
"homepage": "https://hazeljs.com",
|
|
51
|
-
"gitHead": "
|
|
51
|
+
"gitHead": "ff7e43369af054d18e064bb342d1e6c5f058ca87"
|
|
52
52
|
}
|