tjbot-ce 3.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/README.md +382 -0
- package/dist/camera/camera.d.ts +62 -0
- package/dist/camera/camera.d.ts.map +1 -0
- package/dist/camera/camera.js +155 -0
- package/dist/camera/camera.js.map +1 -0
- package/dist/camera/index.d.ts +18 -0
- package/dist/camera/index.d.ts.map +1 -0
- package/dist/camera/index.js +18 -0
- package/dist/camera/index.js.map +1 -0
- package/dist/config/config-types.d.ts +75 -0
- package/dist/config/config-types.d.ts.map +1 -0
- package/dist/config/config-types.generated.d.ts +495 -0
- package/dist/config/config-types.generated.d.ts.map +1 -0
- package/dist/config/config-types.generated.js +2 -0
- package/dist/config/config-types.generated.js.map +1 -0
- package/dist/config/config-types.js +175 -0
- package/dist/config/config-types.js.map +1 -0
- package/dist/config/index.d.ts +20 -0
- package/dist/config/index.d.ts.map +1 -0
- package/dist/config/index.js +19 -0
- package/dist/config/index.js.map +1 -0
- package/dist/config/tjbot-config.d.ts +98 -0
- package/dist/config/tjbot-config.d.ts.map +1 -0
- package/dist/config/tjbot-config.js +309 -0
- package/dist/config/tjbot-config.js.map +1 -0
- package/dist/config/vendor/colors.yaml +61 -0
- package/dist/config/vendor/model-registry.yaml +275 -0
- package/dist/config/vendor/tjbot-config.schema.yaml +792 -0
- package/dist/config/vendor/tjbot.default.toml +452 -0
- package/dist/led/index.d.ts +20 -0
- package/dist/led/index.d.ts.map +1 -0
- package/dist/led/index.js +20 -0
- package/dist/led/index.js.map +1 -0
- package/dist/led/led-common-anode.d.ts +38 -0
- package/dist/led/led-common-anode.d.ts.map +1 -0
- package/dist/led/led-common-anode.js +79 -0
- package/dist/led/led-common-anode.js.map +1 -0
- package/dist/led/led-neopixel-spi.d.ts +60 -0
- package/dist/led/led-neopixel-spi.d.ts.map +1 -0
- package/dist/led/led-neopixel-spi.js +216 -0
- package/dist/led/led-neopixel-spi.js.map +1 -0
- package/dist/led/led-neopixel-ws281x.js +186 -0
- package/dist/led/led-neopixel.d.ts +57 -0
- package/dist/led/led-neopixel.d.ts.map +1 -0
- package/dist/led/led-neopixel.js +235 -0
- package/dist/led/led-neopixel.js.map +1 -0
- package/dist/microphone/index.d.ts +18 -0
- package/dist/microphone/index.d.ts.map +1 -0
- package/dist/microphone/index.js +18 -0
- package/dist/microphone/index.js.map +1 -0
- package/dist/microphone/microphone.d.ts +65 -0
- package/dist/microphone/microphone.d.ts.map +1 -0
- package/dist/microphone/microphone.js +179 -0
- package/dist/microphone/microphone.js.map +1 -0
- package/dist/rpi-drivers/index.d.ts +22 -0
- package/dist/rpi-drivers/index.d.ts.map +1 -0
- package/dist/rpi-drivers/index.js +22 -0
- package/dist/rpi-drivers/index.js.map +1 -0
- package/dist/rpi-drivers/rpi-detect.d.ts +24 -0
- package/dist/rpi-drivers/rpi-detect.d.ts.map +1 -0
- package/dist/rpi-drivers/rpi-detect.js +49 -0
- package/dist/rpi-drivers/rpi-detect.js.map +1 -0
- package/dist/rpi-drivers/rpi-driver.d.ts +116 -0
- package/dist/rpi-drivers/rpi-driver.d.ts.map +1 -0
- package/dist/rpi-drivers/rpi-driver.js +261 -0
- package/dist/rpi-drivers/rpi-driver.js.map +1 -0
- package/dist/rpi-drivers/rpi3-driver.d.ts +47 -0
- package/dist/rpi-drivers/rpi3-driver.d.ts.map +1 -0
- package/dist/rpi-drivers/rpi3-driver.js +145 -0
- package/dist/rpi-drivers/rpi3-driver.js.map +1 -0
- package/dist/rpi-drivers/rpi4-driver.d.ts +35 -0
- package/dist/rpi-drivers/rpi4-driver.d.ts.map +1 -0
- package/dist/rpi-drivers/rpi4-driver.js +101 -0
- package/dist/rpi-drivers/rpi4-driver.js.map +1 -0
- package/dist/rpi-drivers/rpi5-driver.d.ts +33 -0
- package/dist/rpi-drivers/rpi5-driver.d.ts.map +1 -0
- package/dist/rpi-drivers/rpi5-driver.js +78 -0
- package/dist/rpi-drivers/rpi5-driver.js.map +1 -0
- package/dist/servo/index.d.ts +19 -0
- package/dist/servo/index.d.ts.map +1 -0
- package/dist/servo/index.js +19 -0
- package/dist/servo/index.js.map +1 -0
- package/dist/servo/servo-constants.d.ts +33 -0
- package/dist/servo/servo-constants.d.ts.map +1 -0
- package/dist/servo/servo-constants.js +34 -0
- package/dist/servo/servo-constants.js.map +1 -0
- package/dist/servo/servo-lgpio.d.ts +82 -0
- package/dist/servo/servo-lgpio.d.ts.map +1 -0
- package/dist/servo/servo-lgpio.js +178 -0
- package/dist/servo/servo-lgpio.js.map +1 -0
- package/dist/speaker/audio-player.d.ts +30 -0
- package/dist/speaker/audio-player.d.ts.map +1 -0
- package/dist/speaker/audio-player.js +68 -0
- package/dist/speaker/audio-player.js.map +1 -0
- package/dist/speaker/index.d.ts +18 -0
- package/dist/speaker/index.d.ts.map +1 -0
- package/dist/speaker/index.js +18 -0
- package/dist/speaker/index.js.map +1 -0
- package/dist/speaker/speaker.d.ts +53 -0
- package/dist/speaker/speaker.d.ts.map +1 -0
- package/dist/speaker/speaker.js +125 -0
- package/dist/speaker/speaker.js.map +1 -0
- package/dist/stt/backends/azure-stt.d.ts +32 -0
- package/dist/stt/backends/azure-stt.d.ts.map +1 -0
- package/dist/stt/backends/azure-stt.js +227 -0
- package/dist/stt/backends/azure-stt.js.map +1 -0
- package/dist/stt/backends/google-cloud-stt.d.ts +31 -0
- package/dist/stt/backends/google-cloud-stt.d.ts.map +1 -0
- package/dist/stt/backends/google-cloud-stt.js +371 -0
- package/dist/stt/backends/google-cloud-stt.js.map +1 -0
- package/dist/stt/backends/ibm-watson-stt.d.ts +32 -0
- package/dist/stt/backends/ibm-watson-stt.d.ts.map +1 -0
- package/dist/stt/backends/ibm-watson-stt.js +190 -0
- package/dist/stt/backends/ibm-watson-stt.js.map +1 -0
- package/dist/stt/backends/sherpa-onnx-stt.d.ts +117 -0
- package/dist/stt/backends/sherpa-onnx-stt.d.ts.map +1 -0
- package/dist/stt/backends/sherpa-onnx-stt.js +694 -0
- package/dist/stt/backends/sherpa-onnx-stt.js.map +1 -0
- package/dist/stt/index.d.ts +20 -0
- package/dist/stt/index.d.ts.map +1 -0
- package/dist/stt/index.js +21 -0
- package/dist/stt/index.js.map +1 -0
- package/dist/stt/stt-engine.d.ts +68 -0
- package/dist/stt/stt-engine.d.ts.map +1 -0
- package/dist/stt/stt-engine.js +99 -0
- package/dist/stt/stt-engine.js.map +1 -0
- package/dist/stt/stt-utils.d.ts +36 -0
- package/dist/stt/stt-utils.d.ts.map +1 -0
- package/dist/stt/stt-utils.js +112 -0
- package/dist/stt/stt-utils.js.map +1 -0
- package/dist/stt/stt.d.ts +52 -0
- package/dist/stt/stt.d.ts.map +1 -0
- package/dist/stt/stt.js +100 -0
- package/dist/stt/stt.js.map +1 -0
- package/dist/tjbot.d.ts +317 -0
- package/dist/tjbot.d.ts.map +1 -0
- package/dist/tjbot.js +736 -0
- package/dist/tjbot.js.map +1 -0
- package/dist/tts/backends/azure-tts.d.ts +30 -0
- package/dist/tts/backends/azure-tts.d.ts.map +1 -0
- package/dist/tts/backends/azure-tts.js +92 -0
- package/dist/tts/backends/azure-tts.js.map +1 -0
- package/dist/tts/backends/google-cloud-tts.d.ts +38 -0
- package/dist/tts/backends/google-cloud-tts.d.ts.map +1 -0
- package/dist/tts/backends/google-cloud-tts.js +116 -0
- package/dist/tts/backends/google-cloud-tts.js.map +1 -0
- package/dist/tts/backends/ibm-watson-tts.d.ts +42 -0
- package/dist/tts/backends/ibm-watson-tts.d.ts.map +1 -0
- package/dist/tts/backends/ibm-watson-tts.js +99 -0
- package/dist/tts/backends/ibm-watson-tts.js.map +1 -0
- package/dist/tts/backends/sherpa-onnx-tts.d.ts +80 -0
- package/dist/tts/backends/sherpa-onnx-tts.d.ts.map +1 -0
- package/dist/tts/backends/sherpa-onnx-tts.js +237 -0
- package/dist/tts/backends/sherpa-onnx-tts.js.map +1 -0
- package/dist/tts/index.d.ts +19 -0
- package/dist/tts/index.d.ts.map +1 -0
- package/dist/tts/index.js +20 -0
- package/dist/tts/index.js.map +1 -0
- package/dist/tts/tts-engine.d.ts +67 -0
- package/dist/tts/tts-engine.d.ts.map +1 -0
- package/dist/tts/tts-engine.js +109 -0
- package/dist/tts/tts-engine.js.map +1 -0
- package/dist/tts/tts.d.ts +47 -0
- package/dist/tts/tts.d.ts.map +1 -0
- package/dist/tts/tts.js +101 -0
- package/dist/tts/tts.js.map +1 -0
- package/dist/utils/colors.d.ts +39 -0
- package/dist/utils/colors.d.ts.map +1 -0
- package/dist/utils/colors.js +155 -0
- package/dist/utils/colors.js.map +1 -0
- package/dist/utils/constants.d.ts +41 -0
- package/dist/utils/constants.d.ts.map +1 -0
- package/dist/utils/constants.js +43 -0
- package/dist/utils/constants.js.map +1 -0
- package/dist/utils/credentials.d.ts +43 -0
- package/dist/utils/credentials.d.ts.map +1 -0
- package/dist/utils/credentials.js +121 -0
- package/dist/utils/credentials.js.map +1 -0
- package/dist/utils/errors.d.ts +26 -0
- package/dist/utils/errors.d.ts.map +1 -0
- package/dist/utils/errors.js +32 -0
- package/dist/utils/errors.js.map +1 -0
- package/dist/utils/index.d.ts +25 -0
- package/dist/utils/index.d.ts.map +1 -0
- package/dist/utils/index.js +23 -0
- package/dist/utils/index.js.map +1 -0
- package/dist/utils/logging.d.ts +44 -0
- package/dist/utils/logging.d.ts.map +1 -0
- package/dist/utils/logging.js +113 -0
- package/dist/utils/logging.js.map +1 -0
- package/dist/utils/model-registry.d.ts +142 -0
- package/dist/utils/model-registry.d.ts.map +1 -0
- package/dist/utils/model-registry.js +391 -0
- package/dist/utils/model-registry.js.map +1 -0
- package/dist/utils/utils.d.ts +33 -0
- package/dist/utils/utils.d.ts.map +1 -0
- package/dist/utils/utils.js +50 -0
- package/dist/utils/utils.js.map +1 -0
- package/dist/vision/backends/azure-vision.d.ts +33 -0
- package/dist/vision/backends/azure-vision.d.ts.map +1 -0
- package/dist/vision/backends/azure-vision.js +151 -0
- package/dist/vision/backends/azure-vision.js.map +1 -0
- package/dist/vision/backends/google-cloud-vision.d.ts +32 -0
- package/dist/vision/backends/google-cloud-vision.d.ts.map +1 -0
- package/dist/vision/backends/google-cloud-vision.js +193 -0
- package/dist/vision/backends/google-cloud-vision.js.map +1 -0
- package/dist/vision/backends/onnx.d.ts +116 -0
- package/dist/vision/backends/onnx.d.ts.map +1 -0
- package/dist/vision/backends/onnx.js +781 -0
- package/dist/vision/backends/onnx.js.map +1 -0
- package/dist/vision/index.d.ts +19 -0
- package/dist/vision/index.d.ts.map +1 -0
- package/dist/vision/index.js +20 -0
- package/dist/vision/index.js.map +1 -0
- package/dist/vision/vision-engine.d.ts +131 -0
- package/dist/vision/vision-engine.d.ts.map +1 -0
- package/dist/vision/vision-engine.js +97 -0
- package/dist/vision/vision-engine.js.map +1 -0
- package/dist/vision/vision.d.ts +48 -0
- package/dist/vision/vision.d.ts.map +1 -0
- package/dist/vision/vision.js +83 -0
- package/dist/vision/vision.js.map +1 -0
- package/package.json +124 -0
|
@@ -0,0 +1,781 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright 2026-present TJBot Contributors. All Rights Reserved.
|
|
3
|
+
*
|
|
4
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
* you may not use this file except in compliance with the License.
|
|
6
|
+
* You may obtain a copy of the License at
|
|
7
|
+
*
|
|
8
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
*
|
|
10
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
* See the License for the specific language governing permissions and
|
|
14
|
+
* limitations under the License.
|
|
15
|
+
*/
|
|
16
|
+
import fs from 'fs';
|
|
17
|
+
import * as ort from 'onnxruntime-node';
|
|
18
|
+
import path from 'path';
|
|
19
|
+
import sharp from 'sharp';
|
|
20
|
+
import { ModelRegistry, TJBotError } from '../../utils/index.js';
|
|
21
|
+
import { getLogger } from '../../utils/logging.js';
|
|
22
|
+
import { VisionEngine, } from '../vision-engine.js';
|
|
23
|
+
const logger = getLogger(import.meta.url);
|
|
24
|
+
export class ONNXVisionEngine extends VisionEngine {
|
|
25
|
+
manager = ModelRegistry.getInstance();
|
|
26
|
+
models = new Map();
|
|
27
|
+
/**
|
|
28
|
+
* Initialize the ONNX vision engine.
|
|
29
|
+
*/
|
|
30
|
+
async initialize() {
|
|
31
|
+
const config = this.config;
|
|
32
|
+
if (!config.objectDetectionModel) {
|
|
33
|
+
throw new TJBotError('ONNX vision engine config is missing required parameter: objectDetectionModel');
|
|
34
|
+
}
|
|
35
|
+
if (!config.imageClassificationModel) {
|
|
36
|
+
throw new TJBotError('ONNX vision engine config is missing required parameter: imageClassificationModel');
|
|
37
|
+
}
|
|
38
|
+
if (!config.faceDetectionModel) {
|
|
39
|
+
throw new TJBotError('ONNX vision engine config is missing required parameter: faceDetectionModel');
|
|
40
|
+
}
|
|
41
|
+
// Eagerly load all models
|
|
42
|
+
await this.loadModel(config.objectDetectionModel);
|
|
43
|
+
await this.loadModel(config.imageClassificationModel);
|
|
44
|
+
await this.loadModel(config.faceDetectionModel);
|
|
45
|
+
logger.info('ONNX vision engine initialized');
|
|
46
|
+
logger.debug(`Initialized ONNXVisionEngine with config:
|
|
47
|
+
objectDetectionModel: ${config.objectDetectionModel},
|
|
48
|
+
objectDetectionConfidence: ${config.objectDetectionConfidence},
|
|
49
|
+
imageClassificationModel: ${config.imageClassificationModel},
|
|
50
|
+
imageClassificationConfidence: ${config.imageClassificationConfidence},
|
|
51
|
+
faceDetectionModel: ${config.faceDetectionModel},
|
|
52
|
+
faceDetectionConfidence: ${config.faceDetectionConfidence}`);
|
|
53
|
+
}
|
|
54
|
+
/**
|
|
55
|
+
* Load a model
|
|
56
|
+
*/
|
|
57
|
+
async loadModel(modelName) {
|
|
58
|
+
if (modelName === undefined) {
|
|
59
|
+
throw new TJBotError('Cannot load model: modelName is undefined');
|
|
60
|
+
}
|
|
61
|
+
if (this.models.has(modelName)) {
|
|
62
|
+
return; // Already loaded
|
|
63
|
+
}
|
|
64
|
+
logger.verbose(`Loading ONNX model: ${modelName}`);
|
|
65
|
+
// Get model metadata and download
|
|
66
|
+
const metadata = await this.manager.loadModel(modelName);
|
|
67
|
+
// Build model path
|
|
68
|
+
const modelCacheDir = this.manager.getModelCacheDirForType('vision');
|
|
69
|
+
const modelDir = path.join(modelCacheDir, metadata.folder);
|
|
70
|
+
// Find the ONNX model file in the required files
|
|
71
|
+
const onnxFile = metadata.required.find((file) => file.endsWith('.onnx'));
|
|
72
|
+
if (!onnxFile) {
|
|
73
|
+
throw new TJBotError(`No ONNX file found in model requirements for: ${modelName}`);
|
|
74
|
+
}
|
|
75
|
+
const modelPath = path.join(modelDir, onnxFile);
|
|
76
|
+
// Create ONNX session
|
|
77
|
+
const session = await ort.InferenceSession.create(modelPath);
|
|
78
|
+
// Load labels if available
|
|
79
|
+
let labels = [];
|
|
80
|
+
if (metadata.labelUrl && metadata.kind !== 'face-detection') {
|
|
81
|
+
labels = await this.loadLabels(modelName, metadata, modelDir);
|
|
82
|
+
}
|
|
83
|
+
// Get input shape from metadata
|
|
84
|
+
const inputShape = metadata.inputShape ?? [1, 3, 640, 640];
|
|
85
|
+
this.models.set(modelName, {
|
|
86
|
+
session,
|
|
87
|
+
labels,
|
|
88
|
+
inputShape,
|
|
89
|
+
kind: metadata.kind,
|
|
90
|
+
});
|
|
91
|
+
logger.info(`Loaded ONNX model: ${modelName} (${metadata.kind})`);
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Load label file for a model
|
|
95
|
+
*/
|
|
96
|
+
async loadLabels(modelName, metadata, modelDir) {
|
|
97
|
+
logger.info(`Loading labels for model: ${modelName}`);
|
|
98
|
+
try {
|
|
99
|
+
// Try common label file names based on model kind
|
|
100
|
+
let labelFile;
|
|
101
|
+
if (metadata.kind === 'detection') {
|
|
102
|
+
// Look for classes.txt, coco.yaml or coco.names
|
|
103
|
+
const possibleNames = ['classes.txt', 'coco.yaml', 'coco.names'];
|
|
104
|
+
for (const name of possibleNames) {
|
|
105
|
+
if (fs.existsSync(path.join(modelDir, name))) {
|
|
106
|
+
labelFile = path.join(modelDir, name);
|
|
107
|
+
break;
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
else if (metadata.kind === 'classification') {
|
|
112
|
+
// Look for imagenet_classes.txt or similar
|
|
113
|
+
const possibleNames = ['imagenet_classes.txt', 'labels.txt', 'classes.txt'];
|
|
114
|
+
for (const name of possibleNames) {
|
|
115
|
+
if (fs.existsSync(path.join(modelDir, name))) {
|
|
116
|
+
labelFile = path.join(modelDir, name);
|
|
117
|
+
break;
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
if (!labelFile) {
|
|
122
|
+
logger.warn(`No label file found for model: ${modelName}`);
|
|
123
|
+
return [];
|
|
124
|
+
}
|
|
125
|
+
logger.debug(`Found label file for ${modelName}: ${labelFile}`);
|
|
126
|
+
const content = fs.readFileSync(labelFile, 'utf8');
|
|
127
|
+
// Parse YAML files for detection models
|
|
128
|
+
if (labelFile.endsWith('.yaml') && metadata.kind === 'detection') {
|
|
129
|
+
// Extract class names from YAML
|
|
130
|
+
// YAML format 1: names: ['person', 'bicycle', ...]
|
|
131
|
+
let namesMatch = content.match(/names:\s*\[(.*?)\]/s);
|
|
132
|
+
if (namesMatch) {
|
|
133
|
+
const namesStr = namesMatch[1];
|
|
134
|
+
// Split by comma and clean up each class name
|
|
135
|
+
return namesStr
|
|
136
|
+
.split(',')
|
|
137
|
+
.map((name) => name.trim().replace(/^['"]|['"]$/g, ''))
|
|
138
|
+
.filter((name) => name.length > 0);
|
|
139
|
+
}
|
|
140
|
+
// YAML format 2: names: \n 0: person \n 1: bicycle \n ...
|
|
141
|
+
namesMatch = content.match(/names:\s*\n([\s\S]*?)(?:\n[a-z]|$)/);
|
|
142
|
+
if (namesMatch) {
|
|
143
|
+
const namesStr = namesMatch[1];
|
|
144
|
+
// Extract values from "index: 'value'" format
|
|
145
|
+
const lines = namesStr.split('\n');
|
|
146
|
+
const labels = lines
|
|
147
|
+
.map((line) => {
|
|
148
|
+
// Match pattern like "67: 'cell phone'" or "67: cell phone"
|
|
149
|
+
const match = line.match(/^\s*\d+:\s*['"]?([^'"]+)['"]?\s*$/);
|
|
150
|
+
return match ? match[1].trim() : null;
|
|
151
|
+
})
|
|
152
|
+
.filter((name) => name !== null && name.length > 0);
|
|
153
|
+
return labels;
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
// For non-YAML files (txt), split by newlines
|
|
157
|
+
let labels = content
|
|
158
|
+
.split('\n')
|
|
159
|
+
.map((line) => line.trim())
|
|
160
|
+
.filter((line) => line.length > 0);
|
|
161
|
+
// Remove numeric prefixes if present (e.g., "67: cell phone" -> "cell phone")
|
|
162
|
+
if (labels.length > 0 && labels[0].includes(':')) {
|
|
163
|
+
labels = labels.map((line) => {
|
|
164
|
+
const match = line.match(/^\d+:\s*(.+)$/);
|
|
165
|
+
return match ? match[1].trim() : line;
|
|
166
|
+
});
|
|
167
|
+
}
|
|
168
|
+
return labels;
|
|
169
|
+
}
|
|
170
|
+
catch (error) {
|
|
171
|
+
logger.warn(`Failed to load labels for ${modelName}:`, error);
|
|
172
|
+
return [];
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
/**
|
|
176
|
+
* Get a model, loading it if necessary
|
|
177
|
+
*/
|
|
178
|
+
async getOrLoadModel(modelName) {
|
|
179
|
+
let model = this.models.get(modelName);
|
|
180
|
+
if (!model) {
|
|
181
|
+
logger.debug(`model ${modelName} not yet loaded, loading now...`);
|
|
182
|
+
await this.loadModel(modelName);
|
|
183
|
+
model = this.models.get(modelName);
|
|
184
|
+
}
|
|
185
|
+
if (!model) {
|
|
186
|
+
throw new TJBotError(`Failed to load model: ${modelName}`);
|
|
187
|
+
}
|
|
188
|
+
return model;
|
|
189
|
+
}
|
|
190
|
+
getObjectDetectionConfidenceThreshold() {
|
|
191
|
+
const config = this.config;
|
|
192
|
+
if (config.objectDetectionConfidence === undefined) {
|
|
193
|
+
throw new TJBotError('Object detection confidence threshold is not configured for ONNX vision engine');
|
|
194
|
+
}
|
|
195
|
+
return config.objectDetectionConfidence;
|
|
196
|
+
}
|
|
197
|
+
getImageClassificationConfidenceThreshold() {
|
|
198
|
+
const config = this.config;
|
|
199
|
+
if (config.imageClassificationConfidence === undefined) {
|
|
200
|
+
throw new TJBotError('Image classification confidence threshold is not configured for ONNX vision engine');
|
|
201
|
+
}
|
|
202
|
+
return config.imageClassificationConfidence;
|
|
203
|
+
}
|
|
204
|
+
getFaceDetectionConfidenceThreshold() {
|
|
205
|
+
const config = this.config;
|
|
206
|
+
if (config.faceDetectionConfidence === undefined) {
|
|
207
|
+
throw new TJBotError('Face detection confidence threshold is not configured for ONNX vision engine');
|
|
208
|
+
}
|
|
209
|
+
return config.faceDetectionConfidence;
|
|
210
|
+
}
|
|
211
|
+
/**
|
|
212
|
+
* Detect objects in an image.
|
|
213
|
+
*/
|
|
214
|
+
async detectObjects(image) {
|
|
215
|
+
const config = this.config;
|
|
216
|
+
if (config.objectDetectionModel === undefined) {
|
|
217
|
+
throw new TJBotError('Object detection model is not configured for ONNX vision engine');
|
|
218
|
+
}
|
|
219
|
+
const modelName = config.objectDetectionModel;
|
|
220
|
+
const resolvedConfidenceThreshold = this.getObjectDetectionConfidenceThreshold();
|
|
221
|
+
logger.info(`Running object detection using model ${modelName} with confidence threshold ${resolvedConfidenceThreshold}`);
|
|
222
|
+
const model = await this.getOrLoadModel(modelName);
|
|
223
|
+
try {
|
|
224
|
+
// Preprocess image using model's expected input size
|
|
225
|
+
const [, , height, width] = model.inputShape;
|
|
226
|
+
const input = await this.preprocessImage(image, [width, height]);
|
|
227
|
+
// Run inference
|
|
228
|
+
const feeds = {};
|
|
229
|
+
feeds[model.session.inputNames[0]] = input;
|
|
230
|
+
const results = await model.session.run(feeds);
|
|
231
|
+
// Postprocess YOLO output
|
|
232
|
+
return this.postprocessDetection(results, model.labels, model.session.outputNames, resolvedConfidenceThreshold);
|
|
233
|
+
}
|
|
234
|
+
catch (error) {
|
|
235
|
+
throw new TJBotError('Object detection failed', { cause: error });
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
/**
|
|
239
|
+
* Classify an image.
|
|
240
|
+
*/
|
|
241
|
+
async classifyImage(image) {
|
|
242
|
+
const config = this.config;
|
|
243
|
+
if (config.imageClassificationModel === undefined) {
|
|
244
|
+
throw new TJBotError('Image classification model is not configured for ONNX vision engine');
|
|
245
|
+
}
|
|
246
|
+
const modelName = config.imageClassificationModel;
|
|
247
|
+
const resolvedConfidenceThreshold = this.getImageClassificationConfidenceThreshold();
|
|
248
|
+
logger.info(`Running image classification using model ${modelName} with confidence threshold ${resolvedConfidenceThreshold}`);
|
|
249
|
+
const model = await this.getOrLoadModel(modelName);
|
|
250
|
+
try {
|
|
251
|
+
// Preprocess image using model's expected input size
|
|
252
|
+
const [, , height, width] = model.inputShape;
|
|
253
|
+
const input = await this.preprocessImage(image, [width, height]);
|
|
254
|
+
// Run inference
|
|
255
|
+
const feeds = {};
|
|
256
|
+
feeds[model.session.inputNames[0]] = input;
|
|
257
|
+
const results = await model.session.run(feeds);
|
|
258
|
+
// Postprocess classification output
|
|
259
|
+
return this.postprocessClassification(results, model.labels, resolvedConfidenceThreshold, model.session.outputNames);
|
|
260
|
+
}
|
|
261
|
+
catch (error) {
|
|
262
|
+
throw new TJBotError('Image classification failed', { cause: error });
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
/**
|
|
266
|
+
* Detect faces in an image.
|
|
267
|
+
*/
|
|
268
|
+
async detectFaces(image) {
|
|
269
|
+
const config = this.config;
|
|
270
|
+
if (config.faceDetectionModel === undefined) {
|
|
271
|
+
throw new TJBotError('Face detection model is not configured for ONNX vision engine');
|
|
272
|
+
}
|
|
273
|
+
const modelName = config.faceDetectionModel;
|
|
274
|
+
const confidenceThreshold = this.getFaceDetectionConfidenceThreshold();
|
|
275
|
+
logger.info(`Running face detection using model ${modelName} with confidence threshold ${confidenceThreshold}`);
|
|
276
|
+
const model = await this.getOrLoadModel(modelName);
|
|
277
|
+
try {
|
|
278
|
+
// Preprocess image using model's expected input size
|
|
279
|
+
const [, , height, width] = model.inputShape;
|
|
280
|
+
const input = await this.preprocessFaceImage(image, [width, height], modelName);
|
|
281
|
+
// Run inference
|
|
282
|
+
const feeds = {};
|
|
283
|
+
feeds[model.session.inputNames[0]] = input;
|
|
284
|
+
const results = await model.session.run(feeds);
|
|
285
|
+
logger.debug(`Face model output: ${model.session.outputNames.join(', ')}`);
|
|
286
|
+
const outputTensor = results[model.session.outputNames[0]];
|
|
287
|
+
logger.debug(`Output shape: [${outputTensor.dims.join(', ')}], size: ${outputTensor.size}`);
|
|
288
|
+
// Postprocess face detection output
|
|
289
|
+
const metadata = this.postprocessFaceDetection(results, confidenceThreshold, [width, height]);
|
|
290
|
+
return {
|
|
291
|
+
isFaceDetected: metadata.length > 0,
|
|
292
|
+
metadata,
|
|
293
|
+
};
|
|
294
|
+
}
|
|
295
|
+
catch (error) {
|
|
296
|
+
throw new TJBotError('Face detection failed', { cause: error });
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
/**
|
|
300
|
+
* Describe an image - not supported by ONNX backend.
|
|
301
|
+
*/
|
|
302
|
+
async describeImage(_image) {
|
|
303
|
+
throw new TJBotError('Image description is only available with Azure Vision backend. Configure see.backend.type to "azure-vision".');
|
|
304
|
+
}
|
|
305
|
+
/**
|
|
306
|
+
* Sigmoid function to normalize logits to 0-1 range
|
|
307
|
+
*/
|
|
308
|
+
sigmoid(x) {
|
|
309
|
+
return 1 / (1 + Math.exp(-x));
|
|
310
|
+
}
|
|
311
|
+
postprocessDetection(results, labels, outputNames, confidenceThreshold = 0.8) {
|
|
312
|
+
const isSSDMobileNetV2 = outputNames.some((name) => name.includes('BoxPredictor_'));
|
|
313
|
+
if (isSSDMobileNetV2) {
|
|
314
|
+
return this.postprocessSSDMobileNetV2(results, labels, confidenceThreshold);
|
|
315
|
+
}
|
|
316
|
+
// Fallback for YOLO-style output [batch, num_detections, (x, y, w, h, conf, class_scores...)]
|
|
317
|
+
const outputName = outputNames[0];
|
|
318
|
+
const outputData = results[outputName].data;
|
|
319
|
+
let detections = [];
|
|
320
|
+
const numClasses = labels.length || 80;
|
|
321
|
+
const valuesPerDetection = 5 + numClasses;
|
|
322
|
+
for (let i = 0; i < outputData.length; i += valuesPerDetection) {
|
|
323
|
+
const confidence = this.sigmoid(outputData[i + 4]);
|
|
324
|
+
if (confidence < confidenceThreshold)
|
|
325
|
+
continue;
|
|
326
|
+
let maxClassScore = 0;
|
|
327
|
+
let maxClassIdx = 0;
|
|
328
|
+
for (let j = 0; j < numClasses; j++) {
|
|
329
|
+
const score = this.sigmoid(outputData[i + 5 + j]);
|
|
330
|
+
if (score > maxClassScore) {
|
|
331
|
+
maxClassScore = score;
|
|
332
|
+
maxClassIdx = j;
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
const label = labels[maxClassIdx] || `class${maxClassIdx}`;
|
|
336
|
+
const x = outputData[i];
|
|
337
|
+
const y = outputData[i + 1];
|
|
338
|
+
const w = outputData[i + 2];
|
|
339
|
+
const h = outputData[i + 3];
|
|
340
|
+
detections.push({
|
|
341
|
+
label,
|
|
342
|
+
confidence: maxClassScore,
|
|
343
|
+
bbox: [x, y, w, h],
|
|
344
|
+
});
|
|
345
|
+
}
|
|
346
|
+
detections = this.nonMaxSuppression(detections);
|
|
347
|
+
return detections;
|
|
348
|
+
}
|
|
349
|
+
/**
|
|
350
|
+
* Decode SSD MobileNet v2 raw predictor outputs into object detections.
|
|
351
|
+
*/
|
|
352
|
+
postprocessSSDMobileNetV2(results, labels, confidenceThreshold) {
|
|
353
|
+
const boxScales = {
|
|
354
|
+
x: 10,
|
|
355
|
+
y: 10,
|
|
356
|
+
w: 5,
|
|
357
|
+
h: 5,
|
|
358
|
+
};
|
|
359
|
+
const featureMapShapes = [
|
|
360
|
+
[19, 19],
|
|
361
|
+
[10, 10],
|
|
362
|
+
[5, 5],
|
|
363
|
+
[3, 3],
|
|
364
|
+
[2, 2],
|
|
365
|
+
[1, 1],
|
|
366
|
+
];
|
|
367
|
+
const anchorsByLayer = this.generateSSDMobileNetV2Anchors(featureMapShapes);
|
|
368
|
+
const detections = [];
|
|
369
|
+
for (let layer = 0; layer < featureMapShapes.length; layer++) {
|
|
370
|
+
const boxTensor = results[`BoxPredictor_${layer}/BoxEncodingPredictor/BiasAdd:0`];
|
|
371
|
+
const classTensor = results[`BoxPredictor_${layer}/ClassPredictor/BiasAdd:0`];
|
|
372
|
+
if (!boxTensor || !classTensor) {
|
|
373
|
+
continue;
|
|
374
|
+
}
|
|
375
|
+
const boxData = boxTensor.data;
|
|
376
|
+
const classData = classTensor.data;
|
|
377
|
+
const [, boxChannels, h, w] = boxTensor.dims;
|
|
378
|
+
const [, classChannels] = classTensor.dims;
|
|
379
|
+
const numAnchorsPerCell = boxChannels / 4;
|
|
380
|
+
const numClassesWithBackground = classChannels / numAnchorsPerCell;
|
|
381
|
+
for (let y = 0; y < h; y++) {
|
|
382
|
+
for (let x = 0; x < w; x++) {
|
|
383
|
+
for (let a = 0; a < numAnchorsPerCell; a++) {
|
|
384
|
+
const anchorIdxInLayer = (y * w + x) * numAnchorsPerCell + a;
|
|
385
|
+
const anchor = anchorsByLayer[layer][anchorIdxInLayer];
|
|
386
|
+
if (!anchor)
|
|
387
|
+
continue;
|
|
388
|
+
const classLogits = new Float32Array(numClassesWithBackground);
|
|
389
|
+
for (let c = 0; c < numClassesWithBackground; c++) {
|
|
390
|
+
const classChannel = a * numClassesWithBackground + c;
|
|
391
|
+
const classOffset = (classChannel * h + y) * w + x;
|
|
392
|
+
classLogits[c] = classData[classOffset];
|
|
393
|
+
}
|
|
394
|
+
const probs = this.softmax(classLogits);
|
|
395
|
+
// Class index 0 is background for SSD models.
|
|
396
|
+
let bestClass = 0;
|
|
397
|
+
let bestScore = 0;
|
|
398
|
+
for (let c = 1; c < probs.length; c++) {
|
|
399
|
+
if (probs[c] > bestScore) {
|
|
400
|
+
bestScore = probs[c];
|
|
401
|
+
bestClass = c;
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
if (bestScore < confidenceThreshold) {
|
|
405
|
+
continue;
|
|
406
|
+
}
|
|
407
|
+
// Box tensor channel layout per anchor: [ty, tx, th, tw]
|
|
408
|
+
const ty = boxData[((a * 4 + 0) * h + y) * w + x];
|
|
409
|
+
const tx = boxData[((a * 4 + 1) * h + y) * w + x];
|
|
410
|
+
const th = boxData[((a * 4 + 2) * h + y) * w + x];
|
|
411
|
+
const tw = boxData[((a * 4 + 3) * h + y) * w + x];
|
|
412
|
+
const yCenter = (ty / boxScales.y) * anchor.h + anchor.cy;
|
|
413
|
+
const xCenter = (tx / boxScales.x) * anchor.w + anchor.cx;
|
|
414
|
+
const boxH = Math.exp(th / boxScales.h) * anchor.h;
|
|
415
|
+
const boxW = Math.exp(tw / boxScales.w) * anchor.w;
|
|
416
|
+
const xMin = Math.max(0, Math.min(1, xCenter - boxW / 2));
|
|
417
|
+
const yMin = Math.max(0, Math.min(1, yCenter - boxH / 2));
|
|
418
|
+
const xMax = Math.max(0, Math.min(1, xCenter + boxW / 2));
|
|
419
|
+
const yMax = Math.max(0, Math.min(1, yCenter + boxH / 2));
|
|
420
|
+
const width = xMax - xMin;
|
|
421
|
+
const height = yMax - yMin;
|
|
422
|
+
if (width <= 0 || height <= 0) {
|
|
423
|
+
continue;
|
|
424
|
+
}
|
|
425
|
+
const labelIndex = bestClass - 1;
|
|
426
|
+
const label = labels[labelIndex] || `class${labelIndex}`;
|
|
427
|
+
detections.push({
|
|
428
|
+
label,
|
|
429
|
+
confidence: bestScore,
|
|
430
|
+
bbox: [xMin, yMin, width, height],
|
|
431
|
+
});
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
return this.nonMaxSuppression(detections);
|
|
437
|
+
}
|
|
438
|
+
/**
|
|
439
|
+
* Generate normalized anchors for SSD MobileNet v2 with input size 300x300.
|
|
440
|
+
*/
|
|
441
|
+
generateSSDMobileNetV2Anchors(featureMapShapes) {
|
|
442
|
+
const minScale = 0.2;
|
|
443
|
+
const maxScale = 0.95;
|
|
444
|
+
const aspectRatios = [1.0, 2.0, 0.5, 3.0, 1.0 / 3.0];
|
|
445
|
+
const anchorsByLayer = [];
|
|
446
|
+
const scaleForLayer = (layer) => {
|
|
447
|
+
if (featureMapShapes.length === 1) {
|
|
448
|
+
return (minScale + maxScale) * 0.5;
|
|
449
|
+
}
|
|
450
|
+
return minScale + ((maxScale - minScale) * layer) / (featureMapShapes.length - 1);
|
|
451
|
+
};
|
|
452
|
+
for (let layer = 0; layer < featureMapShapes.length; layer++) {
|
|
453
|
+
const [featH, featW] = featureMapShapes[layer];
|
|
454
|
+
const scale = scaleForLayer(layer);
|
|
455
|
+
const nextScale = layer === featureMapShapes.length - 1 ? 1.0 : scaleForLayer(layer + 1);
|
|
456
|
+
const layerAnchors = [];
|
|
457
|
+
const anchorSizes = [];
|
|
458
|
+
if (layer === 0) {
|
|
459
|
+
// Reduced anchor set on first layer per TF SSD config.
|
|
460
|
+
anchorSizes.push({ w: 0.1, h: 0.1 });
|
|
461
|
+
anchorSizes.push({ w: scale * Math.sqrt(2.0), h: scale / Math.sqrt(2.0) });
|
|
462
|
+
anchorSizes.push({ w: scale / Math.sqrt(2.0), h: scale * Math.sqrt(2.0) });
|
|
463
|
+
}
|
|
464
|
+
else {
|
|
465
|
+
for (const ratio of aspectRatios) {
|
|
466
|
+
const ratioSqrt = Math.sqrt(ratio);
|
|
467
|
+
anchorSizes.push({ w: scale * ratioSqrt, h: scale / ratioSqrt });
|
|
468
|
+
}
|
|
469
|
+
// Interpolated scale anchor with aspect ratio 1.0.
|
|
470
|
+
const interpolated = Math.sqrt(scale * nextScale);
|
|
471
|
+
anchorSizes.push({ w: interpolated, h: interpolated });
|
|
472
|
+
}
|
|
473
|
+
for (let y = 0; y < featH; y++) {
|
|
474
|
+
for (let x = 0; x < featW; x++) {
|
|
475
|
+
const cy = (y + 0.5) / featH;
|
|
476
|
+
const cx = (x + 0.5) / featW;
|
|
477
|
+
for (const sz of anchorSizes) {
|
|
478
|
+
layerAnchors.push({ cx, cy, w: sz.w, h: sz.h });
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
anchorsByLayer.push(layerAnchors);
|
|
483
|
+
}
|
|
484
|
+
return anchorsByLayer;
|
|
485
|
+
}
|
|
486
|
+
softmax(values) {
|
|
487
|
+
let max = -Infinity;
|
|
488
|
+
for (let i = 0; i < values.length; i++) {
|
|
489
|
+
if (values[i] > max)
|
|
490
|
+
max = values[i];
|
|
491
|
+
}
|
|
492
|
+
const exps = new Float32Array(values.length);
|
|
493
|
+
let sum = 0;
|
|
494
|
+
for (let i = 0; i < values.length; i++) {
|
|
495
|
+
const e = Math.exp(values[i] - max);
|
|
496
|
+
exps[i] = e;
|
|
497
|
+
sum += e;
|
|
498
|
+
}
|
|
499
|
+
if (sum === 0)
|
|
500
|
+
return exps;
|
|
501
|
+
for (let i = 0; i < exps.length; i++) {
|
|
502
|
+
exps[i] /= sum;
|
|
503
|
+
}
|
|
504
|
+
return exps;
|
|
505
|
+
}
|
|
506
|
+
/**
|
|
507
|
+
* Apply Non-Maximum Suppression to remove overlapping detections
|
|
508
|
+
*/
|
|
509
|
+
nonMaxSuppression(detections, iouThreshold = 0.5) {
|
|
510
|
+
if (detections.length === 0)
|
|
511
|
+
return [];
|
|
512
|
+
// Sort by confidence descending
|
|
513
|
+
const sorted = [...detections].sort((a, b) => b.confidence - a.confidence);
|
|
514
|
+
const kept = [];
|
|
515
|
+
for (const detection of sorted) {
|
|
516
|
+
// Check if this detection overlaps with any kept detection
|
|
517
|
+
let overlaps = false;
|
|
518
|
+
for (const kept_det of kept) {
|
|
519
|
+
const iou = this.calculateIoU(detection.bbox, kept_det.bbox);
|
|
520
|
+
if (iou > iouThreshold) {
|
|
521
|
+
overlaps = true;
|
|
522
|
+
break;
|
|
523
|
+
}
|
|
524
|
+
}
|
|
525
|
+
if (!overlaps) {
|
|
526
|
+
kept.push(detection);
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
return kept;
|
|
530
|
+
}
|
|
531
|
+
/**
|
|
532
|
+
* Calculate Intersection over Union (IoU) between two bounding boxes
|
|
533
|
+
* bbox format: [x, y, w, h]
|
|
534
|
+
*/
|
|
535
|
+
calculateIoU(bbox1, bbox2) {
|
|
536
|
+
const [x1, y1, w1, h1] = bbox1;
|
|
537
|
+
const [x2, y2, w2, h2] = bbox2;
|
|
538
|
+
// Convert to [x_min, y_min, x_max, y_max] format
|
|
539
|
+
const box1_x_min = x1;
|
|
540
|
+
const box1_y_min = y1;
|
|
541
|
+
const box1_x_max = x1 + w1;
|
|
542
|
+
const box1_y_max = y1 + h1;
|
|
543
|
+
const box2_x_min = x2;
|
|
544
|
+
const box2_y_min = y2;
|
|
545
|
+
const box2_x_max = x2 + w2;
|
|
546
|
+
const box2_y_max = y2 + h2;
|
|
547
|
+
// Calculate intersection area
|
|
548
|
+
const inter_x_min = Math.max(box1_x_min, box2_x_min);
|
|
549
|
+
const inter_y_min = Math.max(box1_y_min, box2_y_min);
|
|
550
|
+
const inter_x_max = Math.min(box1_x_max, box2_x_max);
|
|
551
|
+
const inter_y_max = Math.min(box1_y_max, box2_y_max);
|
|
552
|
+
const inter_width = Math.max(0, inter_x_max - inter_x_min);
|
|
553
|
+
const inter_height = Math.max(0, inter_y_max - inter_y_min);
|
|
554
|
+
const intersection = inter_width * inter_height;
|
|
555
|
+
// Calculate union area
|
|
556
|
+
const box1_area = w1 * h1;
|
|
557
|
+
const box2_area = w2 * h2;
|
|
558
|
+
const union = box1_area + box2_area - intersection;
|
|
559
|
+
// Avoid division by zero
|
|
560
|
+
if (union === 0)
|
|
561
|
+
return 0;
|
|
562
|
+
return intersection / union;
|
|
563
|
+
}
|
|
564
|
+
/**
|
|
565
|
+
* Postprocess classification output
|
|
566
|
+
*/
|
|
567
|
+
postprocessClassification(results, labels, confidenceThreshold, outputNames) {
|
|
568
|
+
const outputName = outputNames[0];
|
|
569
|
+
const logits = results[outputName].data;
|
|
570
|
+
const scores = this.softmax(logits);
|
|
571
|
+
// Create results for all classes, then filter by threshold and sort
|
|
572
|
+
const allResults = Array.from(scores)
|
|
573
|
+
.map((score, i) => ({
|
|
574
|
+
label: labels[i] || `class${i}`,
|
|
575
|
+
confidence: score,
|
|
576
|
+
}))
|
|
577
|
+
.filter((result) => result.confidence >= confidenceThreshold)
|
|
578
|
+
.sort((a, b) => b.confidence - a.confidence);
|
|
579
|
+
return allResults;
|
|
580
|
+
}
|
|
581
|
+
/**
|
|
582
|
+
* Postprocess face detection output.
|
|
583
|
+
*/
|
|
584
|
+
postprocessFaceDetection(results, confidenceThreshold = 0.5, modelInputSize) {
|
|
585
|
+
return this.postprocessSCRFDFaceDetection(results, confidenceThreshold, modelInputSize);
|
|
586
|
+
}
|
|
587
|
+
/**
|
|
588
|
+
* Postprocess face detection output from SCRFD-2.5G.
|
|
589
|
+
*/
|
|
590
|
+
postprocessSCRFDFaceDetection(results, confidenceThreshold = 0.5, modelInputSize) {
|
|
591
|
+
const [modelWidth, modelHeight] = modelInputSize || [640, 640];
|
|
592
|
+
const faces = [];
|
|
593
|
+
const scales = [
|
|
594
|
+
{ stride: 8, scoreKey: '446', bboxKey: '449', kpsKey: '452' },
|
|
595
|
+
{ stride: 16, scoreKey: '466', bboxKey: '469', kpsKey: '472' },
|
|
596
|
+
{ stride: 32, scoreKey: '486', bboxKey: '489', kpsKey: '492' },
|
|
597
|
+
];
|
|
598
|
+
logger.debug('Processing SCRFD-2.5G multi-scale output...');
|
|
599
|
+
for (const scale of scales) {
|
|
600
|
+
const scoreTensor = results[scale.scoreKey];
|
|
601
|
+
const bboxTensor = results[scale.bboxKey];
|
|
602
|
+
const kpsTensor = results[scale.kpsKey];
|
|
603
|
+
if (!scoreTensor || !bboxTensor) {
|
|
604
|
+
logger.warn(`Missing SCRFD tensors for stride ${scale.stride}`);
|
|
605
|
+
continue;
|
|
606
|
+
}
|
|
607
|
+
const scores = scoreTensor.data;
|
|
608
|
+
const bboxes = bboxTensor.data;
|
|
609
|
+
const kps = kpsTensor?.data;
|
|
610
|
+
const gridSize = modelWidth / scale.stride;
|
|
611
|
+
const numAnchors = 2;
|
|
612
|
+
for (let i = 0; i < scores.length; i++) {
|
|
613
|
+
const confidence = scores[i];
|
|
614
|
+
if (confidence < confidenceThreshold)
|
|
615
|
+
continue;
|
|
616
|
+
const anchorIndex = Math.floor(i / numAnchors);
|
|
617
|
+
const gridY = Math.floor(anchorIndex / gridSize);
|
|
618
|
+
const gridX = anchorIndex % gridSize;
|
|
619
|
+
const anchorX = (gridX + 0.5) * scale.stride;
|
|
620
|
+
const anchorY = (gridY + 0.5) * scale.stride;
|
|
621
|
+
const left = bboxes[i * 4 + 0] * scale.stride;
|
|
622
|
+
const top = bboxes[i * 4 + 1] * scale.stride;
|
|
623
|
+
const right = bboxes[i * 4 + 2] * scale.stride;
|
|
624
|
+
const bottom = bboxes[i * 4 + 3] * scale.stride;
|
|
625
|
+
const x1 = Math.max(0, anchorX - left);
|
|
626
|
+
const y1 = Math.max(0, anchorY - top);
|
|
627
|
+
const x2 = Math.min(modelWidth, anchorX + right);
|
|
628
|
+
const y2 = Math.min(modelHeight, anchorY + bottom);
|
|
629
|
+
if (x2 <= x1 || y2 <= y1)
|
|
630
|
+
continue;
|
|
631
|
+
const boxW = x2 - x1;
|
|
632
|
+
const boxH = y2 - y1;
|
|
633
|
+
const landmarks = [];
|
|
634
|
+
if (kps && kps.length >= i * 10 + 10) {
|
|
635
|
+
const landmarkTypes = ['eye-left', 'eye-right', 'nose', 'mouth-left', 'mouth-right'];
|
|
636
|
+
for (let j = 0; j < 5; j++) {
|
|
637
|
+
const kx = (kps[i * 10 + j * 2] * scale.stride + anchorX) / modelWidth;
|
|
638
|
+
const ky = (kps[i * 10 + j * 2 + 1] * scale.stride + anchorY) / modelHeight;
|
|
639
|
+
landmarks.push({
|
|
640
|
+
x: Math.min(1, Math.max(0, kx)),
|
|
641
|
+
y: Math.min(1, Math.max(0, ky)),
|
|
642
|
+
type: landmarkTypes[j],
|
|
643
|
+
});
|
|
644
|
+
}
|
|
645
|
+
}
|
|
646
|
+
faces.push({
|
|
647
|
+
boundingBox: [x1 / modelWidth, y1 / modelHeight, boxW / modelWidth, boxH / modelHeight],
|
|
648
|
+
confidence,
|
|
649
|
+
landmarks,
|
|
650
|
+
});
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
return this.applyNMS(faces, 0.45);
|
|
654
|
+
}
|
|
655
|
+
/**
|
|
656
|
+
* Preprocess face image for SCRFD input requirements.
|
|
657
|
+
*/
|
|
658
|
+
async preprocessFaceImage(image, size, _modelName) {
|
|
659
|
+
let imgBuf;
|
|
660
|
+
if (typeof image === 'string') {
|
|
661
|
+
imgBuf = fs.readFileSync(image);
|
|
662
|
+
}
|
|
663
|
+
else {
|
|
664
|
+
imgBuf = image;
|
|
665
|
+
}
|
|
666
|
+
const { data } = await sharp(imgBuf)
|
|
667
|
+
.resize(size[0], size[1])
|
|
668
|
+
.removeAlpha()
|
|
669
|
+
.raw()
|
|
670
|
+
.toBuffer({ resolveWithObject: true });
|
|
671
|
+
const [W, H] = size;
|
|
672
|
+
const input = new Float32Array(3 * H * W);
|
|
673
|
+
for (let y = 0; y < H; ++y) {
|
|
674
|
+
for (let x = 0; x < W; ++x) {
|
|
675
|
+
const offset = y * W * 3 + x * 3;
|
|
676
|
+
const r = data[offset] / 255.0;
|
|
677
|
+
const g = data[offset + 1] / 255.0;
|
|
678
|
+
const b = data[offset + 2] / 255.0;
|
|
679
|
+
input[0 * H * W + y * W + x] = b * 2.0 - 1.0;
|
|
680
|
+
input[1 * H * W + y * W + x] = g * 2.0 - 1.0;
|
|
681
|
+
input[2 * H * W + y * W + x] = r * 2.0 - 1.0;
|
|
682
|
+
}
|
|
683
|
+
}
|
|
684
|
+
return new ort.Tensor('float32', input, [1, 3, H, W]);
|
|
685
|
+
}
|
|
686
|
+
/**
|
|
687
|
+
* Apply Non-Maximum Suppression to remove overlapping face detections
|
|
688
|
+
* @param faces Array of detected faces
|
|
689
|
+
* @param iouThreshold IoU threshold for suppression (default 0.5)
|
|
690
|
+
* @returns Filtered array of non-overlapping faces
|
|
691
|
+
*/
|
|
692
|
+
applyNMS(faces, iouThreshold = 0.5) {
|
|
693
|
+
if (faces.length === 0)
|
|
694
|
+
return faces;
|
|
695
|
+
// Sort by confidence descending
|
|
696
|
+
const sortedFaces = [...faces].sort((a, b) => b.confidence - a.confidence);
|
|
697
|
+
const result = [];
|
|
698
|
+
const suppressed = new Array(sortedFaces.length).fill(false);
|
|
699
|
+
for (let i = 0; i < sortedFaces.length; i++) {
|
|
700
|
+
if (suppressed[i])
|
|
701
|
+
continue;
|
|
702
|
+
result.push(sortedFaces[i]);
|
|
703
|
+
// Suppress overlapping faces
|
|
704
|
+
for (let j = i + 1; j < sortedFaces.length; j++) {
|
|
705
|
+
if (suppressed[j])
|
|
706
|
+
continue;
|
|
707
|
+
const iou = this.computeIoU(sortedFaces[i].boundingBox, sortedFaces[j].boundingBox);
|
|
708
|
+
if (iou > iouThreshold) {
|
|
709
|
+
suppressed[j] = true;
|
|
710
|
+
}
|
|
711
|
+
}
|
|
712
|
+
}
|
|
713
|
+
return result;
|
|
714
|
+
}
|
|
715
|
+
/**
|
|
716
|
+
* Compute Intersection over Union (IoU) between two bounding boxes
|
|
717
|
+
* @param box1 [x, y, w, h]
|
|
718
|
+
* @param box2 [x, y, w, h]
|
|
719
|
+
* @returns IoU value between 0 and 1
|
|
720
|
+
*/
|
|
721
|
+
computeIoU(box1, box2) {
|
|
722
|
+
const [x1, y1, w1, h1] = box1;
|
|
723
|
+
const [x2, y2, w2, h2] = box2;
|
|
724
|
+
// Convert to [x_min, y_min, x_max, y_max] format
|
|
725
|
+
const x1_min = x1;
|
|
726
|
+
const y1_min = y1;
|
|
727
|
+
const x1_max = x1 + w1;
|
|
728
|
+
const y1_max = y1 + h1;
|
|
729
|
+
const x2_min = x2;
|
|
730
|
+
const y2_min = y2;
|
|
731
|
+
const x2_max = x2 + w2;
|
|
732
|
+
const y2_max = y2 + h2;
|
|
733
|
+
// Compute intersection
|
|
734
|
+
const inter_x_min = Math.max(x1_min, x2_min);
|
|
735
|
+
const inter_y_min = Math.max(y1_min, y2_min);
|
|
736
|
+
const inter_x_max = Math.min(x1_max, x2_max);
|
|
737
|
+
const inter_y_max = Math.min(y1_max, y2_max);
|
|
738
|
+
if (inter_x_min >= inter_x_max || inter_y_min >= inter_y_max) {
|
|
739
|
+
return 0; // No intersection
|
|
740
|
+
}
|
|
741
|
+
const interArea = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min);
|
|
742
|
+
const box1Area = w1 * h1;
|
|
743
|
+
const box2Area = w2 * h2;
|
|
744
|
+
const unionArea = box1Area + box2Area - interArea;
|
|
745
|
+
return interArea / unionArea;
|
|
746
|
+
}
|
|
747
|
+
/**
|
|
748
|
+
* Preprocess image to Float32 tensor for ONNX model
|
|
749
|
+
*/
|
|
750
|
+
async preprocessImage(image, size) {
|
|
751
|
+
let imgBuf;
|
|
752
|
+
if (typeof image === 'string') {
|
|
753
|
+
imgBuf = fs.readFileSync(image);
|
|
754
|
+
}
|
|
755
|
+
else {
|
|
756
|
+
imgBuf = image;
|
|
757
|
+
}
|
|
758
|
+
// Use sharp to resize and get raw RGB
|
|
759
|
+
const { data, info: _info } = await sharp(imgBuf)
|
|
760
|
+
.resize(size[0], size[1])
|
|
761
|
+
.removeAlpha()
|
|
762
|
+
.raw()
|
|
763
|
+
.toBuffer({ resolveWithObject: true });
|
|
764
|
+
// Normalize to [0,1] and convert to Float32Array
|
|
765
|
+
const float = new Float32Array(data.length);
|
|
766
|
+
for (let i = 0; i < data.length; ++i)
|
|
767
|
+
float[i] = data[i] / 255.0;
|
|
768
|
+
// ONNX expects NCHW: [1,3,H,W]
|
|
769
|
+
const [W, H] = size;
|
|
770
|
+
const input = new Float32Array(3 * H * W);
|
|
771
|
+
for (let y = 0; y < H; ++y) {
|
|
772
|
+
for (let x = 0; x < W; ++x) {
|
|
773
|
+
for (let c = 0; c < 3; ++c) {
|
|
774
|
+
input[c * H * W + y * W + x] = float[y * W * 3 + x * 3 + c];
|
|
775
|
+
}
|
|
776
|
+
}
|
|
777
|
+
}
|
|
778
|
+
return new ort.Tensor('float32', input, [1, 3, H, W]);
|
|
779
|
+
}
|
|
780
|
+
}
|
|
781
|
+
//# sourceMappingURL=onnx.js.map
|