@hazeljs/ml 0.2.0-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +192 -0
- package/README.md +300 -0
- package/dist/decorators/index.d.ts +4 -0
- package/dist/decorators/index.d.ts.map +1 -0
- package/dist/decorators/index.js +15 -0
- package/dist/decorators/model.decorator.d.ts +26 -0
- package/dist/decorators/model.decorator.d.ts.map +1 -0
- package/dist/decorators/model.decorator.js +48 -0
- package/dist/decorators/model.decorator.test.d.ts +2 -0
- package/dist/decorators/model.decorator.test.d.ts.map +1 -0
- package/dist/decorators/model.decorator.test.js +128 -0
- package/dist/decorators/predict.decorator.d.ts +20 -0
- package/dist/decorators/predict.decorator.d.ts.map +1 -0
- package/dist/decorators/predict.decorator.js +40 -0
- package/dist/decorators/train.decorator.d.ts +21 -0
- package/dist/decorators/train.decorator.d.ts.map +1 -0
- package/dist/decorators/train.decorator.js +41 -0
- package/dist/evaluation/metrics.service.d.ts +54 -0
- package/dist/evaluation/metrics.service.d.ts.map +1 -0
- package/dist/evaluation/metrics.service.js +163 -0
- package/dist/evaluation/metrics.service.test.d.ts +2 -0
- package/dist/evaluation/metrics.service.test.d.ts.map +1 -0
- package/dist/evaluation/metrics.service.test.js +253 -0
- package/dist/index.d.ts +16 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +42 -0
- package/dist/inference/batch.service.d.ts +16 -0
- package/dist/inference/batch.service.d.ts.map +1 -0
- package/dist/inference/batch.service.js +52 -0
- package/dist/inference/batch.service.test.d.ts +2 -0
- package/dist/inference/batch.service.test.d.ts.map +1 -0
- package/dist/inference/batch.service.test.js +86 -0
- package/dist/inference/predictor.service.d.ts +13 -0
- package/dist/inference/predictor.service.d.ts.map +1 -0
- package/dist/inference/predictor.service.js +65 -0
- package/dist/inference/predictor.service.test.d.ts +2 -0
- package/dist/inference/predictor.service.test.d.ts.map +1 -0
- package/dist/inference/predictor.service.test.js +115 -0
- package/dist/ml-model.base.d.ts +20 -0
- package/dist/ml-model.base.d.ts.map +1 -0
- package/dist/ml-model.base.js +33 -0
- package/dist/ml-model.base.test.d.ts +2 -0
- package/dist/ml-model.base.test.d.ts.map +1 -0
- package/dist/ml-model.base.test.js +57 -0
- package/dist/ml.module.d.ts +27 -0
- package/dist/ml.module.d.ts.map +1 -0
- package/dist/ml.module.js +126 -0
- package/dist/ml.module.test.d.ts +2 -0
- package/dist/ml.module.test.d.ts.map +1 -0
- package/dist/ml.module.test.js +60 -0
- package/dist/ml.types.d.ts +30 -0
- package/dist/ml.types.d.ts.map +1 -0
- package/dist/ml.types.js +5 -0
- package/dist/registry/model.registry.d.ts +21 -0
- package/dist/registry/model.registry.d.ts.map +1 -0
- package/dist/registry/model.registry.js +64 -0
- package/dist/registry/model.registry.test.d.ts +2 -0
- package/dist/registry/model.registry.test.d.ts.map +1 -0
- package/dist/registry/model.registry.test.js +93 -0
- package/dist/training/pipeline.service.d.ts +25 -0
- package/dist/training/pipeline.service.d.ts.map +1 -0
- package/dist/training/pipeline.service.js +65 -0
- package/dist/training/pipeline.service.test.d.ts +2 -0
- package/dist/training/pipeline.service.test.d.ts.map +1 -0
- package/dist/training/pipeline.service.test.js +52 -0
- package/dist/training/trainer.service.d.ts +13 -0
- package/dist/training/trainer.service.d.ts.map +1 -0
- package/dist/training/trainer.service.js +69 -0
- package/dist/training/trainer.service.test.d.ts +2 -0
- package/dist/training/trainer.service.test.d.ts.map +1 -0
- package/dist/training/trainer.service.test.js +99 -0
- package/package.json +52 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
const model_registry_1 = require("./model.registry");
|
|
4
|
+
describe('ModelRegistry', () => {
|
|
5
|
+
let registry;
|
|
6
|
+
beforeEach(() => {
|
|
7
|
+
registry = new model_registry_1.ModelRegistry();
|
|
8
|
+
});
|
|
9
|
+
it('registers and retrieves model by name and version', () => {
|
|
10
|
+
const model = {
|
|
11
|
+
metadata: { name: 'sentiment', version: '1.0.0', framework: 'tensorflow' },
|
|
12
|
+
instance: {},
|
|
13
|
+
trainMethod: 'train',
|
|
14
|
+
predictMethod: 'predict',
|
|
15
|
+
};
|
|
16
|
+
registry.register(model);
|
|
17
|
+
const retrieved = registry.get('sentiment', '1.0.0');
|
|
18
|
+
expect(retrieved).toBe(model);
|
|
19
|
+
});
|
|
20
|
+
it('returns latest version when version not specified', () => {
|
|
21
|
+
const v1 = {
|
|
22
|
+
metadata: { name: 'model', version: '1.0.0', framework: 'tensorflow' },
|
|
23
|
+
instance: {},
|
|
24
|
+
};
|
|
25
|
+
const v2 = {
|
|
26
|
+
metadata: { name: 'model', version: '2.0.0', framework: 'tensorflow' },
|
|
27
|
+
instance: {},
|
|
28
|
+
};
|
|
29
|
+
registry.register(v1);
|
|
30
|
+
registry.register(v2);
|
|
31
|
+
expect(registry.get('model')).toBe(v2);
|
|
32
|
+
expect(registry.get('model', '1.0.0')).toBe(v1);
|
|
33
|
+
});
|
|
34
|
+
it('returns undefined for unknown model', () => {
|
|
35
|
+
expect(registry.get('unknown')).toBeUndefined();
|
|
36
|
+
expect(registry.get('unknown', '1.0.0')).toBeUndefined();
|
|
37
|
+
});
|
|
38
|
+
it('lists all registered models', () => {
|
|
39
|
+
registry.register({
|
|
40
|
+
metadata: { name: 'a', version: '1', framework: 'tensorflow' },
|
|
41
|
+
instance: {},
|
|
42
|
+
});
|
|
43
|
+
registry.register({
|
|
44
|
+
metadata: { name: 'b', version: '1', framework: 'onnx' },
|
|
45
|
+
instance: {},
|
|
46
|
+
});
|
|
47
|
+
const list = registry.list();
|
|
48
|
+
expect(list).toHaveLength(2);
|
|
49
|
+
expect(list.map((m) => m.name)).toContain('a');
|
|
50
|
+
expect(list.map((m) => m.name)).toContain('b');
|
|
51
|
+
});
|
|
52
|
+
it('getVersions returns empty array for unknown model', () => {
|
|
53
|
+
expect(registry.getVersions('unknown')).toEqual([]);
|
|
54
|
+
});
|
|
55
|
+
it('getVersions returns version history', () => {
|
|
56
|
+
registry.register({
|
|
57
|
+
metadata: { name: 'model', version: '1.0.0', framework: 'tensorflow' },
|
|
58
|
+
instance: {},
|
|
59
|
+
});
|
|
60
|
+
registry.register({
|
|
61
|
+
metadata: { name: 'model', version: '2.0.0', framework: 'tensorflow' },
|
|
62
|
+
instance: {},
|
|
63
|
+
});
|
|
64
|
+
const versions = registry.getVersions('model');
|
|
65
|
+
expect(versions).toHaveLength(2);
|
|
66
|
+
expect(versions.map((v) => v.version)).toEqual(['1.0.0', '2.0.0']);
|
|
67
|
+
});
|
|
68
|
+
it('unregister removes model', () => {
|
|
69
|
+
registry.register({
|
|
70
|
+
metadata: { name: 'model', version: '1.0.0', framework: 'tensorflow' },
|
|
71
|
+
instance: {},
|
|
72
|
+
});
|
|
73
|
+
expect(registry.get('model', '1.0.0')).toBeDefined();
|
|
74
|
+
const deleted = registry.unregister('model', '1.0.0');
|
|
75
|
+
expect(deleted).toBe(true);
|
|
76
|
+
expect(registry.get('model', '1.0.0')).toBeUndefined();
|
|
77
|
+
expect(registry.unregister('model', '1.0.0')).toBe(false);
|
|
78
|
+
});
|
|
79
|
+
it('unregister updates versions list when deleting', () => {
|
|
80
|
+
registry.register({
|
|
81
|
+
metadata: { name: 'm', version: '1.0.0', framework: 'tensorflow' },
|
|
82
|
+
instance: {},
|
|
83
|
+
});
|
|
84
|
+
registry.register({
|
|
85
|
+
metadata: { name: 'm', version: '2.0.0', framework: 'tensorflow' },
|
|
86
|
+
instance: {},
|
|
87
|
+
});
|
|
88
|
+
registry.unregister('m', '1.0.0');
|
|
89
|
+
const versions = registry.getVersions('m');
|
|
90
|
+
expect(versions).toHaveLength(1);
|
|
91
|
+
expect(versions[0].version).toBe('2.0.0');
|
|
92
|
+
});
|
|
93
|
+
});
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import { TrainingData } from '../ml.types';
|
|
2
|
+
export interface PipelineStep {
|
|
3
|
+
name: string;
|
|
4
|
+
transform: (data: unknown) => Promise<unknown> | unknown;
|
|
5
|
+
}
|
|
6
|
+
/**
|
|
7
|
+
* Pipeline Service - ETL pipelines for training data preparation
|
|
8
|
+
* Handles data transformation before model training
|
|
9
|
+
*/
|
|
10
|
+
export declare class PipelineService {
|
|
11
|
+
private pipelines;
|
|
12
|
+
registerPipeline(name: string, steps: PipelineStep[]): void;
|
|
13
|
+
/**
|
|
14
|
+
* Run a registered pipeline by name.
|
|
15
|
+
*/
|
|
16
|
+
run(name: string, data: TrainingData): Promise<TrainingData>;
|
|
17
|
+
/**
|
|
18
|
+
* Run an ad-hoc pipeline with inline steps (no registration required).
|
|
19
|
+
*/
|
|
20
|
+
run(data: TrainingData, steps: PipelineStep[]): Promise<TrainingData>;
|
|
21
|
+
private executeSteps;
|
|
22
|
+
getPipeline(name: string): PipelineStep[] | undefined;
|
|
23
|
+
listPipelines(): string[];
|
|
24
|
+
}
|
|
25
|
+
//# sourceMappingURL=pipeline.service.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"pipeline.service.d.ts","sourceRoot":"","sources":["../../src/training/pipeline.service.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,YAAY,EAAE,MAAM,aAAa,CAAC;AAG3C,MAAM,WAAW,YAAY;IAC3B,IAAI,EAAE,MAAM,CAAC;IACb,SAAS,EAAE,CAAC,IAAI,EAAE,OAAO,KAAK,OAAO,CAAC,OAAO,CAAC,GAAG,OAAO,CAAC;CAC1D;AAED;;;GAGG;AACH,qBACa,eAAe;IAC1B,OAAO,CAAC,SAAS,CAA0C;IAE3D,gBAAgB,CAAC,IAAI,EAAE,MAAM,EAAE,KAAK,EAAE,YAAY,EAAE,GAAG,IAAI;IAK3D;;OAEG;IACG,GAAG,CAAC,IAAI,EAAE,MAAM,EAAE,IAAI,EAAE,YAAY,GAAG,OAAO,CAAC,YAAY,CAAC;IAElE;;OAEG;IACG,GAAG,CAAC,IAAI,EAAE,YAAY,EAAE,KAAK,EAAE,YAAY,EAAE,GAAG,OAAO,CAAC,YAAY,CAAC;YA2B7D,YAAY;IAa1B,WAAW,CAAC,IAAI,EAAE,MAAM,GAAG,YAAY,EAAE,GAAG,SAAS;IAIrD,aAAa,IAAI,MAAM,EAAE;CAG1B"}
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) {
|
|
3
|
+
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;
|
|
4
|
+
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc);
|
|
5
|
+
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;
|
|
6
|
+
return c > 3 && r && Object.defineProperty(target, key, r), r;
|
|
7
|
+
};
|
|
8
|
+
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
9
|
+
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
10
|
+
};
|
|
11
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
12
|
+
exports.PipelineService = void 0;
|
|
13
|
+
const core_1 = require("@hazeljs/core");
|
|
14
|
+
const core_2 = __importDefault(require("@hazeljs/core"));
|
|
15
|
+
/**
|
|
16
|
+
* Pipeline Service - ETL pipelines for training data preparation
|
|
17
|
+
* Handles data transformation before model training
|
|
18
|
+
*/
|
|
19
|
+
let PipelineService = class PipelineService {
|
|
20
|
+
constructor() {
|
|
21
|
+
this.pipelines = new Map();
|
|
22
|
+
}
|
|
23
|
+
registerPipeline(name, steps) {
|
|
24
|
+
this.pipelines.set(name, steps);
|
|
25
|
+
core_2.default.debug(`Registered pipeline: ${name} with ${steps.length} steps`);
|
|
26
|
+
}
|
|
27
|
+
async run(nameOrData, dataOrSteps) {
|
|
28
|
+
let data;
|
|
29
|
+
let steps;
|
|
30
|
+
let label;
|
|
31
|
+
if (typeof nameOrData === 'string') {
|
|
32
|
+
label = nameOrData;
|
|
33
|
+
data = dataOrSteps;
|
|
34
|
+
const registered = this.pipelines.get(nameOrData);
|
|
35
|
+
if (!registered) {
|
|
36
|
+
throw new Error(`Pipeline not found: ${nameOrData}`);
|
|
37
|
+
}
|
|
38
|
+
steps = registered;
|
|
39
|
+
}
|
|
40
|
+
else {
|
|
41
|
+
label = 'inline';
|
|
42
|
+
data = nameOrData;
|
|
43
|
+
steps = dataOrSteps;
|
|
44
|
+
}
|
|
45
|
+
return this.executeSteps(data, steps, label);
|
|
46
|
+
}
|
|
47
|
+
async executeSteps(data, steps, label) {
|
|
48
|
+
let result = data;
|
|
49
|
+
for (const step of steps) {
|
|
50
|
+
core_2.default.debug(`Pipeline ${label}: executing step ${step.name}`);
|
|
51
|
+
result = await Promise.resolve(step.transform(result));
|
|
52
|
+
}
|
|
53
|
+
return result;
|
|
54
|
+
}
|
|
55
|
+
getPipeline(name) {
|
|
56
|
+
return this.pipelines.get(name);
|
|
57
|
+
}
|
|
58
|
+
listPipelines() {
|
|
59
|
+
return Array.from(this.pipelines.keys());
|
|
60
|
+
}
|
|
61
|
+
};
|
|
62
|
+
exports.PipelineService = PipelineService;
|
|
63
|
+
exports.PipelineService = PipelineService = __decorate([
|
|
64
|
+
(0, core_1.Service)()
|
|
65
|
+
], PipelineService);
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"pipeline.service.test.d.ts","sourceRoot":"","sources":["../../src/training/pipeline.service.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
const pipeline_service_1 = require("./pipeline.service");
|
|
4
|
+
describe('PipelineService', () => {
|
|
5
|
+
let service;
|
|
6
|
+
beforeEach(() => {
|
|
7
|
+
service = new pipeline_service_1.PipelineService();
|
|
8
|
+
});
|
|
9
|
+
it('registers and runs pipeline', async () => {
|
|
10
|
+
service.registerPipeline('test', [
|
|
11
|
+
{ name: 'step1', transform: (d) => ({ ...d, a: 1 }) },
|
|
12
|
+
{ name: 'step2', transform: (d) => ({ ...d, b: 2 }) },
|
|
13
|
+
]);
|
|
14
|
+
const result = await service.run('test', { x: 0 });
|
|
15
|
+
expect(result).toEqual({ x: 0, a: 1, b: 2 });
|
|
16
|
+
});
|
|
17
|
+
it('handles async transforms', async () => {
|
|
18
|
+
service.registerPipeline('async', [
|
|
19
|
+
{ name: 'async', transform: async (d) => ({ ...d, done: true }) },
|
|
20
|
+
]);
|
|
21
|
+
const result = await service.run('async', {});
|
|
22
|
+
expect(result).toEqual({ done: true });
|
|
23
|
+
});
|
|
24
|
+
it('throws when pipeline not found', async () => {
|
|
25
|
+
await expect(service.run('missing', {})).rejects.toThrow('Pipeline not found: missing');
|
|
26
|
+
});
|
|
27
|
+
it('getPipeline returns steps', () => {
|
|
28
|
+
const steps = [{ name: 's1', transform: (d) => d }];
|
|
29
|
+
service.registerPipeline('p', steps);
|
|
30
|
+
expect(service.getPipeline('p')).toEqual(steps);
|
|
31
|
+
expect(service.getPipeline('x')).toBeUndefined();
|
|
32
|
+
});
|
|
33
|
+
it('listPipelines returns names', () => {
|
|
34
|
+
service.registerPipeline('a', []);
|
|
35
|
+
service.registerPipeline('b', []);
|
|
36
|
+
expect(service.listPipelines()).toEqual(['a', 'b']);
|
|
37
|
+
});
|
|
38
|
+
it('run with inline steps (no registration)', async () => {
|
|
39
|
+
const steps = [
|
|
40
|
+
{
|
|
41
|
+
name: 'lower',
|
|
42
|
+
transform: (d) => ({
|
|
43
|
+
...d,
|
|
44
|
+
x: d.x?.toLowerCase() ?? '',
|
|
45
|
+
}),
|
|
46
|
+
},
|
|
47
|
+
{ name: 'add', transform: (d) => ({ ...d, y: 1 }) },
|
|
48
|
+
];
|
|
49
|
+
const result = await service.run({ x: 'HELLO' }, steps);
|
|
50
|
+
expect(result).toEqual({ x: 'hello', y: 1 });
|
|
51
|
+
});
|
|
52
|
+
});
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { ModelRegistry } from '../registry/model.registry';
|
|
2
|
+
import { TrainingData, TrainingResult } from '../ml.types';
|
|
3
|
+
/**
|
|
4
|
+
* Trainer Service - Training orchestration for ML models
|
|
5
|
+
* Coordinates training pipelines and model updates
|
|
6
|
+
*/
|
|
7
|
+
export declare class TrainerService {
|
|
8
|
+
private readonly modelRegistry;
|
|
9
|
+
constructor(modelRegistry: ModelRegistry);
|
|
10
|
+
train(modelName: string, data: TrainingData, version?: string): Promise<TrainingResult>;
|
|
11
|
+
discoverTrainMethod(instance: object): string | undefined;
|
|
12
|
+
}
|
|
13
|
+
//# sourceMappingURL=trainer.service.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trainer.service.d.ts","sourceRoot":"","sources":["../../src/training/trainer.service.ts"],"names":[],"mappings":"AACA,OAAO,EAAE,aAAa,EAAE,MAAM,4BAA4B,CAAC;AAE3D,OAAO,EAAE,YAAY,EAAE,cAAc,EAAE,MAAM,aAAa,CAAC;AAG3D;;;GAGG;AACH,qBACa,cAAc;IACb,OAAO,CAAC,QAAQ,CAAC,aAAa;gBAAb,aAAa,EAAE,aAAa;IAEnD,KAAK,CAAC,SAAS,EAAE,MAAM,EAAE,IAAI,EAAE,YAAY,EAAE,OAAO,CAAC,EAAE,MAAM,GAAG,OAAO,CAAC,cAAc,CAAC;IA2B7F,mBAAmB,CAAC,QAAQ,EAAE,MAAM,GAAG,MAAM,GAAG,SAAS;CAgB1D"}
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) {
|
|
3
|
+
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;
|
|
4
|
+
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc);
|
|
5
|
+
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;
|
|
6
|
+
return c > 3 && r && Object.defineProperty(target, key, r), r;
|
|
7
|
+
};
|
|
8
|
+
var __metadata = (this && this.__metadata) || function (k, v) {
|
|
9
|
+
if (typeof Reflect === "object" && typeof Reflect.metadata === "function") return Reflect.metadata(k, v);
|
|
10
|
+
};
|
|
11
|
+
var __importDefault = (this && this.__importDefault) || function (mod) {
|
|
12
|
+
return (mod && mod.__esModule) ? mod : { "default": mod };
|
|
13
|
+
};
|
|
14
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
15
|
+
exports.TrainerService = void 0;
|
|
16
|
+
const core_1 = require("@hazeljs/core");
|
|
17
|
+
const model_registry_1 = require("../registry/model.registry");
|
|
18
|
+
const decorators_1 = require("../decorators");
|
|
19
|
+
const core_2 = __importDefault(require("@hazeljs/core"));
|
|
20
|
+
/**
|
|
21
|
+
* Trainer Service - Training orchestration for ML models
|
|
22
|
+
* Coordinates training pipelines and model updates
|
|
23
|
+
*/
|
|
24
|
+
let TrainerService = class TrainerService {
|
|
25
|
+
constructor(modelRegistry) {
|
|
26
|
+
this.modelRegistry = modelRegistry;
|
|
27
|
+
}
|
|
28
|
+
async train(modelName, data, version) {
|
|
29
|
+
const model = this.modelRegistry.get(modelName, version);
|
|
30
|
+
if (!model) {
|
|
31
|
+
throw new Error(`Model not found: ${modelName}`);
|
|
32
|
+
}
|
|
33
|
+
const trainMethod = model.trainMethod;
|
|
34
|
+
if (!trainMethod) {
|
|
35
|
+
throw new Error(`Model ${modelName} has no training method`);
|
|
36
|
+
}
|
|
37
|
+
const instance = model.instance;
|
|
38
|
+
const trainFn = instance[trainMethod];
|
|
39
|
+
if (typeof trainFn !== 'function') {
|
|
40
|
+
throw new Error(`Training method ${trainMethod} not found on model`);
|
|
41
|
+
}
|
|
42
|
+
core_2.default.debug(`Starting training for model: ${modelName}`);
|
|
43
|
+
const result = await trainFn.call(instance, data);
|
|
44
|
+
core_2.default.debug(`Training completed for model: ${modelName}`, result);
|
|
45
|
+
return result;
|
|
46
|
+
}
|
|
47
|
+
discoverTrainMethod(instance) {
|
|
48
|
+
const metadata = (0, decorators_1.getModelMetadata)(instance.constructor);
|
|
49
|
+
if (!metadata)
|
|
50
|
+
return undefined;
|
|
51
|
+
const proto = Object.getPrototypeOf(instance);
|
|
52
|
+
for (const key of Object.getOwnPropertyNames(proto)) {
|
|
53
|
+
if (key === 'constructor')
|
|
54
|
+
continue;
|
|
55
|
+
const descriptor = Object.getOwnPropertyDescriptor(proto, key);
|
|
56
|
+
if (descriptor?.value && typeof descriptor.value === 'function') {
|
|
57
|
+
if ((0, decorators_1.getTrainMetadata)(proto, key)) {
|
|
58
|
+
return key;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
return undefined;
|
|
63
|
+
}
|
|
64
|
+
};
|
|
65
|
+
exports.TrainerService = TrainerService;
|
|
66
|
+
exports.TrainerService = TrainerService = __decorate([
|
|
67
|
+
(0, core_1.Service)(),
|
|
68
|
+
__metadata("design:paramtypes", [model_registry_1.ModelRegistry])
|
|
69
|
+
], TrainerService);
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"trainer.service.test.d.ts","sourceRoot":"","sources":["../../src/training/trainer.service.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
var __decorate = (this && this.__decorate) || function (decorators, target, key, desc) {
|
|
3
|
+
var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;
|
|
4
|
+
if (typeof Reflect === "object" && typeof Reflect.decorate === "function") r = Reflect.decorate(decorators, target, key, desc);
|
|
5
|
+
else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;
|
|
6
|
+
return c > 3 && r && Object.defineProperty(target, key, r), r;
|
|
7
|
+
};
|
|
8
|
+
var __metadata = (this && this.__metadata) || function (k, v) {
|
|
9
|
+
if (typeof Reflect === "object" && typeof Reflect.metadata === "function") return Reflect.metadata(k, v);
|
|
10
|
+
};
|
|
11
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
12
|
+
const model_registry_1 = require("../registry/model.registry");
|
|
13
|
+
const trainer_service_1 = require("./trainer.service");
|
|
14
|
+
const decorators_1 = require("../decorators");
|
|
15
|
+
describe('TrainerService', () => {
|
|
16
|
+
let registry;
|
|
17
|
+
let trainer;
|
|
18
|
+
let TestModel = class TestModel {
|
|
19
|
+
async train(data) {
|
|
20
|
+
return { accuracy: 0.95, loss: 0.05, inputSize: data?.size };
|
|
21
|
+
}
|
|
22
|
+
async predict() {
|
|
23
|
+
return { value: 1 };
|
|
24
|
+
}
|
|
25
|
+
};
|
|
26
|
+
__decorate([
|
|
27
|
+
(0, decorators_1.Train)(),
|
|
28
|
+
__metadata("design:type", Function),
|
|
29
|
+
__metadata("design:paramtypes", [Object]),
|
|
30
|
+
__metadata("design:returntype", Promise)
|
|
31
|
+
], TestModel.prototype, "train", null);
|
|
32
|
+
__decorate([
|
|
33
|
+
(0, decorators_1.Predict)(),
|
|
34
|
+
__metadata("design:type", Function),
|
|
35
|
+
__metadata("design:paramtypes", []),
|
|
36
|
+
__metadata("design:returntype", Promise)
|
|
37
|
+
], TestModel.prototype, "predict", null);
|
|
38
|
+
TestModel = __decorate([
|
|
39
|
+
(0, decorators_1.Model)({ name: 'test-model', version: '1.0.0', framework: 'tensorflow' })
|
|
40
|
+
], TestModel);
|
|
41
|
+
beforeEach(() => {
|
|
42
|
+
registry = new model_registry_1.ModelRegistry();
|
|
43
|
+
trainer = new trainer_service_1.TrainerService(registry);
|
|
44
|
+
const instance = new TestModel();
|
|
45
|
+
registry.register({
|
|
46
|
+
metadata: { name: 'test-model', version: '1.0.0', framework: 'tensorflow' },
|
|
47
|
+
instance,
|
|
48
|
+
trainMethod: 'train',
|
|
49
|
+
predictMethod: 'predict',
|
|
50
|
+
});
|
|
51
|
+
});
|
|
52
|
+
it('trains model and returns result', async () => {
|
|
53
|
+
const result = await trainer.train('test-model', { size: 100 });
|
|
54
|
+
expect(result).toEqual({ accuracy: 0.95, loss: 0.05, inputSize: 100 });
|
|
55
|
+
});
|
|
56
|
+
it('throws when model not found', async () => {
|
|
57
|
+
await expect(trainer.train('unknown', {})).rejects.toThrow('Model not found: unknown');
|
|
58
|
+
});
|
|
59
|
+
it('throws when model has no training method', async () => {
|
|
60
|
+
let NoTrainModel = class NoTrainModel {
|
|
61
|
+
predict() { }
|
|
62
|
+
};
|
|
63
|
+
__decorate([
|
|
64
|
+
(0, decorators_1.Predict)(),
|
|
65
|
+
__metadata("design:type", Function),
|
|
66
|
+
__metadata("design:paramtypes", []),
|
|
67
|
+
__metadata("design:returntype", void 0)
|
|
68
|
+
], NoTrainModel.prototype, "predict", null);
|
|
69
|
+
NoTrainModel = __decorate([
|
|
70
|
+
(0, decorators_1.Model)({ name: 'no-train', version: '1.0.0', framework: 'tensorflow' })
|
|
71
|
+
], NoTrainModel);
|
|
72
|
+
registry.register({
|
|
73
|
+
metadata: { name: 'no-train', version: '1.0.0', framework: 'tensorflow' },
|
|
74
|
+
instance: new NoTrainModel(),
|
|
75
|
+
predictMethod: 'predict',
|
|
76
|
+
});
|
|
77
|
+
await expect(trainer.train('no-train', {})).rejects.toThrow('no-train has no training method');
|
|
78
|
+
});
|
|
79
|
+
it('discoverTrainMethod finds @Train decorated method', () => {
|
|
80
|
+
const instance = new TestModel();
|
|
81
|
+
expect(trainer.discoverTrainMethod(instance)).toBe('train');
|
|
82
|
+
});
|
|
83
|
+
it('discoverTrainMethod returns undefined for class without @Model', () => {
|
|
84
|
+
class PlainClass {
|
|
85
|
+
train() { }
|
|
86
|
+
}
|
|
87
|
+
expect(trainer.discoverTrainMethod(new PlainClass())).toBeUndefined();
|
|
88
|
+
});
|
|
89
|
+
it('throws when train method is not a function', async () => {
|
|
90
|
+
const fakeInstance = { train: 'not-a-function' };
|
|
91
|
+
registry.register({
|
|
92
|
+
metadata: { name: 'broken-train', version: '1.0.0', framework: 'custom' },
|
|
93
|
+
instance: fakeInstance,
|
|
94
|
+
trainMethod: 'train',
|
|
95
|
+
predictMethod: undefined,
|
|
96
|
+
});
|
|
97
|
+
await expect(trainer.train('broken-train', {})).rejects.toThrow('Training method train not found on model');
|
|
98
|
+
});
|
|
99
|
+
});
|
package/package.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@hazeljs/ml",
|
|
3
|
+
"version": "0.2.0-alpha.1",
|
|
4
|
+
"description": "Machine Learning & Model Management for HazelJS framework",
|
|
5
|
+
"main": "dist/index.js",
|
|
6
|
+
"types": "dist/index.d.ts",
|
|
7
|
+
"files": [
|
|
8
|
+
"dist"
|
|
9
|
+
],
|
|
10
|
+
"scripts": {
|
|
11
|
+
"build": "tsc --skipLibCheck",
|
|
12
|
+
"test": "jest --coverage --maxWorkers=1",
|
|
13
|
+
"lint": "eslint \"src/**/*.ts\"",
|
|
14
|
+
"lint:fix": "eslint \"src/**/*.ts\" --fix",
|
|
15
|
+
"clean": "rm -rf dist"
|
|
16
|
+
},
|
|
17
|
+
"devDependencies": {
|
|
18
|
+
"@types/node": "^20.17.50",
|
|
19
|
+
"@typescript-eslint/eslint-plugin": "^8.18.2",
|
|
20
|
+
"@typescript-eslint/parser": "^8.18.2",
|
|
21
|
+
"eslint": "^8.56.0",
|
|
22
|
+
"jest": "^29.7.0",
|
|
23
|
+
"ts-jest": "^29.1.2",
|
|
24
|
+
"typescript": "^5.3.3"
|
|
25
|
+
},
|
|
26
|
+
"peerDependencies": {
|
|
27
|
+
"@hazeljs/core": ">=0.2.0-beta.0"
|
|
28
|
+
},
|
|
29
|
+
"publishConfig": {
|
|
30
|
+
"access": "public"
|
|
31
|
+
},
|
|
32
|
+
"repository": {
|
|
33
|
+
"type": "git",
|
|
34
|
+
"url": "git+https://github.com/hazel-js/hazeljs.git",
|
|
35
|
+
"directory": "packages/ml"
|
|
36
|
+
},
|
|
37
|
+
"keywords": [
|
|
38
|
+
"hazeljs",
|
|
39
|
+
"ml",
|
|
40
|
+
"machine-learning",
|
|
41
|
+
"model-registry",
|
|
42
|
+
"tensorflow",
|
|
43
|
+
"onnx"
|
|
44
|
+
],
|
|
45
|
+
"author": "Muhammad Arslan <muhammad.arslan@hazeljs.com>",
|
|
46
|
+
"license": "Apache-2.0",
|
|
47
|
+
"bugs": {
|
|
48
|
+
"url": "https://github.com/hazeljs/hazel-js/issues"
|
|
49
|
+
},
|
|
50
|
+
"homepage": "https://hazeljs.com",
|
|
51
|
+
"gitHead": "cbc5ee2c12ced28fd0576faf13c5f078c1e8421e"
|
|
52
|
+
}
|