wrapture 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/.browserslistrc +4 -0
  2. package/.editorconfig +10 -0
  3. package/.eslintrc +0 -0
  4. package/.github/FUNDING.yml +4 -0
  5. package/.github/ISSUE_TEMPLATE/bug_report.md +38 -0
  6. package/.github/ISSUE_TEMPLATE/feature_request.md +20 -0
  7. package/.github/dependabot.yml +25 -0
  8. package/.github/labeler.yml +72 -0
  9. package/.github/labels.yml +24 -0
  10. package/.github/workflows/auto-assign.yml +19 -0
  11. package/.github/workflows/check.yml +72 -0
  12. package/.github/workflows/label.yml +16 -0
  13. package/.github/workflows/publish.yml +66 -0
  14. package/.github/workflows/triage.yml +18 -0
  15. package/.markdownlint.json +7 -0
  16. package/.prettierrc +7 -0
  17. package/.release-it.json +67 -0
  18. package/CHANGELOG.md +8 -0
  19. package/CODE_OF_CONDUCT.md +128 -0
  20. package/CONTRIBUTING.md +200 -0
  21. package/LICENSE +21 -0
  22. package/README.md +182 -0
  23. package/SECURITY.md +16 -0
  24. package/api/README.md +35 -0
  25. package/api/utils/convert.md +96 -0
  26. package/api/utils/generate-wrapper.md +91 -0
  27. package/api/wrapture.md +31 -0
  28. package/bin/wrapture.js +148 -0
  29. package/custom-typedoc-plugin.js +55 -0
  30. package/eslint.config.mjs +16 -0
  31. package/package.json +74 -0
  32. package/public/docs/README.md +1 -0
  33. package/python/convert.py +72 -0
  34. package/python/scripts/basic_model.py +20 -0
  35. package/scripts/test.ts +11 -0
  36. package/src/utils/__tests__/convert.unit.ts +47 -0
  37. package/src/utils/__tests__/generate-wrapper.unit.ts +29 -0
  38. package/src/utils/convert.ts +98 -0
  39. package/src/utils/generate-wrapper.ts +107 -0
  40. package/src/wrapture.ts +51 -0
  41. package/test/fixtures/basic_model.pt +0 -0
  42. package/tsconfig.json +27 -0
  43. package/tsup.config.ts +14 -0
  44. package/typedoc.json +33 -0
@@ -0,0 +1,148 @@
1
+ #!/usr/bin/env node
2
+
3
+ // src/wrapture.ts
4
+ import chalk2 from "chalk";
5
+ import { Command } from "commander";
6
+ import { existsSync } from "node:fs";
7
+ import path3 from "node:path";
8
+
9
+ // src/utils/convert.ts
10
+ import chalk from "chalk";
11
+ import ora from "ora";
12
+ import { spawn } from "node:child_process";
13
+ import fs from "node:fs";
14
+ import path, { dirname } from "node:path";
15
+ import { fileURLToPath } from "node:url";
16
+ var __filename = fileURLToPath(import.meta.url);
17
+ var __dirname = dirname(__filename);
18
+ var convert = async (inputPath, outputDir, opts) => {
19
+ if (!fs.existsSync(inputPath)) {
20
+ throw new Error(`Input model file not found: ${inputPath}`);
21
+ }
22
+ const spinner = ora("\u{1F504} Converting model to ONNX...").start();
23
+ return new Promise((resolve, reject) => {
24
+ const scriptPath = path.resolve(__dirname, "../../python/convert.py");
25
+ const args = [
26
+ scriptPath,
27
+ "--input",
28
+ inputPath,
29
+ "--output",
30
+ outputDir,
31
+ "--format",
32
+ opts.format || "onnx"
33
+ ];
34
+ if (opts.quantize) args.push("--quantize");
35
+ const python = spawn("python3", args);
36
+ python.stdout.on("data", (data) => process.stdout.write(data));
37
+ python.stderr.on(
38
+ "data",
39
+ (data) => process.stderr.write(chalk.red(data.toString()))
40
+ );
41
+ python.on("close", (code) => {
42
+ if (code === 0) {
43
+ spinner.succeed("\u2705 Model converted successfully.");
44
+ resolve();
45
+ } else {
46
+ spinner.fail("\u274C Model conversion failed.");
47
+ reject(new Error(`convert.py exited with code ${code}`));
48
+ }
49
+ });
50
+ });
51
+ };
52
+
53
+ // src/utils/generate-wrapper.ts
54
+ import ora2 from "ora";
55
+ import fs2 from "node:fs";
56
+ import path2 from "node:path";
57
+ var generateWrapper = async (outputDir, opts) => {
58
+ const spinner = ora2("\u{1F6E0} Generating wrapper files...").start();
59
+ const wrapper = `import { InferenceSession, Tensor } from 'onnxruntime-web';
60
+
61
+ const softmax = (logits) => {
62
+ const exps = logits.map(Math.exp);
63
+ const sum = exps.reduce((a, b) => a + b, 0);
64
+ return exps.map(e => e / sum);
65
+ }
66
+
67
+ const argmax = (arr) => {
68
+ return arr.reduce((maxIdx, val, idx, src) => val > src[maxIdx] ? idx : maxIdx, 0);
69
+ }
70
+
71
+ export const loadModel = async () => {
72
+ const session = await InferenceSession.create(
73
+ new URL('./${opts.backend === "wasm" ? "model_quant.onnx" : "model.onnx"}', import.meta.url).href
74
+ );
75
+ return {
76
+ predict: async (input) => {
77
+ const feeds = { input: new Tensor('float32', input.data, input.dims) };
78
+ const results = await session.run(feeds);
79
+ const raw = results.output.data;
80
+
81
+ if (!(raw instanceof Float32Array)) {
82
+ throw new Error('Expected Float32Array logits but got something else');
83
+ }
84
+
85
+ const logits = raw;
86
+ const probabilities = softmax(Array.from(logits));
87
+ const predictedClass = argmax(probabilities);
88
+ return { logits, probabilities, predictedClass };
89
+ }
90
+ };
91
+ };
92
+ `;
93
+ const typings = `export interface ModelInput {
94
+ data: Float32Array;
95
+ dims: number[];
96
+ }
97
+
98
+ export interface ModelOutput {
99
+ logits: Float32Array;
100
+ probabilities: number[];
101
+ predictedClass: number;
102
+ }
103
+
104
+ export interface LoadedModel {
105
+ predict(input: ModelInput): Promise<ModelOutput>;
106
+ }
107
+
108
+ /**
109
+ * Load the ONNX model and return a wrapper with \`predict()\` function.
110
+ */
111
+ export function loadModel(): Promise<LoadedModel>;`;
112
+ try {
113
+ fs2.writeFileSync(path2.join(outputDir, "wrapped.ts"), wrapper);
114
+ fs2.writeFileSync(path2.join(outputDir, "wrapped.d.ts"), typings);
115
+ spinner.succeed("\u2705 Wrapper files generated.");
116
+ } catch (error) {
117
+ spinner.fail("\u274C Failed to generate wrapper files.");
118
+ throw error;
119
+ }
120
+ };
121
+
122
+ // src/wrapture.ts
123
+ var program = new Command();
124
+ program.name("wrapture").description("\u{1F300} One-click model exporter: from PyTorch to Web-ready JS/TS").version("0.1.0").requiredOption("-i, --input <file>", "Path to the PyTorch model (.pt)").requiredOption(
125
+ "-o, --output <dir>",
126
+ "Output directory for the wrapped model"
127
+ ).option("--quantize", "Apply quantization to reduce model size").option("--format <type>", "Export format: onnx (default)", "onnx").option(
128
+ "--backend <backend>",
129
+ "Inference backend: webgpu | wasm | cpu",
130
+ "webgpu"
131
+ ).action(async (opts) => {
132
+ const input = path3.resolve(opts.input);
133
+ const output = path3.resolve(opts.output);
134
+ if (!existsSync(input)) {
135
+ console.error(chalk2.red(`Input file not found: ${input}`));
136
+ process.exit(1);
137
+ }
138
+ console.log(chalk2.cyan("\u2728 Wrapture: Exporting model..."));
139
+ try {
140
+ await convert(input, output, opts);
141
+ await generateWrapper(output, opts);
142
+ console.log(chalk2.green("\u2705 Done! Your model is wrapped and ready."));
143
+ } catch (err) {
144
+ console.error(chalk2.red("\u274C Failed to export model:"), err);
145
+ process.exit(1);
146
+ }
147
+ });
148
+ program.parse(process.argv);
@@ -0,0 +1,55 @@
1
+ /* eslint-disable import/no-unused-modules */
2
+ import { MarkdownTheme, MarkdownThemeContext } from 'typedoc-plugin-markdown';
3
+
4
+ /**
5
+ * @param {import('typedoc-plugin-markdown').MarkdownApplication} app
6
+ */
7
+ export function load(app) {
8
+ app.renderer.defineTheme('custom-markdown-theme', MyMarkdownTheme);
9
+
10
+ app.renderer.markdownHooks.on(
11
+ 'content.begin',
12
+ () => `> Last updated ${new Date().toISOString()}`
13
+ );
14
+
15
+ const footerText = `***
16
+
17
+ **Contributing**
18
+
19
+ Want to contribute? Please read the [CONTRIBUTING.md](https://github.com/phun-ky/wrapture/blob/main/CONTRIBUTING.md) and [CODE_OF_CONDUCT.md](https://github.com/phun-ky/wrapture/blob/main/CODE_OF_CONDUCT.md)
20
+
21
+ **Sponsor me**
22
+
23
+ I'm an Open Source evangelist, creating stuff that does not exist yet to help get rid of secondary activities and to enhance systems already in place, be it documentation or web sites.
24
+
25
+ The sponsorship is an unique opportunity to alleviate more hours for me to maintain my projects, create new ones and contribute to the large community we're all part of :)
26
+
27
+ [Support me on GitHub Sponsors](https://github.com/sponsors/phun-ky).
28
+
29
+ ***
30
+
31
+ This project created by [Alexander Vassbotn Røyne-Helgesen](http://phun-ky.net) is licensed under a [MIT License](https://choosealicense.com/licenses/mit/).
32
+ `;
33
+
34
+ app.renderer.markdownHooks.on('page.end', () => footerText);
35
+
36
+ app.renderer.markdownHooks.on('index.page.end', () => footerText);
37
+ }
38
+
39
+ class MyMarkdownTheme extends MarkdownTheme {
40
+ /**
41
+ * @param {import('typedoc-plugin-markdown').MarkdownPageEvent} page
42
+ */
43
+ getRenderContext(page) {
44
+ return new MyMarkdownThemeContext(this, page, this.application.options);
45
+ }
46
+ }
47
+
48
+ class MyMarkdownThemeContext extends MarkdownThemeContext {
49
+ partials = {
50
+ ...this.partials,
51
+ header: () => {
52
+ return '';
53
+ }
54
+ };
55
+ }
@@ -0,0 +1,16 @@
1
+ import { defineConfig } from 'eslint/config';
2
+ import customConfig from 'eslint-config-phun-ky';
3
+
4
+ // eslint-disable-next-line import/no-unused-modules
5
+ export default defineConfig([
6
+ {
7
+ extends: [customConfig]
8
+ },
9
+ {
10
+ files: ['**/*.md'],
11
+ rules: {
12
+ 'no-irregular-whitespace': 'off',
13
+ '@stylistic/indent': 'off'
14
+ }
15
+ }
16
+ ]);
package/package.json ADDED
@@ -0,0 +1,74 @@
1
+ {
2
+ "name": "wrapture",
3
+ "version": "0.0.1",
4
+ "description": "",
5
+ "homepage": "https://github.com/phun-ky/wrapture#readme",
6
+ "bugs": {
7
+ "url": "https://github.com/phun-ky/wrapture/issues"
8
+ },
9
+ "repository": {
10
+ "type": "git",
11
+ "url": "git+https://github.com/phun-ky/wrapture.git"
12
+ },
13
+ "funding": "https://github.com/phun-ky/angle?sponsor=1",
14
+ "license": "MIT",
15
+ "author": "Alexander Vassbotn Røyne-Helgesen <alexander@phun-ky.net>",
16
+ "type": "module",
17
+ "bin": {
18
+ "wrapture": "./bin/wrapture.js"
19
+ },
20
+ "scripts": {
21
+ "commit": "npx git-cz",
22
+ "docs:gen": "node ./node_modules/.bin/typedoc",
23
+ "release": "release-it",
24
+ "start": "node ./bin/wrapture.js",
25
+ "build": "tsup",
26
+ "style:code": "npx putout src",
27
+ "style:format": "./node_modules/.bin/eslint -c ./eslint.config.mjs src --fix && ./node_modules/.bin/prettier --write ./eslint.config.mjs src",
28
+ "style:lint": "./node_modules/.bin/eslint -c ./eslint.config.mjs src && ./node_modules/.bin/prettier --check src",
29
+ "test": "NODE_ENV=test glob -c \"node --import tsx --test --no-warnings\" \"./src/**/__tests__/**/*.[jt]s\"",
30
+ "pretest:ci": "rm -rf coverage && mkdir -p coverage",
31
+ "test:ci": "NODE_ENV=test glob -c \"node --import tsx --test --no-warnings --experimental-test-coverage --test-reporter=cobertura --test-reporter-destination=coverage/cobertura-coverage.xml --test-reporter=spec --test-reporter-destination=stdout\" \"./src/**/__tests__/**/*.[jt]s\""
32
+ },
33
+ "engines": {
34
+ "node": ">=22.0.0",
35
+ "npm": ">=10.8.2"
36
+ },
37
+ "publishConfig": {
38
+ "access": "public"
39
+ },
40
+ "devDependencies": {
41
+ "@release-it/conventional-changelog": "^10.0.0",
42
+ "@rollup/plugin-node-resolve": "^16.0.1",
43
+ "@rollup/plugin-terser": "^0.4.4",
44
+ "@stylistic/eslint-plugin": "^4.2.0",
45
+ "@types/node": "^22.15.3",
46
+ "cobertura": "^1.0.1",
47
+ "eslint": "^9.20.0",
48
+ "eslint-config-phun-ky": "^1.0.0",
49
+ "git-cz": "^4.9.0",
50
+ "onnxruntime-web": "^1.22.0",
51
+ "prettier": "^3.2.5",
52
+ "putout": "^40.1.9",
53
+ "release-it": "^19.0.1",
54
+ "remark-github": "^12.0.0",
55
+ "remark-toc": "^9.0.0",
56
+ "tslib": "^2.3.1",
57
+ "tsup": "^8.4.0",
58
+ "tsx": "^4.7.1",
59
+ "typedoc": "^0.28.3",
60
+ "typedoc-plugin-frontmatter": "^1.0.0",
61
+ "typedoc-plugin-markdown": "^4.2.3",
62
+ "typedoc-plugin-mdn-links": "^5.0.1",
63
+ "typedoc-plugin-no-inherit": "^1.4.0",
64
+ "typedoc-plugin-remark": "^2.0.0",
65
+ "typedoc-plugin-rename-defaults": "^0.7.1",
66
+ "typescript": "^5.0.0",
67
+ "unified-prettier": "^2.0.1"
68
+ },
69
+ "dependencies": {
70
+ "chalk": "^5.4.1",
71
+ "commander": "^13.1.0",
72
+ "ora": "^8.2.0"
73
+ }
74
+ }
@@ -0,0 +1 @@
1
+ # wrapture API documentation
@@ -0,0 +1,72 @@
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ import os
5
+ from onnxsim import simplify
6
+ import onnx
7
+ from onnxruntime.quantization import quantize_dynamic, QuantType
8
+
9
+ def convert_to_onnx(input_path, output_path, quantize=False):
10
+ class BasicCNN(nn.Module): # Define the same model architecture
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.net = nn.Sequential(
14
+ nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
15
+ nn.ReLU(),
16
+ nn.AdaptiveAvgPool2d((1, 1)),
17
+ nn.Flatten(),
18
+ nn.Linear(16, 10)
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.net(x)
23
+
24
+ model = BasicCNN()
25
+
26
+ model.load_state_dict(torch.load(input_path, map_location='cpu', weights_only=False))
27
+
28
+ model.eval()
29
+
30
+ dummy_input = torch.randn(1, 3, 224, 224)
31
+ output_model_path = os.path.join(output_path, 'model.onnx')
32
+ torch.onnx.export(
33
+ model,
34
+ dummy_input,
35
+ output_model_path,
36
+ input_names=['input'],
37
+ output_names=['output'],
38
+ dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}},
39
+ opset_version=11
40
+ )
41
+ model = onnx.load(output_model_path)
42
+ model_simp, check = simplify(model)
43
+ assert check, "Simplified model could not be validated"
44
+ onnx.checker.check_model(model_simp) # throw if broken
45
+ onnx.save(model_simp, output_model_path)
46
+ print("✅ ONNX model simplified.")
47
+
48
+ print(f"✅ Exported ONNX model to: {output_model_path}")
49
+
50
+ if quantize:
51
+ quant_model_path = os.path.join(output_path, 'model_quant.onnx')
52
+ quantize_dynamic(output_model_path, quant_model_path, weight_type=QuantType.QInt8)
53
+
54
+ def main():
55
+ parser = argparse.ArgumentParser(description='Wrapture: Convert PyTorch model to ONNX.')
56
+ parser.add_argument('--input', required=True, help='Path to the PyTorch model (.pt)')
57
+ parser.add_argument('--output', required=True, help='Directory to save the ONNX model')
58
+ parser.add_argument('--format', default='onnx', help='Export format (currently only supports ONNX)')
59
+ parser.add_argument('--quantize', action='store_true', help='Apply quantization')
60
+
61
+ args = parser.parse_args()
62
+
63
+ if not os.path.exists(args.output):
64
+ os.makedirs(args.output)
65
+
66
+ if args.format.lower() == 'onnx':
67
+ convert_to_onnx(args.input, args.output, args.quantize)
68
+ else:
69
+ raise ValueError(f"Unsupported format: {args.format}")
70
+
71
+ if __name__ == '__main__':
72
+ main()
@@ -0,0 +1,20 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class BasicCNN(nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.net = nn.Sequential(
8
+ nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
9
+ nn.ReLU(),
10
+ nn.AdaptiveAvgPool2d((1, 1)),
11
+ nn.Flatten(),
12
+ nn.Linear(16, 10)
13
+ )
14
+
15
+ def forward(self, x):
16
+ return self.net(x)
17
+
18
+ model = BasicCNN()
19
+ torch.save(model.state_dict(), 'test/fixtures/basic_model.pt')
20
+ print("✅ Saved model weights to test/fixtures/basic_model.pt")
@@ -0,0 +1,11 @@
1
+ /* eslint-disable no-console */
2
+ import { loadModel } from '../wrapped.js';
3
+
4
+ const model = await loadModel();
5
+ const input = {
6
+ data: new Float32Array(1 * 3 * 224 * 224),
7
+ dims: [1, 3, 224, 224]
8
+ };
9
+ const out = await model.predict(input);
10
+
11
+ console.log(out);
@@ -0,0 +1,47 @@
1
+ // test/convert.test.js
2
+ import test from 'node:test';
3
+ import assert from 'node:assert/strict';
4
+ import fs from 'node:fs';
5
+ import { tmpdir } from 'node:os';
6
+ import { randomUUID } from 'node:crypto';
7
+ import path, { dirname } from 'node:path';
8
+ import { fileURLToPath } from 'node:url';
9
+
10
+ import { convert } from '../convert';
11
+
12
+ const __filename = fileURLToPath(import.meta.url);
13
+ const __dirname = dirname(__filename);
14
+
15
+ const fixtureModelPath = path.resolve('test', 'fixtures', 'basic_model.pt');
16
+ const outputDir = path.join(tmpdir(), `wrapture-test-${randomUUID()}`);
17
+
18
+ test('convert() should reject if file does not exist', async () => {
19
+ const nonExistentPath = 'some/missing/model.pt';
20
+ const invalidOut = path.join(tmpdir(), `wrapture-test-${randomUUID()}`);
21
+
22
+ await assert.rejects(
23
+ () => convert(nonExistentPath, invalidOut, {}),
24
+ /Input model file not found/i
25
+ );
26
+ });
27
+
28
+ test('convert() resolves if subprocess exits cleanly', async (t) => {
29
+ // Mock subprocess call by making convert.py short-circuit if needed
30
+ const mockInput = fixtureModelPath;
31
+ const output = outputDir;
32
+
33
+ // Make sure the output dir exists
34
+ fs.mkdirSync(output, { recursive: true });
35
+
36
+ // You can only run this if the python script exists and is safe to test
37
+ // Skip the test if script is missing or too slow
38
+ const scriptPath = path.resolve(__dirname, '../../../python', 'convert.py');
39
+ if (!fs.existsSync(scriptPath)) {
40
+ t.skip('Python script does not exist, skipping subprocess test');
41
+ return;
42
+ }
43
+
44
+ await assert.doesNotReject(async () => {
45
+ await convert(mockInput, output, { format: 'onnx', quantize: false });
46
+ });
47
+ });
@@ -0,0 +1,29 @@
1
+ import test from 'node:test';
2
+ import assert from 'node:assert/strict';
3
+ import fs from 'node:fs';
4
+ import path from 'node:path';
5
+ import { generateWrapper } from '../generate-wrapper.js';
6
+
7
+ const testOutputDir = './.test-output';
8
+
9
+ test('generateWrapper creates wrapper files', async (t) => {
10
+ // Clean up previous run
11
+ fs.rmSync(testOutputDir, { recursive: true, force: true });
12
+ fs.mkdirSync(testOutputDir);
13
+
14
+ await generateWrapper(testOutputDir, { backend: 'webgpu' });
15
+
16
+ const files = fs.readdirSync(testOutputDir);
17
+ assert.ok(files.includes('wrapped.ts'), 'wrapped.ts should be generated');
18
+ assert.ok(files.includes('wrapped.d.ts'), 'wrapped.d.ts should be generated');
19
+
20
+ const wrapperContent = fs.readFileSync(
21
+ path.join(testOutputDir, 'wrapped.ts'),
22
+ 'utf-8'
23
+ );
24
+ assert.match(
25
+ wrapperContent,
26
+ /loadModel/,
27
+ 'Wrapper should contain loadModel function'
28
+ );
29
+ });
@@ -0,0 +1,98 @@
1
+ /* eslint-disable import/no-unused-modules */
2
+ /* global process */
3
+ import chalk from 'chalk';
4
+ import ora from 'ora';
5
+
6
+ import { spawn } from 'node:child_process';
7
+ import fs from 'node:fs';
8
+ import path, { dirname } from 'node:path';
9
+ import { fileURLToPath } from 'node:url';
10
+
11
+ const __filename = fileURLToPath(import.meta.url);
12
+ const __dirname = dirname(__filename);
13
+
14
+ /**
15
+ * Options for the {@link convert} function.
16
+ */
17
+ export interface ConvertOptionsInterface {
18
+ /**
19
+ * The output format for the converted model (e.g., 'onnx').
20
+ * Defaults to 'onnx' if not provided.
21
+ */
22
+ format?: string;
23
+
24
+ /**
25
+ * Whether to apply quantization to the model.
26
+ */
27
+ quantize?: boolean;
28
+ }
29
+
30
+ /**
31
+ * Converts a machine learning model to ONNX or another supported format
32
+ * by delegating to a Python script (`convert.py`).
33
+ *
34
+ * This function spawns a subprocess using `python3` and monitors stdout/stderr.
35
+ * It supports additional options such as format selection and quantization.
36
+ *
37
+ * @function convert
38
+ * @param {string} inputPath - Path to the input model file
39
+ * @param {string} outputDir - Directory where the converted model should be saved
40
+ * @param {ConvertOptionsInterface} opts - Conversion options
41
+ * @returns A promise that resolves on success or rejects if conversion fails
42
+ *
43
+ * @throws {Error} If the Python process exits with a non-zero code
44
+ *
45
+ * @example
46
+ * ```ts
47
+ * await convert('models/model.pt', 'out/', { format: 'onnx', quantize: true });
48
+ * ```
49
+ *
50
+ * @see https://nodejs.org/api/child_process.html#child_processspawncommand-args-options
51
+ * @see https://github.com/sindresorhus/ora - Ora spinner for CLI feedback
52
+ * @see https://github.com/chalk/chalk - Chalk for terminal coloring
53
+ */
54
+ export const convert = async (
55
+ inputPath: string,
56
+ outputDir: string,
57
+ opts: ConvertOptionsInterface
58
+ ): Promise<void> => {
59
+ if (!fs.existsSync(inputPath)) {
60
+ throw new Error(`Input model file not found: ${inputPath}`);
61
+ }
62
+
63
+ const spinner = ora('🔄 Converting model to ONNX...').start();
64
+
65
+ return new Promise((resolve, reject) => {
66
+ const scriptPath = path.resolve(__dirname, '../../python/convert.py');
67
+ const args = [
68
+ scriptPath,
69
+ '--input',
70
+ inputPath,
71
+ '--output',
72
+ outputDir,
73
+ '--format',
74
+ opts.format || 'onnx'
75
+ ];
76
+
77
+ if (opts.quantize) args.push('--quantize');
78
+
79
+ const python = spawn('python3', args);
80
+
81
+ python.stdout.on('data', (data) => process.stdout.write(data));
82
+ python.stderr.on('data', (data) =>
83
+ process.stderr.write(chalk.red(data.toString()))
84
+ );
85
+
86
+ python.on('close', (code) => {
87
+ if (code === 0) {
88
+ spinner.succeed('✅ Model converted successfully.');
89
+
90
+ resolve();
91
+ } else {
92
+ spinner.fail('❌ Model conversion failed.');
93
+
94
+ reject(new Error(`convert.py exited with code ${code}`));
95
+ }
96
+ });
97
+ });
98
+ };
@@ -0,0 +1,107 @@
1
+ import ora from 'ora';
2
+
3
+ import fs from 'node:fs';
4
+ import path from 'node:path';
5
+
6
+ /**
7
+ * Options for generating ONNX wrapper files.
8
+ */
9
+ export interface GenerateWrapperOptionsInterface {
10
+ /**
11
+ * The backend to use for inference. This affects the model file used.
12
+ * If set to `'wasm'`, the generated wrapper will load `model_quant.onnx`,
13
+ * otherwise it will load `model.onnx`.
14
+ */
15
+ backend: 'wasm' | 'webgl' | string;
16
+ }
17
+
18
+ /**
19
+ * Generates a TypeScript wrapper and type definition file (`wrapped.ts` and `wrapped.d.ts`)
20
+ * for use with `onnxruntime-web`, including utility functions like `softmax`, `argmax`,
21
+ * and a typed `predict()` function.
22
+ *
23
+ * The generated code loads the correct ONNX model based on the provided backend.
24
+ *
25
+ * @param {string} outputDir - The directory where the wrapper files will be written.
26
+ * @param {GenerateWrapperOptionsInterface} opts - Wrapper generation options, including backend type.
27
+ * @returns A Promise that resolves when the wrapper files are successfully written.
28
+ *
29
+ * @throws Will throw an error if file writing fails.
30
+ *
31
+ * @example
32
+ * ```ts
33
+ * await generateWrapper('./dist', { backend: 'wasm' });
34
+ * // Creates `wrapped.ts` and `wrapped.d.ts` in ./dist
35
+ * ```
36
+ *
37
+ * @see https://www.npmjs.com/package/onnxruntime-web
38
+ */
39
+ export const generateWrapper = async (
40
+ outputDir: string,
41
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
42
+ opts: GenerateWrapperOptionsInterface
43
+ ): Promise<void> => {
44
+ const spinner = ora('🛠 Generating wrapper files...').start();
45
+ const wrapper = `import { InferenceSession, Tensor } from 'onnxruntime-web';
46
+
47
+ const softmax = (logits) => {
48
+ const exps = logits.map(Math.exp);
49
+ const sum = exps.reduce((a, b) => a + b, 0);
50
+ return exps.map(e => e / sum);
51
+ }
52
+
53
+ const argmax = (arr) => {
54
+ return arr.reduce((maxIdx, val, idx, src) => val > src[maxIdx] ? idx : maxIdx, 0);
55
+ }
56
+
57
+ export const loadModel = async () => {
58
+ const session = await InferenceSession.create(
59
+ new URL('./${opts.backend === 'wasm' ? 'model_quant.onnx' : 'model.onnx'}', import.meta.url).href
60
+ );
61
+ return {
62
+ predict: async (input) => {
63
+ const feeds = { input: new Tensor('float32', input.data, input.dims) };
64
+ const results = await session.run(feeds);
65
+ const raw = results.output.data;
66
+
67
+ if (!(raw instanceof Float32Array)) {
68
+ throw new Error('Expected Float32Array logits but got something else');
69
+ }
70
+
71
+ const logits = raw;
72
+ const probabilities = softmax(Array.from(logits));
73
+ const predictedClass = argmax(probabilities);
74
+ return { logits, probabilities, predictedClass };
75
+ }
76
+ };
77
+ };
78
+ `;
79
+ const typings = `export interface ModelInput {
80
+ data: Float32Array;
81
+ dims: number[];
82
+ }
83
+
84
+ export interface ModelOutput {
85
+ logits: Float32Array;
86
+ probabilities: number[];
87
+ predictedClass: number;
88
+ }
89
+
90
+ export interface LoadedModel {
91
+ predict(input: ModelInput): Promise<ModelOutput>;
92
+ }
93
+
94
+ /**
95
+ * Load the ONNX model and return a wrapper with \`predict()\` function.
96
+ */
97
+ export function loadModel(): Promise<LoadedModel>;`;
98
+
99
+ try {
100
+ fs.writeFileSync(path.join(outputDir, 'wrapped.ts'), wrapper);
101
+ fs.writeFileSync(path.join(outputDir, 'wrapped.d.ts'), typings);
102
+ spinner.succeed('✅ Wrapper files generated.');
103
+ } catch (error) {
104
+ spinner.fail('❌ Failed to generate wrapper files.');
105
+ throw error;
106
+ }
107
+ };