wrapture 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.browserslistrc +4 -0
- package/.editorconfig +10 -0
- package/.eslintrc +0 -0
- package/.github/FUNDING.yml +4 -0
- package/.github/ISSUE_TEMPLATE/bug_report.md +38 -0
- package/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
- package/.github/dependabot.yml +25 -0
- package/.github/labeler.yml +72 -0
- package/.github/labels.yml +24 -0
- package/.github/workflows/auto-assign.yml +19 -0
- package/.github/workflows/check.yml +72 -0
- package/.github/workflows/label.yml +16 -0
- package/.github/workflows/publish.yml +66 -0
- package/.github/workflows/triage.yml +18 -0
- package/.markdownlint.json +7 -0
- package/.prettierrc +7 -0
- package/.release-it.json +67 -0
- package/CHANGELOG.md +8 -0
- package/CODE_OF_CONDUCT.md +128 -0
- package/CONTRIBUTING.md +200 -0
- package/LICENSE +21 -0
- package/README.md +182 -0
- package/SECURITY.md +16 -0
- package/api/README.md +35 -0
- package/api/utils/convert.md +96 -0
- package/api/utils/generate-wrapper.md +91 -0
- package/api/wrapture.md +31 -0
- package/bin/wrapture.js +148 -0
- package/custom-typedoc-plugin.js +55 -0
- package/eslint.config.mjs +16 -0
- package/package.json +74 -0
- package/public/docs/README.md +1 -0
- package/python/convert.py +72 -0
- package/python/scripts/basic_model.py +20 -0
- package/scripts/test.ts +11 -0
- package/src/utils/__tests__/convert.unit.ts +47 -0
- package/src/utils/__tests__/generate-wrapper.unit.ts +29 -0
- package/src/utils/convert.ts +98 -0
- package/src/utils/generate-wrapper.ts +107 -0
- package/src/wrapture.ts +51 -0
- package/test/fixtures/basic_model.pt +0 -0
- package/tsconfig.json +27 -0
- package/tsup.config.ts +14 -0
- package/typedoc.json +33 -0
package/bin/wrapture.js
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
|
|
3
|
+
// src/wrapture.ts
|
|
4
|
+
import chalk2 from "chalk";
|
|
5
|
+
import { Command } from "commander";
|
|
6
|
+
import { existsSync } from "node:fs";
|
|
7
|
+
import path3 from "node:path";
|
|
8
|
+
|
|
9
|
+
// src/utils/convert.ts
|
|
10
|
+
import chalk from "chalk";
|
|
11
|
+
import ora from "ora";
|
|
12
|
+
import { spawn } from "node:child_process";
|
|
13
|
+
import fs from "node:fs";
|
|
14
|
+
import path, { dirname } from "node:path";
|
|
15
|
+
import { fileURLToPath } from "node:url";
|
|
16
|
+
var __filename = fileURLToPath(import.meta.url);
|
|
17
|
+
var __dirname = dirname(__filename);
|
|
18
|
+
var convert = async (inputPath, outputDir, opts) => {
|
|
19
|
+
if (!fs.existsSync(inputPath)) {
|
|
20
|
+
throw new Error(`Input model file not found: ${inputPath}`);
|
|
21
|
+
}
|
|
22
|
+
const spinner = ora("\u{1F504} Converting model to ONNX...").start();
|
|
23
|
+
return new Promise((resolve, reject) => {
|
|
24
|
+
const scriptPath = path.resolve(__dirname, "../../python/convert.py");
|
|
25
|
+
const args = [
|
|
26
|
+
scriptPath,
|
|
27
|
+
"--input",
|
|
28
|
+
inputPath,
|
|
29
|
+
"--output",
|
|
30
|
+
outputDir,
|
|
31
|
+
"--format",
|
|
32
|
+
opts.format || "onnx"
|
|
33
|
+
];
|
|
34
|
+
if (opts.quantize) args.push("--quantize");
|
|
35
|
+
const python = spawn("python3", args);
|
|
36
|
+
python.stdout.on("data", (data) => process.stdout.write(data));
|
|
37
|
+
python.stderr.on(
|
|
38
|
+
"data",
|
|
39
|
+
(data) => process.stderr.write(chalk.red(data.toString()))
|
|
40
|
+
);
|
|
41
|
+
python.on("close", (code) => {
|
|
42
|
+
if (code === 0) {
|
|
43
|
+
spinner.succeed("\u2705 Model converted successfully.");
|
|
44
|
+
resolve();
|
|
45
|
+
} else {
|
|
46
|
+
spinner.fail("\u274C Model conversion failed.");
|
|
47
|
+
reject(new Error(`convert.py exited with code ${code}`));
|
|
48
|
+
}
|
|
49
|
+
});
|
|
50
|
+
});
|
|
51
|
+
};
|
|
52
|
+
|
|
53
|
+
// src/utils/generate-wrapper.ts
|
|
54
|
+
import ora2 from "ora";
|
|
55
|
+
import fs2 from "node:fs";
|
|
56
|
+
import path2 from "node:path";
|
|
57
|
+
var generateWrapper = async (outputDir, opts) => {
|
|
58
|
+
const spinner = ora2("\u{1F6E0} Generating wrapper files...").start();
|
|
59
|
+
const wrapper = `import { InferenceSession, Tensor } from 'onnxruntime-web';
|
|
60
|
+
|
|
61
|
+
const softmax = (logits) => {
|
|
62
|
+
const exps = logits.map(Math.exp);
|
|
63
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
64
|
+
return exps.map(e => e / sum);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
const argmax = (arr) => {
|
|
68
|
+
return arr.reduce((maxIdx, val, idx, src) => val > src[maxIdx] ? idx : maxIdx, 0);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
export const loadModel = async () => {
|
|
72
|
+
const session = await InferenceSession.create(
|
|
73
|
+
new URL('./${opts.backend === "wasm" ? "model_quant.onnx" : "model.onnx"}', import.meta.url).href
|
|
74
|
+
);
|
|
75
|
+
return {
|
|
76
|
+
predict: async (input) => {
|
|
77
|
+
const feeds = { input: new Tensor('float32', input.data, input.dims) };
|
|
78
|
+
const results = await session.run(feeds);
|
|
79
|
+
const raw = results.output.data;
|
|
80
|
+
|
|
81
|
+
if (!(raw instanceof Float32Array)) {
|
|
82
|
+
throw new Error('Expected Float32Array logits but got something else');
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
const logits = raw;
|
|
86
|
+
const probabilities = softmax(Array.from(logits));
|
|
87
|
+
const predictedClass = argmax(probabilities);
|
|
88
|
+
return { logits, probabilities, predictedClass };
|
|
89
|
+
}
|
|
90
|
+
};
|
|
91
|
+
};
|
|
92
|
+
`;
|
|
93
|
+
const typings = `export interface ModelInput {
|
|
94
|
+
data: Float32Array;
|
|
95
|
+
dims: number[];
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
export interface ModelOutput {
|
|
99
|
+
logits: Float32Array;
|
|
100
|
+
probabilities: number[];
|
|
101
|
+
predictedClass: number;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
export interface LoadedModel {
|
|
105
|
+
predict(input: ModelInput): Promise<ModelOutput>;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/**
|
|
109
|
+
* Load the ONNX model and return a wrapper with \`predict()\` function.
|
|
110
|
+
*/
|
|
111
|
+
export function loadModel(): Promise<LoadedModel>;`;
|
|
112
|
+
try {
|
|
113
|
+
fs2.writeFileSync(path2.join(outputDir, "wrapped.ts"), wrapper);
|
|
114
|
+
fs2.writeFileSync(path2.join(outputDir, "wrapped.d.ts"), typings);
|
|
115
|
+
spinner.succeed("\u2705 Wrapper files generated.");
|
|
116
|
+
} catch (error) {
|
|
117
|
+
spinner.fail("\u274C Failed to generate wrapper files.");
|
|
118
|
+
throw error;
|
|
119
|
+
}
|
|
120
|
+
};
|
|
121
|
+
|
|
122
|
+
// src/wrapture.ts
|
|
123
|
+
var program = new Command();
|
|
124
|
+
program.name("wrapture").description("\u{1F300} One-click model exporter: from PyTorch to Web-ready JS/TS").version("0.1.0").requiredOption("-i, --input <file>", "Path to the PyTorch model (.pt)").requiredOption(
|
|
125
|
+
"-o, --output <dir>",
|
|
126
|
+
"Output directory for the wrapped model"
|
|
127
|
+
).option("--quantize", "Apply quantization to reduce model size").option("--format <type>", "Export format: onnx (default)", "onnx").option(
|
|
128
|
+
"--backend <backend>",
|
|
129
|
+
"Inference backend: webgpu | wasm | cpu",
|
|
130
|
+
"webgpu"
|
|
131
|
+
).action(async (opts) => {
|
|
132
|
+
const input = path3.resolve(opts.input);
|
|
133
|
+
const output = path3.resolve(opts.output);
|
|
134
|
+
if (!existsSync(input)) {
|
|
135
|
+
console.error(chalk2.red(`Input file not found: ${input}`));
|
|
136
|
+
process.exit(1);
|
|
137
|
+
}
|
|
138
|
+
console.log(chalk2.cyan("\u2728 Wrapture: Exporting model..."));
|
|
139
|
+
try {
|
|
140
|
+
await convert(input, output, opts);
|
|
141
|
+
await generateWrapper(output, opts);
|
|
142
|
+
console.log(chalk2.green("\u2705 Done! Your model is wrapped and ready."));
|
|
143
|
+
} catch (err) {
|
|
144
|
+
console.error(chalk2.red("\u274C Failed to export model:"), err);
|
|
145
|
+
process.exit(1);
|
|
146
|
+
}
|
|
147
|
+
});
|
|
148
|
+
program.parse(process.argv);
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
/* eslint-disable import/no-unused-modules */
|
|
2
|
+
import { MarkdownTheme, MarkdownThemeContext } from 'typedoc-plugin-markdown';
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* @param {import('typedoc-plugin-markdown').MarkdownApplication} app
|
|
6
|
+
*/
|
|
7
|
+
export function load(app) {
|
|
8
|
+
app.renderer.defineTheme('custom-markdown-theme', MyMarkdownTheme);
|
|
9
|
+
|
|
10
|
+
app.renderer.markdownHooks.on(
|
|
11
|
+
'content.begin',
|
|
12
|
+
() => `> Last updated ${new Date().toISOString()}`
|
|
13
|
+
);
|
|
14
|
+
|
|
15
|
+
const footerText = `***
|
|
16
|
+
|
|
17
|
+
**Contributing**
|
|
18
|
+
|
|
19
|
+
Want to contribute? Please read the [CONTRIBUTING.md](https://github.com/phun-ky/wrapture/blob/main/CONTRIBUTING.md) and [CODE_OF_CONDUCT.md](https://github.com/phun-ky/wrapture/blob/main/CODE_OF_CONDUCT.md)
|
|
20
|
+
|
|
21
|
+
**Sponsor me**
|
|
22
|
+
|
|
23
|
+
I'm an Open Source evangelist, creating stuff that does not exist yet to help get rid of secondary activities and to enhance systems already in place, be it documentation or web sites.
|
|
24
|
+
|
|
25
|
+
The sponsorship is an unique opportunity to alleviate more hours for me to maintain my projects, create new ones and contribute to the large community we're all part of :)
|
|
26
|
+
|
|
27
|
+
[Support me on GitHub Sponsors](https://github.com/sponsors/phun-ky).
|
|
28
|
+
|
|
29
|
+
***
|
|
30
|
+
|
|
31
|
+
This project created by [Alexander Vassbotn Røyne-Helgesen](http://phun-ky.net) is licensed under a [MIT License](https://choosealicense.com/licenses/mit/).
|
|
32
|
+
`;
|
|
33
|
+
|
|
34
|
+
app.renderer.markdownHooks.on('page.end', () => footerText);
|
|
35
|
+
|
|
36
|
+
app.renderer.markdownHooks.on('index.page.end', () => footerText);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
class MyMarkdownTheme extends MarkdownTheme {
|
|
40
|
+
/**
|
|
41
|
+
* @param {import('typedoc-plugin-markdown').MarkdownPageEvent} page
|
|
42
|
+
*/
|
|
43
|
+
getRenderContext(page) {
|
|
44
|
+
return new MyMarkdownThemeContext(this, page, this.application.options);
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
class MyMarkdownThemeContext extends MarkdownThemeContext {
|
|
49
|
+
partials = {
|
|
50
|
+
...this.partials,
|
|
51
|
+
header: () => {
|
|
52
|
+
return '';
|
|
53
|
+
}
|
|
54
|
+
};
|
|
55
|
+
}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import { defineConfig } from 'eslint/config';
|
|
2
|
+
import customConfig from 'eslint-config-phun-ky';
|
|
3
|
+
|
|
4
|
+
// eslint-disable-next-line import/no-unused-modules
|
|
5
|
+
export default defineConfig([
|
|
6
|
+
{
|
|
7
|
+
extends: [customConfig]
|
|
8
|
+
},
|
|
9
|
+
{
|
|
10
|
+
files: ['**/*.md'],
|
|
11
|
+
rules: {
|
|
12
|
+
'no-irregular-whitespace': 'off',
|
|
13
|
+
'@stylistic/indent': 'off'
|
|
14
|
+
}
|
|
15
|
+
}
|
|
16
|
+
]);
|
package/package.json
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "wrapture",
|
|
3
|
+
"version": "0.0.1",
|
|
4
|
+
"description": "",
|
|
5
|
+
"homepage": "https://github.com/phun-ky/wrapture#readme",
|
|
6
|
+
"bugs": {
|
|
7
|
+
"url": "https://github.com/phun-ky/wrapture/issues"
|
|
8
|
+
},
|
|
9
|
+
"repository": {
|
|
10
|
+
"type": "git",
|
|
11
|
+
"url": "git+https://github.com/phun-ky/wrapture.git"
|
|
12
|
+
},
|
|
13
|
+
"funding": "https://github.com/phun-ky/angle?sponsor=1",
|
|
14
|
+
"license": "MIT",
|
|
15
|
+
"author": "Alexander Vassbotn Røyne-Helgesen <alexander@phun-ky.net>",
|
|
16
|
+
"type": "module",
|
|
17
|
+
"bin": {
|
|
18
|
+
"wrapture": "./bin/wrapture.js"
|
|
19
|
+
},
|
|
20
|
+
"scripts": {
|
|
21
|
+
"commit": "npx git-cz",
|
|
22
|
+
"docs:gen": "node ./node_modules/.bin/typedoc",
|
|
23
|
+
"release": "release-it",
|
|
24
|
+
"start": "node ./bin/wrapture.js",
|
|
25
|
+
"build": "tsup",
|
|
26
|
+
"style:code": "npx putout src",
|
|
27
|
+
"style:format": "./node_modules/.bin/eslint -c ./eslint.config.mjs src --fix && ./node_modules/.bin/prettier --write ./eslint.config.mjs src",
|
|
28
|
+
"style:lint": "./node_modules/.bin/eslint -c ./eslint.config.mjs src && ./node_modules/.bin/prettier --check src",
|
|
29
|
+
"test": "NODE_ENV=test glob -c \"node --import tsx --test --no-warnings\" \"./src/**/__tests__/**/*.[jt]s\"",
|
|
30
|
+
"pretest:ci": "rm -rf coverage && mkdir -p coverage",
|
|
31
|
+
"test:ci": "NODE_ENV=test glob -c \"node --import tsx --test --no-warnings --experimental-test-coverage --test-reporter=cobertura --test-reporter-destination=coverage/cobertura-coverage.xml --test-reporter=spec --test-reporter-destination=stdout\" \"./src/**/__tests__/**/*.[jt]s\""
|
|
32
|
+
},
|
|
33
|
+
"engines": {
|
|
34
|
+
"node": ">=22.0.0",
|
|
35
|
+
"npm": ">=10.8.2"
|
|
36
|
+
},
|
|
37
|
+
"publishConfig": {
|
|
38
|
+
"access": "public"
|
|
39
|
+
},
|
|
40
|
+
"devDependencies": {
|
|
41
|
+
"@release-it/conventional-changelog": "^10.0.0",
|
|
42
|
+
"@rollup/plugin-node-resolve": "^16.0.1",
|
|
43
|
+
"@rollup/plugin-terser": "^0.4.4",
|
|
44
|
+
"@stylistic/eslint-plugin": "^4.2.0",
|
|
45
|
+
"@types/node": "^22.15.3",
|
|
46
|
+
"cobertura": "^1.0.1",
|
|
47
|
+
"eslint": "^9.20.0",
|
|
48
|
+
"eslint-config-phun-ky": "^1.0.0",
|
|
49
|
+
"git-cz": "^4.9.0",
|
|
50
|
+
"onnxruntime-web": "^1.22.0",
|
|
51
|
+
"prettier": "^3.2.5",
|
|
52
|
+
"putout": "^40.1.9",
|
|
53
|
+
"release-it": "^19.0.1",
|
|
54
|
+
"remark-github": "^12.0.0",
|
|
55
|
+
"remark-toc": "^9.0.0",
|
|
56
|
+
"tslib": "^2.3.1",
|
|
57
|
+
"tsup": "^8.4.0",
|
|
58
|
+
"tsx": "^4.7.1",
|
|
59
|
+
"typedoc": "^0.28.3",
|
|
60
|
+
"typedoc-plugin-frontmatter": "^1.0.0",
|
|
61
|
+
"typedoc-plugin-markdown": "^4.2.3",
|
|
62
|
+
"typedoc-plugin-mdn-links": "^5.0.1",
|
|
63
|
+
"typedoc-plugin-no-inherit": "^1.4.0",
|
|
64
|
+
"typedoc-plugin-remark": "^2.0.0",
|
|
65
|
+
"typedoc-plugin-rename-defaults": "^0.7.1",
|
|
66
|
+
"typescript": "^5.0.0",
|
|
67
|
+
"unified-prettier": "^2.0.1"
|
|
68
|
+
},
|
|
69
|
+
"dependencies": {
|
|
70
|
+
"chalk": "^5.4.1",
|
|
71
|
+
"commander": "^13.1.0",
|
|
72
|
+
"ora": "^8.2.0"
|
|
73
|
+
}
|
|
74
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# wrapture API documentation
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import os
|
|
5
|
+
from onnxsim import simplify
|
|
6
|
+
import onnx
|
|
7
|
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
8
|
+
|
|
9
|
+
def convert_to_onnx(input_path, output_path, quantize=False):
|
|
10
|
+
class BasicCNN(nn.Module): # Define the same model architecture
|
|
11
|
+
def __init__(self):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.net = nn.Sequential(
|
|
14
|
+
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
|
|
15
|
+
nn.ReLU(),
|
|
16
|
+
nn.AdaptiveAvgPool2d((1, 1)),
|
|
17
|
+
nn.Flatten(),
|
|
18
|
+
nn.Linear(16, 10)
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def forward(self, x):
|
|
22
|
+
return self.net(x)
|
|
23
|
+
|
|
24
|
+
model = BasicCNN()
|
|
25
|
+
|
|
26
|
+
model.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False))
|
|
27
|
+
|
|
28
|
+
model.eval()
|
|
29
|
+
|
|
30
|
+
dummy_input = torch.randn(1, 3, 224, 224)
|
|
31
|
+
output_model_path = os.path.join(output_path, 'model.onnx')
|
|
32
|
+
torch.onnx.export(
|
|
33
|
+
model,
|
|
34
|
+
dummy_input,
|
|
35
|
+
output_model_path,
|
|
36
|
+
input_names=['input'],
|
|
37
|
+
output_names=['output'],
|
|
38
|
+
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
|
|
39
|
+
opset_version=11
|
|
40
|
+
)
|
|
41
|
+
model = onnx.load(output_model_path)
|
|
42
|
+
model_simp, check = simplify(model)
|
|
43
|
+
assert check, "Simplified model could not be validated"
|
|
44
|
+
onnx.checker.check_model(model_simp) # throw if broken
|
|
45
|
+
onnx.save(model_simp, output_model_path)
|
|
46
|
+
print("✅ ONNX model simplified.")
|
|
47
|
+
|
|
48
|
+
print(f"✅ Exported ONNX model to: {output_model_path}")
|
|
49
|
+
|
|
50
|
+
if quantize:
|
|
51
|
+
quant_model_path = os.path.join(output_path, 'model_quant.onnx')
|
|
52
|
+
quantize_dynamic(output_model_path, quant_model_path, weight_type=QuantType.QInt8)
|
|
53
|
+
|
|
54
|
+
def main():
|
|
55
|
+
parser = argparse.ArgumentParser(description='Wrapture: Convert PyTorch model to ONNX.')
|
|
56
|
+
parser.add_argument('--input', required=True, help='Path to the PyTorch model (.pt)')
|
|
57
|
+
parser.add_argument('--output', required=True, help='Directory to save the ONNX model')
|
|
58
|
+
parser.add_argument('--format', default='onnx', help='Export format (currently only supports ONNX)')
|
|
59
|
+
parser.add_argument('--quantize', action='store_true', help='Apply quantization')
|
|
60
|
+
|
|
61
|
+
args = parser.parse_args()
|
|
62
|
+
|
|
63
|
+
if not os.path.exists(args.output):
|
|
64
|
+
os.makedirs(args.output)
|
|
65
|
+
|
|
66
|
+
if args.format.lower() == 'onnx':
|
|
67
|
+
convert_to_onnx(args.input, args.output, args.quantize)
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Unsupported format: {args.format}")
|
|
70
|
+
|
|
71
|
+
if __name__ == '__main__':
|
|
72
|
+
main()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
class BasicCNN(nn.Module):
|
|
5
|
+
def __init__(self):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.net = nn.Sequential(
|
|
8
|
+
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
|
|
9
|
+
nn.ReLU(),
|
|
10
|
+
nn.AdaptiveAvgPool2d((1, 1)),
|
|
11
|
+
nn.Flatten(),
|
|
12
|
+
nn.Linear(16, 10)
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
def forward(self, x):
|
|
16
|
+
return self.net(x)
|
|
17
|
+
|
|
18
|
+
model = BasicCNN()
|
|
19
|
+
torch.save(model.state_dict(), 'test/fixtures/basic_model.pt')
|
|
20
|
+
print("✅ Saved model weights to test/fixtures/basic_model.pt")
|
package/scripts/test.ts
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
/* eslint-disable no-console */
|
|
2
|
+
import { loadModel } from '../wrapped.js';
|
|
3
|
+
|
|
4
|
+
const model = await loadModel();
|
|
5
|
+
const input = {
|
|
6
|
+
data: new Float32Array(1 * 3 * 224 * 224),
|
|
7
|
+
dims: [1, 3, 224, 224]
|
|
8
|
+
};
|
|
9
|
+
const out = await model.predict(input);
|
|
10
|
+
|
|
11
|
+
console.log(out);
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
// test/convert.test.js
|
|
2
|
+
import test from 'node:test';
|
|
3
|
+
import assert from 'node:assert/strict';
|
|
4
|
+
import fs from 'node:fs';
|
|
5
|
+
import { tmpdir } from 'node:os';
|
|
6
|
+
import { randomUUID } from 'node:crypto';
|
|
7
|
+
import path, { dirname } from 'node:path';
|
|
8
|
+
import { fileURLToPath } from 'node:url';
|
|
9
|
+
|
|
10
|
+
import { convert } from '../convert';
|
|
11
|
+
|
|
12
|
+
const __filename = fileURLToPath(import.meta.url);
|
|
13
|
+
const __dirname = dirname(__filename);
|
|
14
|
+
|
|
15
|
+
const fixtureModelPath = path.resolve('test', 'fixtures', 'basic_model.pt');
|
|
16
|
+
const outputDir = path.join(tmpdir(), `wrapture-test-${randomUUID()}`);
|
|
17
|
+
|
|
18
|
+
test('convert() should reject if file does not exist', async () => {
|
|
19
|
+
const nonExistentPath = 'some/missing/model.pt';
|
|
20
|
+
const invalidOut = path.join(tmpdir(), `wrapture-test-${randomUUID()}`);
|
|
21
|
+
|
|
22
|
+
await assert.rejects(
|
|
23
|
+
() => convert(nonExistentPath, invalidOut, {}),
|
|
24
|
+
/Input model file not found/i
|
|
25
|
+
);
|
|
26
|
+
});
|
|
27
|
+
|
|
28
|
+
test('convert() resolves if subprocess exits cleanly', async (t) => {
|
|
29
|
+
// Mock subprocess call by making convert.py short-circuit if needed
|
|
30
|
+
const mockInput = fixtureModelPath;
|
|
31
|
+
const output = outputDir;
|
|
32
|
+
|
|
33
|
+
// Make sure the output dir exists
|
|
34
|
+
fs.mkdirSync(output, { recursive: true });
|
|
35
|
+
|
|
36
|
+
// You can only run this if the python script exists and is safe to test
|
|
37
|
+
// Skip the test if script is missing or too slow
|
|
38
|
+
const scriptPath = path.resolve(__dirname, '../../../python', 'convert.py');
|
|
39
|
+
if (!fs.existsSync(scriptPath)) {
|
|
40
|
+
t.skip('Python script does not exist, skipping subprocess test');
|
|
41
|
+
return;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
await assert.doesNotReject(async () => {
|
|
45
|
+
await convert(mockInput, output, { format: 'onnx', quantize: false });
|
|
46
|
+
});
|
|
47
|
+
});
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import test from 'node:test';
|
|
2
|
+
import assert from 'node:assert/strict';
|
|
3
|
+
import fs from 'node:fs';
|
|
4
|
+
import path from 'node:path';
|
|
5
|
+
import { generateWrapper } from '../generate-wrapper.js';
|
|
6
|
+
|
|
7
|
+
const testOutputDir = './.test-output';
|
|
8
|
+
|
|
9
|
+
test('generateWrapper creates wrapper files', async (t) => {
|
|
10
|
+
// Clean up previous run
|
|
11
|
+
fs.rmSync(testOutputDir, { recursive: true, force: true });
|
|
12
|
+
fs.mkdirSync(testOutputDir);
|
|
13
|
+
|
|
14
|
+
await generateWrapper(testOutputDir, { backend: 'webgpu' });
|
|
15
|
+
|
|
16
|
+
const files = fs.readdirSync(testOutputDir);
|
|
17
|
+
assert.ok(files.includes('wrapped.ts'), 'wrapped.ts should be generated');
|
|
18
|
+
assert.ok(files.includes('wrapped.d.ts'), 'wrapped.d.ts should be generated');
|
|
19
|
+
|
|
20
|
+
const wrapperContent = fs.readFileSync(
|
|
21
|
+
path.join(testOutputDir, 'wrapped.ts'),
|
|
22
|
+
'utf-8'
|
|
23
|
+
);
|
|
24
|
+
assert.match(
|
|
25
|
+
wrapperContent,
|
|
26
|
+
/loadModel/,
|
|
27
|
+
'Wrapper should contain loadModel function'
|
|
28
|
+
);
|
|
29
|
+
});
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
/* eslint-disable import/no-unused-modules */
|
|
2
|
+
/* global process */
|
|
3
|
+
import chalk from 'chalk';
|
|
4
|
+
import ora from 'ora';
|
|
5
|
+
|
|
6
|
+
import { spawn } from 'node:child_process';
|
|
7
|
+
import fs from 'node:fs';
|
|
8
|
+
import path, { dirname } from 'node:path';
|
|
9
|
+
import { fileURLToPath } from 'node:url';
|
|
10
|
+
|
|
11
|
+
const __filename = fileURLToPath(import.meta.url);
|
|
12
|
+
const __dirname = dirname(__filename);
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Options for the {@link convert} function.
|
|
16
|
+
*/
|
|
17
|
+
export interface ConvertOptionsInterface {
|
|
18
|
+
/**
|
|
19
|
+
* The output format for the converted model (e.g., 'onnx').
|
|
20
|
+
* Defaults to 'onnx' if not provided.
|
|
21
|
+
*/
|
|
22
|
+
format?: string;
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
* Whether to apply quantization to the model.
|
|
26
|
+
*/
|
|
27
|
+
quantize?: boolean;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
/**
|
|
31
|
+
* Converts a machine learning model to ONNX or another supported format
|
|
32
|
+
* by delegating to a Python script (`convert.py`).
|
|
33
|
+
*
|
|
34
|
+
* This function spawns a subprocess using `python3` and monitors stdout/stderr.
|
|
35
|
+
* It supports additional options such as format selection and quantization.
|
|
36
|
+
*
|
|
37
|
+
* @function convert
|
|
38
|
+
* @param {string} inputPath - Path to the input model file
|
|
39
|
+
* @param {string} outputDir - Directory where the converted model should be saved
|
|
40
|
+
* @param {ConvertOptionsInterface} opts - Conversion options
|
|
41
|
+
* @returns A promise that resolves on success or rejects if conversion fails
|
|
42
|
+
*
|
|
43
|
+
* @throws {Error} If the Python process exits with a non-zero code
|
|
44
|
+
*
|
|
45
|
+
* @example
|
|
46
|
+
* ```ts
|
|
47
|
+
* await convert('models/model.pt', 'out/', { format: 'onnx', quantize: true });
|
|
48
|
+
* ```
|
|
49
|
+
*
|
|
50
|
+
* @see https://nodejs.org/api/child_process.html#child_processspawncommand-args-options
|
|
51
|
+
* @see https://github.com/sindresorhus/ora - Ora spinner for CLI feedback
|
|
52
|
+
* @see https://github.com/chalk/chalk - Chalk for terminal coloring
|
|
53
|
+
*/
|
|
54
|
+
export const convert = async (
|
|
55
|
+
inputPath: string,
|
|
56
|
+
outputDir: string,
|
|
57
|
+
opts: ConvertOptionsInterface
|
|
58
|
+
): Promise<void> => {
|
|
59
|
+
if (!fs.existsSync(inputPath)) {
|
|
60
|
+
throw new Error(`Input model file not found: ${inputPath}`);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
const spinner = ora('🔄 Converting model to ONNX...').start();
|
|
64
|
+
|
|
65
|
+
return new Promise((resolve, reject) => {
|
|
66
|
+
const scriptPath = path.resolve(__dirname, '../../python/convert.py');
|
|
67
|
+
const args = [
|
|
68
|
+
scriptPath,
|
|
69
|
+
'--input',
|
|
70
|
+
inputPath,
|
|
71
|
+
'--output',
|
|
72
|
+
outputDir,
|
|
73
|
+
'--format',
|
|
74
|
+
opts.format || 'onnx'
|
|
75
|
+
];
|
|
76
|
+
|
|
77
|
+
if (opts.quantize) args.push('--quantize');
|
|
78
|
+
|
|
79
|
+
const python = spawn('python3', args);
|
|
80
|
+
|
|
81
|
+
python.stdout.on('data', (data) => process.stdout.write(data));
|
|
82
|
+
python.stderr.on('data', (data) =>
|
|
83
|
+
process.stderr.write(chalk.red(data.toString()))
|
|
84
|
+
);
|
|
85
|
+
|
|
86
|
+
python.on('close', (code) => {
|
|
87
|
+
if (code === 0) {
|
|
88
|
+
spinner.succeed('✅ Model converted successfully.');
|
|
89
|
+
|
|
90
|
+
resolve();
|
|
91
|
+
} else {
|
|
92
|
+
spinner.fail('❌ Model conversion failed.');
|
|
93
|
+
|
|
94
|
+
reject(new Error(`convert.py exited with code ${code}`));
|
|
95
|
+
}
|
|
96
|
+
});
|
|
97
|
+
});
|
|
98
|
+
};
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import ora from 'ora';
|
|
2
|
+
|
|
3
|
+
import fs from 'node:fs';
|
|
4
|
+
import path from 'node:path';
|
|
5
|
+
|
|
6
|
+
/**
|
|
7
|
+
* Options for generating ONNX wrapper files.
|
|
8
|
+
*/
|
|
9
|
+
export interface GenerateWrapperOptionsInterface {
|
|
10
|
+
/**
|
|
11
|
+
* The backend to use for inference. This affects the model file used.
|
|
12
|
+
* If set to `'wasm'`, the generated wrapper will load `model_quant.onnx`,
|
|
13
|
+
* otherwise it will load `model.onnx`.
|
|
14
|
+
*/
|
|
15
|
+
backend: 'wasm' | 'webgl' | string;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Generates a TypeScript wrapper and type definition file (`wrapped.ts` and `wrapped.d.ts`)
|
|
20
|
+
* for use with `onnxruntime-web`, including utility functions like `softmax`, `argmax`,
|
|
21
|
+
* and a typed `predict()` function.
|
|
22
|
+
*
|
|
23
|
+
* The generated code loads the correct ONNX model based on the provided backend.
|
|
24
|
+
*
|
|
25
|
+
* @param {string} outputDir - The directory where the wrapper files will be written.
|
|
26
|
+
* @param {GenerateWrapperOptionsInterface} opts - Wrapper generation options, including backend type.
|
|
27
|
+
* @returns A Promise that resolves when the wrapper files are successfully written.
|
|
28
|
+
*
|
|
29
|
+
* @throws Will throw an error if file writing fails.
|
|
30
|
+
*
|
|
31
|
+
* @example
|
|
32
|
+
* ```ts
|
|
33
|
+
* await generateWrapper('./dist', { backend: 'wasm' });
|
|
34
|
+
* // Creates `wrapped.ts` and `wrapped.d.ts` in ./dist
|
|
35
|
+
* ```
|
|
36
|
+
*
|
|
37
|
+
* @see https://www.npmjs.com/package/onnxruntime-web
|
|
38
|
+
*/
|
|
39
|
+
export const generateWrapper = async (
|
|
40
|
+
outputDir: string,
|
|
41
|
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
42
|
+
opts: GenerateWrapperOptionsInterface
|
|
43
|
+
): Promise<void> => {
|
|
44
|
+
const spinner = ora('🛠 Generating wrapper files...').start();
|
|
45
|
+
const wrapper = `import { InferenceSession, Tensor } from 'onnxruntime-web';
|
|
46
|
+
|
|
47
|
+
const softmax = (logits) => {
|
|
48
|
+
const exps = logits.map(Math.exp);
|
|
49
|
+
const sum = exps.reduce((a, b) => a + b, 0);
|
|
50
|
+
return exps.map(e => e / sum);
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
const argmax = (arr) => {
|
|
54
|
+
return arr.reduce((maxIdx, val, idx, src) => val > src[maxIdx] ? idx : maxIdx, 0);
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
export const loadModel = async () => {
|
|
58
|
+
const session = await InferenceSession.create(
|
|
59
|
+
new URL('./${opts.backend === 'wasm' ? 'model_quant.onnx' : 'model.onnx'}', import.meta.url).href
|
|
60
|
+
);
|
|
61
|
+
return {
|
|
62
|
+
predict: async (input) => {
|
|
63
|
+
const feeds = { input: new Tensor('float32', input.data, input.dims) };
|
|
64
|
+
const results = await session.run(feeds);
|
|
65
|
+
const raw = results.output.data;
|
|
66
|
+
|
|
67
|
+
if (!(raw instanceof Float32Array)) {
|
|
68
|
+
throw new Error('Expected Float32Array logits but got something else');
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
const logits = raw;
|
|
72
|
+
const probabilities = softmax(Array.from(logits));
|
|
73
|
+
const predictedClass = argmax(probabilities);
|
|
74
|
+
return { logits, probabilities, predictedClass };
|
|
75
|
+
}
|
|
76
|
+
};
|
|
77
|
+
};
|
|
78
|
+
`;
|
|
79
|
+
const typings = `export interface ModelInput {
|
|
80
|
+
data: Float32Array;
|
|
81
|
+
dims: number[];
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
export interface ModelOutput {
|
|
85
|
+
logits: Float32Array;
|
|
86
|
+
probabilities: number[];
|
|
87
|
+
predictedClass: number;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
export interface LoadedModel {
|
|
91
|
+
predict(input: ModelInput): Promise<ModelOutput>;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
/**
|
|
95
|
+
* Load the ONNX model and return a wrapper with \`predict()\` function.
|
|
96
|
+
*/
|
|
97
|
+
export function loadModel(): Promise<LoadedModel>;`;
|
|
98
|
+
|
|
99
|
+
try {
|
|
100
|
+
fs.writeFileSync(path.join(outputDir, 'wrapped.ts'), wrapper);
|
|
101
|
+
fs.writeFileSync(path.join(outputDir, 'wrapped.d.ts'), typings);
|
|
102
|
+
spinner.succeed('✅ Wrapper files generated.');
|
|
103
|
+
} catch (error) {
|
|
104
|
+
spinner.fail('❌ Failed to generate wrapper files.');
|
|
105
|
+
throw error;
|
|
106
|
+
}
|
|
107
|
+
};
|