@huggingface/transformers 3.0.1 → 3.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +14 -4
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +16607 -13472
- package/dist/transformers.cjs.map +1 -1
- package/dist/transformers.js +16601 -13451
- package/dist/transformers.js.map +1 -1
- package/dist/transformers.min.cjs +238 -52
- package/dist/transformers.min.cjs.map +1 -1
- package/dist/transformers.min.js +229 -43
- package/dist/transformers.min.js.map +1 -1
- package/dist/transformers.min.mjs +240 -54
- package/dist/transformers.min.mjs.map +1 -1
- package/dist/transformers.mjs +16017 -12878
- package/dist/transformers.mjs.map +1 -1
- package/package.json +7 -7
- package/src/base/feature_extraction_utils.js +54 -0
- package/src/base/image_processors_utils.js +1089 -0
- package/src/base/processing_utils.js +145 -0
- package/src/configs.js +15 -3
- package/src/env.js +15 -4
- package/src/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js +90 -0
- package/src/models/auto/feature_extraction_auto.js +41 -0
- package/src/models/auto/image_processing_auto.js +29 -0
- package/src/models/auto/processing_auto.js +100 -0
- package/src/models/beit/image_processing_beit.js +5 -0
- package/src/models/bit/image_processing_bit.js +5 -0
- package/src/models/chinese_clip/image_processing_chinese_clip.js +5 -0
- package/src/models/clap/feature_extraction_clap.js +159 -0
- package/src/models/clip/image_processing_clip.js +6 -0
- package/src/models/convnext/image_processing_convnext.js +45 -0
- package/src/models/deit/image_processing_deit.js +6 -0
- package/src/models/detr/image_processing_detr.js +52 -0
- package/src/models/donut/image_processing_donut.js +31 -0
- package/src/models/dpt/image_processing_dpt.js +6 -0
- package/src/models/efficientnet/image_processing_efficientnet.js +13 -0
- package/src/models/feature_extractors.js +12 -0
- package/src/models/florence2/processing_florence2.js +128 -0
- package/src/models/glpn/image_processing_glpn.js +5 -0
- package/src/models/image_processors.js +36 -0
- package/src/models/janus/image_processing_janus.js +26 -0
- package/src/models/janus/processing_janus.js +123 -0
- package/src/models/jina_clip/image_processing_jina_clip.js +26 -0
- package/src/models/jina_clip/processing_jina_clip.js +24 -0
- package/src/models/llava_onevision/image_processing_llava_onevision.js +5 -0
- package/src/models/mask2former/image_processing_mask2former.js +5 -0
- package/src/models/maskformer/image_processing_maskformer.js +18 -0
- package/src/models/mgp_str/processing_mgp_str.js +170 -0
- package/src/models/mobilenet_v1/image_processing_mobilenet_v1.js +7 -0
- package/src/models/mobilenet_v2/image_processing_mobilenet_v2.js +7 -0
- package/src/models/mobilenet_v3/image_processing_mobilenet_v3.js +7 -0
- package/src/models/mobilenet_v4/image_processing_mobilenet_v4.js +7 -0
- package/src/models/mobilevit/image_processing_mobilevit.js +6 -0
- package/src/models/nougat/image_processing_nougat.js +5 -0
- package/src/models/owlv2/image_processing_owlv2.js +5 -0
- package/src/models/owlvit/image_processing_owlvit.js +12 -0
- package/src/models/owlvit/processing_owlvit.js +7 -0
- package/src/models/processors.js +11 -0
- package/src/models/pvt/image_processing_pvt.js +5 -0
- package/src/models/pyannote/feature_extraction_pyannote.js +28 -0
- package/src/models/pyannote/processing_pyannote.js +71 -0
- package/src/models/qwen2_vl/image_processing_qwen2_vl.js +52 -0
- package/src/models/qwen2_vl/processing_qwen2_vl.js +52 -0
- package/src/models/rt_detr/image_processing_rt_detr.js +12 -0
- package/src/models/sam/image_processing_sam.js +242 -0
- package/src/models/sam/processing_sam.js +20 -0
- package/src/models/sapiens/image_processing_sapiens.js +13 -0
- package/src/models/seamless_m4t/feature_extraction_seamless_m4t.js +180 -0
- package/src/models/segformer/image_processing_segformer.js +13 -0
- package/src/models/siglip/image_processing_siglip.js +5 -0
- package/src/models/speecht5/feature_extraction_speecht5.js +4 -0
- package/src/models/speecht5/processing_speecht5.js +17 -0
- package/src/models/swin2sr/image_processing_swin2sr.js +24 -0
- package/src/models/vit/image_processing_vit.js +7 -0
- package/src/models/vitmatte/image_processing_vitmatte.js +50 -0
- package/src/models/vitpose/image_processing_vitpose.js +89 -0
- package/src/models/wav2vec2/feature_extraction_wav2vec2.js +44 -0
- package/src/models/wav2vec2/processing_wav2vec2.js +15 -0
- package/src/models/wespeaker/feature_extraction_wespeaker.js +100 -0
- package/src/models/whisper/feature_extraction_whisper.js +84 -0
- package/src/models/whisper/processing_whisper.js +21 -0
- package/src/models/yolos/image_processing_yolos.js +12 -0
- package/src/models.js +695 -32
- package/src/pipelines.js +8 -8
- package/src/tokenizers.js +5 -0
- package/src/transformers.js +15 -2
- package/src/utils/constants.js +8 -1
- package/src/utils/core.js +37 -9
- package/src/utils/hub.js +2 -1
- package/src/utils/image.js +68 -17
- package/src/utils/tensor.js +33 -1
- package/types/base/feature_extraction_utils.d.ts +41 -0
- package/types/base/feature_extraction_utils.d.ts.map +1 -0
- package/types/base/image_processors_utils.d.ts +323 -0
- package/types/base/image_processors_utils.d.ts.map +1 -0
- package/types/base/processing_utils.d.ts +80 -0
- package/types/base/processing_utils.d.ts.map +1 -0
- package/types/configs.d.ts +4 -1
- package/types/configs.d.ts.map +1 -1
- package/types/env.d.ts.map +1 -1
- package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts +25 -0
- package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts.map +1 -0
- package/types/models/auto/feature_extraction_auto.d.ts +5 -0
- package/types/models/auto/feature_extraction_auto.d.ts.map +1 -0
- package/types/models/auto/image_processing_auto.d.ts +5 -0
- package/types/models/auto/image_processing_auto.d.ts.map +1 -0
- package/types/models/auto/processing_auto.d.ts +35 -0
- package/types/models/auto/processing_auto.d.ts.map +1 -0
- package/types/models/beit/image_processing_beit.d.ts +4 -0
- package/types/models/beit/image_processing_beit.d.ts.map +1 -0
- package/types/models/bit/image_processing_bit.d.ts +4 -0
- package/types/models/bit/image_processing_bit.d.ts.map +1 -0
- package/types/models/chinese_clip/image_processing_chinese_clip.d.ts +4 -0
- package/types/models/chinese_clip/image_processing_chinese_clip.d.ts.map +1 -0
- package/types/models/clap/feature_extraction_clap.d.ts +57 -0
- package/types/models/clap/feature_extraction_clap.d.ts.map +1 -0
- package/types/models/clip/image_processing_clip.d.ts +6 -0
- package/types/models/clip/image_processing_clip.d.ts.map +1 -0
- package/types/models/convnext/image_processing_convnext.d.ts +12 -0
- package/types/models/convnext/image_processing_convnext.d.ts.map +1 -0
- package/types/models/deit/image_processing_deit.d.ts +6 -0
- package/types/models/deit/image_processing_deit.d.ts.map +1 -0
- package/types/models/detr/image_processing_detr.d.ts +42 -0
- package/types/models/detr/image_processing_detr.d.ts.map +1 -0
- package/types/models/donut/image_processing_donut.d.ts +7 -0
- package/types/models/donut/image_processing_donut.d.ts.map +1 -0
- package/types/models/dpt/image_processing_dpt.d.ts +6 -0
- package/types/models/dpt/image_processing_dpt.d.ts.map +1 -0
- package/types/models/efficientnet/image_processing_efficientnet.d.ts +6 -0
- package/types/models/efficientnet/image_processing_efficientnet.d.ts.map +1 -0
- package/types/models/feature_extractors.d.ts +10 -0
- package/types/models/feature_extractors.d.ts.map +1 -0
- package/types/models/florence2/processing_florence2.d.ts +39 -0
- package/types/models/florence2/processing_florence2.d.ts.map +1 -0
- package/types/models/glpn/image_processing_glpn.d.ts +4 -0
- package/types/models/glpn/image_processing_glpn.d.ts.map +1 -0
- package/types/models/image_processors.d.ts +36 -0
- package/types/models/image_processors.d.ts.map +1 -0
- package/types/models/janus/image_processing_janus.d.ts +7 -0
- package/types/models/janus/image_processing_janus.d.ts.map +1 -0
- package/types/models/janus/processing_janus.d.ts +77 -0
- package/types/models/janus/processing_janus.d.ts.map +1 -0
- package/types/models/jina_clip/image_processing_jina_clip.d.ts +5 -0
- package/types/models/jina_clip/image_processing_jina_clip.d.ts.map +1 -0
- package/types/models/jina_clip/processing_jina_clip.d.ts +9 -0
- package/types/models/jina_clip/processing_jina_clip.d.ts.map +1 -0
- package/types/models/llava_onevision/image_processing_llava_onevision.d.ts +4 -0
- package/types/models/llava_onevision/image_processing_llava_onevision.d.ts.map +1 -0
- package/types/models/mask2former/image_processing_mask2former.d.ts +4 -0
- package/types/models/mask2former/image_processing_mask2former.d.ts.map +1 -0
- package/types/models/maskformer/image_processing_maskformer.d.ts +22 -0
- package/types/models/maskformer/image_processing_maskformer.d.ts.map +1 -0
- package/types/models/mgp_str/processing_mgp_str.d.ts +64 -0
- package/types/models/mgp_str/processing_mgp_str.d.ts.map +1 -0
- package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts +6 -0
- package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts.map +1 -0
- package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts +6 -0
- package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts.map +1 -0
- package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts +6 -0
- package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts.map +1 -0
- package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts +6 -0
- package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts.map +1 -0
- package/types/models/mobilevit/image_processing_mobilevit.d.ts +6 -0
- package/types/models/mobilevit/image_processing_mobilevit.d.ts.map +1 -0
- package/types/models/nougat/image_processing_nougat.d.ts +4 -0
- package/types/models/nougat/image_processing_nougat.d.ts.map +1 -0
- package/types/models/owlv2/image_processing_owlv2.d.ts +4 -0
- package/types/models/owlv2/image_processing_owlv2.d.ts.map +1 -0
- package/types/models/owlvit/image_processing_owlvit.d.ts +10 -0
- package/types/models/owlvit/image_processing_owlvit.d.ts.map +1 -0
- package/types/models/owlvit/processing_owlvit.d.ts +8 -0
- package/types/models/owlvit/processing_owlvit.d.ts.map +1 -0
- package/types/models/processors.d.ts +12 -0
- package/types/models/processors.d.ts.map +1 -0
- package/types/models/pvt/image_processing_pvt.d.ts +4 -0
- package/types/models/pvt/image_processing_pvt.d.ts.map +1 -0
- package/types/models/pyannote/feature_extraction_pyannote.d.ts +13 -0
- package/types/models/pyannote/feature_extraction_pyannote.d.ts.map +1 -0
- package/types/models/pyannote/processing_pyannote.d.ts +30 -0
- package/types/models/pyannote/processing_pyannote.d.ts.map +1 -0
- package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts +11 -0
- package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts.map +1 -0
- package/types/models/qwen2_vl/processing_qwen2_vl.d.ts +17 -0
- package/types/models/qwen2_vl/processing_qwen2_vl.d.ts.map +1 -0
- package/types/models/rt_detr/image_processing_rt_detr.d.ts +8 -0
- package/types/models/rt_detr/image_processing_rt_detr.d.ts.map +1 -0
- package/types/models/sam/image_processing_sam.d.ts +103 -0
- package/types/models/sam/image_processing_sam.d.ts.map +1 -0
- package/types/models/sam/processing_sam.d.ts +9 -0
- package/types/models/sam/processing_sam.d.ts.map +1 -0
- package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts +34 -0
- package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts.map +1 -0
- package/types/models/segformer/image_processing_segformer.d.ts +10 -0
- package/types/models/segformer/image_processing_segformer.d.ts.map +1 -0
- package/types/models/siglip/image_processing_siglip.d.ts +4 -0
- package/types/models/siglip/image_processing_siglip.d.ts.map +1 -0
- package/types/models/speecht5/feature_extraction_speecht5.d.ts +4 -0
- package/types/models/speecht5/feature_extraction_speecht5.d.ts.map +1 -0
- package/types/models/speecht5/processing_speecht5.d.ts +14 -0
- package/types/models/speecht5/processing_speecht5.d.ts.map +1 -0
- package/types/models/swin2sr/image_processing_swin2sr.d.ts +5 -0
- package/types/models/swin2sr/image_processing_swin2sr.d.ts.map +1 -0
- package/types/models/vit/image_processing_vit.d.ts +6 -0
- package/types/models/vit/image_processing_vit.d.ts.map +1 -0
- package/types/models/vitmatte/image_processing_vitmatte.d.ts +12 -0
- package/types/models/vitmatte/image_processing_vitmatte.d.ts.map +1 -0
- package/types/models/vitpose/image_processing_vitpose.d.ts +26 -0
- package/types/models/vitpose/image_processing_vitpose.d.ts.map +1 -0
- package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts +19 -0
- package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts.map +1 -0
- package/types/models/wav2vec2/processing_wav2vec2.d.ts +12 -0
- package/types/models/wav2vec2/processing_wav2vec2.d.ts.map +1 -0
- package/types/models/wespeaker/feature_extraction_wespeaker.d.ts +23 -0
- package/types/models/wespeaker/feature_extraction_wespeaker.d.ts.map +1 -0
- package/types/models/whisper/feature_extraction_whisper.d.ts +21 -0
- package/types/models/whisper/feature_extraction_whisper.d.ts.map +1 -0
- package/types/models/whisper/processing_whisper.d.ts +17 -0
- package/types/models/whisper/processing_whisper.d.ts.map +1 -0
- package/types/models/yolos/image_processing_yolos.d.ts +10 -0
- package/types/models/yolos/image_processing_yolos.d.ts.map +1 -0
- package/types/models.d.ts +152 -0
- package/types/models.d.ts.map +1 -1
- package/types/pipelines.d.ts +2 -3
- package/types/pipelines.d.ts.map +1 -1
- package/types/tokenizers.d.ts +3 -0
- package/types/tokenizers.d.ts.map +1 -1
- package/types/transformers.d.ts +10 -1
- package/types/utils/constants.d.ts +6 -0
- package/types/utils/constants.d.ts.map +1 -1
- package/types/utils/core.d.ts +58 -3
- package/types/utils/core.d.ts.map +1 -1
- package/types/utils/hub.d.ts +1 -1
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/image.d.ts +10 -2
- package/types/utils/image.d.ts.map +1 -1
- package/types/utils/tensor.d.ts +34 -1
- package/types/utils/tensor.d.ts.map +1 -1
- package/src/processors.js +0 -2655
- package/types/processors.d.ts +0 -924
- package/types/processors.d.ts.map +0 -1
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ImageProcessor,
|
|
3
|
+
} from "../../base/image_processors_utils.js";
|
|
4
|
+
import { calculateDimensions } from "../../utils/core.js";
|
|
5
|
+
|
|
6
|
+
import {
|
|
7
|
+
interpolate_4d,
|
|
8
|
+
Tensor,
|
|
9
|
+
} from "../../utils/tensor.js";
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* @typedef {object} SamImageProcessorResult
|
|
14
|
+
* @property {Tensor} pixel_values
|
|
15
|
+
* @property {import("../../base/image_processors_utils.js").HeightWidth[]} original_sizes
|
|
16
|
+
* @property {import("../../base/image_processors_utils.js").HeightWidth[]} reshaped_input_sizes
|
|
17
|
+
* @property {Tensor} [input_points]
|
|
18
|
+
* @property {Tensor} [input_labels]
|
|
19
|
+
* @property {Tensor} [input_boxes]
|
|
20
|
+
*/
|
|
21
|
+
|
|
22
|
+
export class SamImageProcessor extends ImageProcessor {
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
*
|
|
26
|
+
* @param {any} input_points
|
|
27
|
+
* @param {import("../../base/image_processors_utils.js").HeightWidth[]} original_sizes
|
|
28
|
+
* @param {import("../../base/image_processors_utils.js").HeightWidth[]} reshaped_input_sizes
|
|
29
|
+
* @returns {Tensor}
|
|
30
|
+
*/
|
|
31
|
+
reshape_input_points(input_points, original_sizes, reshaped_input_sizes, is_bounding_box = false) {
|
|
32
|
+
|
|
33
|
+
// Make deep copy to avoid altering user's input
|
|
34
|
+
input_points = structuredClone(input_points);
|
|
35
|
+
let shape = calculateDimensions(input_points);
|
|
36
|
+
|
|
37
|
+
// TODO: add support for 2D input_points
|
|
38
|
+
if (shape.length === 3) {
|
|
39
|
+
// Correct user's input
|
|
40
|
+
if (!is_bounding_box) {
|
|
41
|
+
shape = [1, ...shape];
|
|
42
|
+
}
|
|
43
|
+
input_points = [input_points];
|
|
44
|
+
} else if (shape.length !== 4) {
|
|
45
|
+
throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.")
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
// Reshape input points
|
|
49
|
+
for (let i = 0; i < input_points.length; ++i) { // batch_size
|
|
50
|
+
let originalImageSize = original_sizes[i];
|
|
51
|
+
let reshapedImageSize = reshaped_input_sizes[i];
|
|
52
|
+
|
|
53
|
+
let resizeFactors = [
|
|
54
|
+
reshapedImageSize[0] / originalImageSize[0],
|
|
55
|
+
reshapedImageSize[1] / originalImageSize[1]
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
for (let j = 0; j < input_points[i].length; ++j) { // point_batch_size
|
|
59
|
+
for (let k = 0; k < input_points[i][j].length; ++k) { // nb_points_per_image
|
|
60
|
+
for (let w = 0; w < input_points[i][j][k].length; ++w) { // 2 or 4
|
|
61
|
+
input_points[i][j][k][w] *= resizeFactors[w % 2];
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
return new Tensor(
|
|
68
|
+
'float32',
|
|
69
|
+
Float32Array.from(input_points.flat(Infinity)),
|
|
70
|
+
shape
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
*
|
|
77
|
+
* @param {any} input_labels
|
|
78
|
+
* @param {Tensor} input_points
|
|
79
|
+
* @returns {Tensor}
|
|
80
|
+
*/
|
|
81
|
+
add_input_labels(input_labels, input_points) {
|
|
82
|
+
let shape = calculateDimensions(input_labels);
|
|
83
|
+
if (shape.length === 2) {
|
|
84
|
+
// Correct user's input
|
|
85
|
+
shape = [1, ...shape];
|
|
86
|
+
input_labels = [input_labels];
|
|
87
|
+
} else if (shape.length !== 3) {
|
|
88
|
+
throw Error("The input_points must be a 4D tensor of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.")
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if (shape.some((x, i) => x !== input_points.dims[i])) {
|
|
92
|
+
throw Error(`The first ${shape.length} dimensions of 'input_points' and 'input_labels' must be the same.`)
|
|
93
|
+
}
|
|
94
|
+
return new Tensor(
|
|
95
|
+
'int64',
|
|
96
|
+
input_labels.flat(Infinity).map(BigInt),
|
|
97
|
+
shape,
|
|
98
|
+
)
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* @param {any[]} images The URL(s) of the image(s) to extract features from.
|
|
102
|
+
* @param {Object} [options] Additional options for the processor.
|
|
103
|
+
* @param {any} [options.input_points=null] A 3D or 4D array, representing the input points provided by the user.
|
|
104
|
+
* - 3D: `[point_batch_size, nb_points_per_image, 2]`. In this case, `batch_size` is assumed to be 1.
|
|
105
|
+
* - 4D: `[batch_size, point_batch_size, nb_points_per_image, 2]`.
|
|
106
|
+
* @param {any} [options.input_labels=null] A 2D or 3D array, representing the input labels for the points, used by the prompt encoder to encode the prompt.
|
|
107
|
+
* - 2D: `[point_batch_size, nb_points_per_image]`. In this case, `batch_size` is assumed to be 1.
|
|
108
|
+
* - 3D: `[batch_size, point_batch_size, nb_points_per_image]`.
|
|
109
|
+
* @param {number[][][]} [options.input_boxes=null] A 3D array of shape `(batch_size, num_boxes, 4)`, representing the input boxes provided by the user.
|
|
110
|
+
* This is used by the prompt encoder to encode the prompt. Generally yields to much better generated masks.
|
|
111
|
+
* The processor will generate a tensor, with each dimension corresponding respectively to the image batch size,
|
|
112
|
+
* the number of boxes per image and the coordinates of the top left and botton right point of the box.
|
|
113
|
+
* In the order (`x1`, `y1`, `x2`, `y2`):
|
|
114
|
+
* - `x1`: the x coordinate of the top left point of the input box
|
|
115
|
+
* - `y1`: the y coordinate of the top left point of the input box
|
|
116
|
+
* - `x2`: the x coordinate of the bottom right point of the input box
|
|
117
|
+
* - `y2`: the y coordinate of the bottom right point of the input box
|
|
118
|
+
* @returns {Promise<SamImageProcessorResult>}
|
|
119
|
+
*/
|
|
120
|
+
async _call(images, {
|
|
121
|
+
input_points = null,
|
|
122
|
+
input_labels = null,
|
|
123
|
+
input_boxes = null
|
|
124
|
+
} = {}) {
|
|
125
|
+
// TODO allow user to use preprocessed images
|
|
126
|
+
/** @type {SamImageProcessorResult} */
|
|
127
|
+
const processed = await super._call(images);
|
|
128
|
+
|
|
129
|
+
if (input_points) {
|
|
130
|
+
processed.input_points = this.reshape_input_points(
|
|
131
|
+
input_points, processed.original_sizes, processed.reshaped_input_sizes
|
|
132
|
+
);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
if (input_labels) {
|
|
136
|
+
if (!processed.input_points) {
|
|
137
|
+
throw Error("`input_points` must be provided if `input_labels` are provided.")
|
|
138
|
+
}
|
|
139
|
+
processed.input_labels = this.add_input_labels(input_labels, processed.input_points);
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
if (input_boxes) {
|
|
143
|
+
processed.input_boxes = this.reshape_input_points(
|
|
144
|
+
input_boxes, processed.original_sizes, processed.reshaped_input_sizes, true,
|
|
145
|
+
);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
return processed;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
/**
|
|
152
|
+
* Remove padding and upscale masks to the original image size.
|
|
153
|
+
* @param {Tensor} masks Batched masks from the mask_decoder in (batch_size, num_channels, height, width) format.
|
|
154
|
+
* @param {[number, number][]} original_sizes The original sizes of each image before it was resized to the model's expected input shape, in (height, width) format.
|
|
155
|
+
* @param {[number, number][]} reshaped_input_sizes The size of each image as it is fed to the model, in (height, width) format. Used to remove padding.
|
|
156
|
+
* @param {Object} options Optional parameters for post-processing.
|
|
157
|
+
* @param {number} [options.mask_threshold] The threshold to use for binarizing the masks.
|
|
158
|
+
* @param {boolean} [options.binarize] Whether to binarize the masks.
|
|
159
|
+
* @param {Object} [options.pad_size] The target size the images were padded to before being passed to the model. If `null`, the target size is assumed to be the processor's `pad_size`.
|
|
160
|
+
* @param {number} [options.pad_size.height] The height the images were padded to.
|
|
161
|
+
* @param {number} [options.pad_size.width] The width the images were padded to.
|
|
162
|
+
* @returns {Promise<Tensor[]>} Batched masks in batch_size, num_channels, height, width) format, where (height, width) is given by original_size.
|
|
163
|
+
*/
|
|
164
|
+
async post_process_masks(masks, original_sizes, reshaped_input_sizes, {
|
|
165
|
+
mask_threshold = 0.0,
|
|
166
|
+
binarize = true,
|
|
167
|
+
pad_size = null,
|
|
168
|
+
} = {}) {
|
|
169
|
+
// masks: [1, 1, 3, 256, 256]
|
|
170
|
+
|
|
171
|
+
const output_masks = [];
|
|
172
|
+
|
|
173
|
+
pad_size = pad_size ?? this.pad_size;
|
|
174
|
+
|
|
175
|
+
/** @type {[number, number]} */
|
|
176
|
+
const target_image_size = [pad_size.height, pad_size.width];
|
|
177
|
+
|
|
178
|
+
for (let i = 0; i < original_sizes.length; ++i) {
|
|
179
|
+
const original_size = original_sizes[i];
|
|
180
|
+
const reshaped_input_size = reshaped_input_sizes[i];
|
|
181
|
+
|
|
182
|
+
// Upscale mask to padded size
|
|
183
|
+
let interpolated_mask = (await interpolate_4d(
|
|
184
|
+
masks[i],
|
|
185
|
+
{ mode: 'bilinear', size: target_image_size }
|
|
186
|
+
));
|
|
187
|
+
|
|
188
|
+
// Crop mask
|
|
189
|
+
interpolated_mask = interpolated_mask.slice(null, null, [0, reshaped_input_size[0]], [0, reshaped_input_size[1]]);
|
|
190
|
+
|
|
191
|
+
// Downscale mask
|
|
192
|
+
interpolated_mask = (await interpolate_4d(
|
|
193
|
+
interpolated_mask,
|
|
194
|
+
{ mode: 'bilinear', size: original_size }
|
|
195
|
+
));
|
|
196
|
+
|
|
197
|
+
if (binarize) {
|
|
198
|
+
const data = interpolated_mask.data;
|
|
199
|
+
const binarizedMaskData = new Uint8Array(data.length);
|
|
200
|
+
for (let i = 0; i < data.length; ++i) {
|
|
201
|
+
if (data[i] > mask_threshold) {
|
|
202
|
+
binarizedMaskData[i] = 1;
|
|
203
|
+
}
|
|
204
|
+
}
|
|
205
|
+
interpolated_mask = new Tensor(
|
|
206
|
+
'bool',
|
|
207
|
+
binarizedMaskData,
|
|
208
|
+
interpolated_mask.dims
|
|
209
|
+
)
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
output_masks.push(interpolated_mask);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
return output_masks;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
/**
|
|
219
|
+
* Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer.
|
|
220
|
+
* @param {import("../../utils/image.js").RawImage} image Input original image
|
|
221
|
+
* @param {number} target_size Target size of the resized image
|
|
222
|
+
* @param {Object} options Options for generating crop boxes
|
|
223
|
+
* @param {number} [options.crop_n_layers] If >0, mask prediction will be run again on crops of the image.
|
|
224
|
+
* Sets the number of layers to run, where each layer has 2**i_layer number of image crops.
|
|
225
|
+
* @param {number} [options.overlap_ratio] Sets the degree to which crops overlap. In the first crop layer,
|
|
226
|
+
* crops will overlap by this fraction of the image length. Later layers with more crops scale down this overlap.
|
|
227
|
+
* @param {number} [options.points_per_crop] Number of points to sample from each crop.
|
|
228
|
+
* @param {number} [options.crop_n_points_downscale_factor] The number of points-per-side sampled in layer n is
|
|
229
|
+
* scaled down by crop_n_points_downscale_factor**n.
|
|
230
|
+
* @returns {Object} An object containing the crop boxes, number of points per crop, cropped images, and input labels.
|
|
231
|
+
*/
|
|
232
|
+
generate_crop_boxes(image, target_size, {
|
|
233
|
+
crop_n_layers = 0,
|
|
234
|
+
overlap_ratio = 512 / 1500,
|
|
235
|
+
points_per_crop = 32,
|
|
236
|
+
crop_n_points_downscale_factor = 1,
|
|
237
|
+
} = {}) {
|
|
238
|
+
// TODO: Implement
|
|
239
|
+
// return { crop_boxes, points_per_crop, cropped_images, input_labels }
|
|
240
|
+
}
|
|
241
|
+
}
|
|
242
|
+
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import { Processor } from "../../base/processing_utils.js";
|
|
2
|
+
import { AutoImageProcessor } from "../auto/image_processing_auto.js";
|
|
3
|
+
|
|
4
|
+
export class SamProcessor extends Processor {
|
|
5
|
+
static image_processor_class = AutoImageProcessor
|
|
6
|
+
|
|
7
|
+
async _call(...args) {
|
|
8
|
+
return await this.image_processor(...args);
|
|
9
|
+
}
|
|
10
|
+
|
|
11
|
+
post_process_masks(...args) {
|
|
12
|
+
// @ts-ignore
|
|
13
|
+
return this.image_processor.post_process_masks(...args);
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
reshape_input_points(...args) {
|
|
17
|
+
// @ts-ignore
|
|
18
|
+
return this.image_processor.reshape_input_points(...args);
|
|
19
|
+
}
|
|
20
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ImageProcessor,
|
|
3
|
+
post_process_semantic_segmentation,
|
|
4
|
+
} from "../../base/image_processors_utils.js";
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
export class SapiensImageProcessor extends ImageProcessor {
|
|
8
|
+
/** @type {typeof post_process_semantic_segmentation} */
|
|
9
|
+
post_process_semantic_segmentation(...args) {
|
|
10
|
+
return post_process_semantic_segmentation(...args);
|
|
11
|
+
}
|
|
12
|
+
}
|
|
13
|
+
export class SapiensFeatureExtractor extends SapiensImageProcessor { }
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
import { FeatureExtractor, validate_audio_inputs } from '../../base/feature_extraction_utils.js';
|
|
2
|
+
import { Tensor } from '../../utils/tensor.js';
|
|
3
|
+
import { mel_filter_bank, spectrogram, window_function } from '../../utils/audio.js';
|
|
4
|
+
|
|
5
|
+
export class SeamlessM4TFeatureExtractor extends FeatureExtractor {
|
|
6
|
+
|
|
7
|
+
constructor(config) {
|
|
8
|
+
super(config);
|
|
9
|
+
|
|
10
|
+
const sampling_rate = this.config.sampling_rate;
|
|
11
|
+
const mel_filters = mel_filter_bank(
|
|
12
|
+
256, // num_frequency_bins
|
|
13
|
+
this.config.num_mel_bins, // num_mel_filters
|
|
14
|
+
20, // min_frequency
|
|
15
|
+
Math.floor(sampling_rate / 2), // max_frequency
|
|
16
|
+
sampling_rate, // sampling_rate
|
|
17
|
+
null, // norm
|
|
18
|
+
"kaldi", // mel_scale
|
|
19
|
+
true, // triangularize_in_mel_space
|
|
20
|
+
);
|
|
21
|
+
|
|
22
|
+
// Do padding:
|
|
23
|
+
for (let i = 0; i < mel_filters.length; ++i) {
|
|
24
|
+
mel_filters[i].push(0);
|
|
25
|
+
}
|
|
26
|
+
this.mel_filters = mel_filters;
|
|
27
|
+
|
|
28
|
+
this.window = window_function(400, 'povey', {
|
|
29
|
+
periodic: false,
|
|
30
|
+
})
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
/**
|
|
34
|
+
* Computes the log-Mel spectrogram of the provided audio waveform.
|
|
35
|
+
* @param {Float32Array|Float64Array} waveform The audio waveform to process.
|
|
36
|
+
* @param {number} max_length The maximum number of frames to return.
|
|
37
|
+
* @returns {Promise<Tensor>} An object containing the log-Mel spectrogram data as a Float32Array and its dimensions as an array of numbers.
|
|
38
|
+
*/
|
|
39
|
+
async _extract_fbank_features(waveform, max_length) {
|
|
40
|
+
// NOTE: We don't pad/truncate since that is passed in as `max_num_frames`
|
|
41
|
+
|
|
42
|
+
// Kaldi compliance: 16-bit signed integers
|
|
43
|
+
// 32768 == 2 ** 15
|
|
44
|
+
waveform = waveform.map((/** @type {number} */ x) => x * 32768)
|
|
45
|
+
|
|
46
|
+
return spectrogram(
|
|
47
|
+
waveform,
|
|
48
|
+
this.window, // window
|
|
49
|
+
400, // frame_length
|
|
50
|
+
160, // hop_length
|
|
51
|
+
{
|
|
52
|
+
fft_length: 512,
|
|
53
|
+
power: 2.0,
|
|
54
|
+
center: false,
|
|
55
|
+
preemphasis: 0.97,
|
|
56
|
+
mel_filters: this.mel_filters,
|
|
57
|
+
log_mel: 'log',
|
|
58
|
+
mel_floor: 1.192092955078125e-07,
|
|
59
|
+
remove_dc_offset: true,
|
|
60
|
+
|
|
61
|
+
// Custom
|
|
62
|
+
max_num_frames: max_length,
|
|
63
|
+
transpose: true,
|
|
64
|
+
}
|
|
65
|
+
)
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
/**
|
|
69
|
+
* Asynchronously extracts features from a given audio using the provided configuration.
|
|
70
|
+
* @param {Float32Array|Float64Array} audio The audio data as a Float32Array/Float64Array.
|
|
71
|
+
* @param {Object} options Optional parameters for feature extraction.
|
|
72
|
+
* @param {boolean} [options.padding=true] Whether to pad the sequence to a multiple of `pad_to_multiple_of`.
|
|
73
|
+
* @param {number} [options.pad_to_multiple_of=2] The number to pad the sequence to a multiple of.
|
|
74
|
+
* @param {boolean} [options.do_normalize_per_mel_bins=true] Whether or not to zero-mean unit-variance normalize the input per mel-channel.
|
|
75
|
+
* @param {boolean} [options.return_attention_mask=true] Whether to return the attention mask.
|
|
76
|
+
* @returns {Promise<{ input_features: Tensor, attention_mask?: Tensor }>} A Promise resolving to an object containing the extracted input features and attention masks as Tensors.
|
|
77
|
+
*/
|
|
78
|
+
async _call(audio, {
|
|
79
|
+
padding = true,
|
|
80
|
+
pad_to_multiple_of = 2,
|
|
81
|
+
do_normalize_per_mel_bins = true,
|
|
82
|
+
return_attention_mask = true,
|
|
83
|
+
} = {}) {
|
|
84
|
+
validate_audio_inputs(audio, 'SeamlessM4TFeatureExtractor');
|
|
85
|
+
|
|
86
|
+
let features = await this._extract_fbank_features(audio, this.config.max_length);
|
|
87
|
+
|
|
88
|
+
if (do_normalize_per_mel_bins) {
|
|
89
|
+
const [num_features, feature_size] = features.dims;
|
|
90
|
+
const data = features.data;
|
|
91
|
+
for (let i = 0; i < feature_size; ++i) {
|
|
92
|
+
let sum = 0;
|
|
93
|
+
for (let j = 0; j < num_features; ++j) {
|
|
94
|
+
sum += data[j * feature_size + i];
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
const mean = sum / num_features;
|
|
98
|
+
|
|
99
|
+
let variance = 0;
|
|
100
|
+
for (let j = 0; j < num_features; ++j) {
|
|
101
|
+
variance += (data[j * feature_size + i] - mean) ** 2;
|
|
102
|
+
}
|
|
103
|
+
variance /= num_features - 1; // NOTE: We use ddof=1
|
|
104
|
+
|
|
105
|
+
const std = Math.sqrt(variance + 1e-7);
|
|
106
|
+
for (let j = 0; j < num_features; ++j) {
|
|
107
|
+
const index = j * feature_size + i;
|
|
108
|
+
data[index] = (data[index] - mean) / std;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
let padded_attention_mask;
|
|
114
|
+
if (padding) {
|
|
115
|
+
const [num_frames, num_channels] = features.dims;
|
|
116
|
+
const data = /** @type {Float32Array} */(features.data);
|
|
117
|
+
|
|
118
|
+
const pad_size = num_frames % pad_to_multiple_of;
|
|
119
|
+
if (pad_size > 0) {
|
|
120
|
+
const padded_data = new Float32Array(num_channels * (num_frames + pad_size));
|
|
121
|
+
padded_data.set(data)
|
|
122
|
+
padded_data.fill(this.config.padding_value, data.length)
|
|
123
|
+
|
|
124
|
+
const numPaddedFrames = num_frames + pad_size;
|
|
125
|
+
features = new Tensor(
|
|
126
|
+
features.type,
|
|
127
|
+
padded_data,
|
|
128
|
+
[numPaddedFrames, num_channels],
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if (return_attention_mask) {
|
|
132
|
+
padded_attention_mask = new Tensor(
|
|
133
|
+
'int64',
|
|
134
|
+
new BigInt64Array(numPaddedFrames),
|
|
135
|
+
[1, numPaddedFrames],
|
|
136
|
+
)
|
|
137
|
+
padded_attention_mask.data.fill(1n, 0, num_frames);
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
const [num_frames, num_channels] = features.dims;
|
|
143
|
+
|
|
144
|
+
const stride = this.config.stride;
|
|
145
|
+
const remainder = num_frames % stride;
|
|
146
|
+
if (remainder !== 0) {
|
|
147
|
+
throw new Error(`The number of frames (${num_frames}) must be a multiple of the stride (${stride}).`)
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
const input_features = features.view(
|
|
151
|
+
1,
|
|
152
|
+
Math.floor(num_frames / stride),
|
|
153
|
+
num_channels * stride,
|
|
154
|
+
);
|
|
155
|
+
|
|
156
|
+
const result = { input_features }
|
|
157
|
+
|
|
158
|
+
if (return_attention_mask) {
|
|
159
|
+
const reshapedNumFrames = input_features.dims[1];
|
|
160
|
+
|
|
161
|
+
const attention_mask_data = new BigInt64Array(reshapedNumFrames);
|
|
162
|
+
|
|
163
|
+
if (padded_attention_mask) {
|
|
164
|
+
const padded_attention_mask_data = padded_attention_mask.data;
|
|
165
|
+
for (let i = 1, j = 0; i < num_frames; i += stride, ++j) {
|
|
166
|
+
attention_mask_data[j] = padded_attention_mask_data[i];
|
|
167
|
+
}
|
|
168
|
+
} else {
|
|
169
|
+
attention_mask_data.fill(1n);
|
|
170
|
+
}
|
|
171
|
+
result.attention_mask = new Tensor(
|
|
172
|
+
'int64',
|
|
173
|
+
attention_mask_data,
|
|
174
|
+
[1, reshapedNumFrames],
|
|
175
|
+
);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
return result;
|
|
179
|
+
}
|
|
180
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ImageProcessor,
|
|
3
|
+
post_process_semantic_segmentation,
|
|
4
|
+
} from "../../base/image_processors_utils.js";
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
export class SegformerImageProcessor extends ImageProcessor {
|
|
8
|
+
/** @type {typeof post_process_semantic_segmentation} */
|
|
9
|
+
post_process_semantic_segmentation(...args) {
|
|
10
|
+
return post_process_semantic_segmentation(...args);
|
|
11
|
+
}
|
|
12
|
+
}
|
|
13
|
+
export class SegformerFeatureExtractor extends SegformerImageProcessor { }
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import { Processor } from "../../base/processing_utils.js";
|
|
2
|
+
import { AutoTokenizer } from "../../tokenizers.js";
|
|
3
|
+
import { AutoFeatureExtractor } from "../auto/feature_extraction_auto.js";
|
|
4
|
+
|
|
5
|
+
export class SpeechT5Processor extends Processor {
|
|
6
|
+
static tokenizer_class = AutoTokenizer
|
|
7
|
+
static feature_extractor_class = AutoFeatureExtractor
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* Calls the feature_extractor function with the given input.
|
|
11
|
+
* @param {any} input The input to extract features from.
|
|
12
|
+
* @returns {Promise<any>} A Promise that resolves with the extracted features.
|
|
13
|
+
*/
|
|
14
|
+
async _call(input) {
|
|
15
|
+
return await this.feature_extractor(input)
|
|
16
|
+
}
|
|
17
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ImageProcessor,
|
|
3
|
+
} from "../../base/image_processors_utils.js";
|
|
4
|
+
|
|
5
|
+
export class Swin2SRImageProcessor extends ImageProcessor {
|
|
6
|
+
pad_image(pixelData, imgDims, padSize, options = {}) {
|
|
7
|
+
// NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
|
|
8
|
+
// In other words, the image is padded so that its width and height are multiples of `padSize`.
|
|
9
|
+
const [imageHeight, imageWidth, imageChannels] = imgDims;
|
|
10
|
+
|
|
11
|
+
return super.pad_image(pixelData, imgDims, {
|
|
12
|
+
// NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
|
|
13
|
+
// a multiple of `pad_size`. However, this is most likely a bug (PR: https://github.com/mv-lab/swin2sr/pull/19).
|
|
14
|
+
// For this reason, we only add padding when the image's width/height is not a multiple of `pad_size`.
|
|
15
|
+
width: imageWidth + (padSize - imageWidth % padSize) % padSize,
|
|
16
|
+
height: imageHeight + (padSize - imageHeight % padSize) % padSize,
|
|
17
|
+
}, {
|
|
18
|
+
mode: 'symmetric',
|
|
19
|
+
center: false,
|
|
20
|
+
constant_values: -1,
|
|
21
|
+
...options,
|
|
22
|
+
})
|
|
23
|
+
}
|
|
24
|
+
}
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ImageProcessor,
|
|
3
|
+
} from "../../base/image_processors_utils.js";
|
|
4
|
+
|
|
5
|
+
import {
|
|
6
|
+
stack,
|
|
7
|
+
cat,
|
|
8
|
+
} from "../../utils/tensor.js";
|
|
9
|
+
|
|
10
|
+
export class VitMatteImageProcessor extends ImageProcessor {
|
|
11
|
+
/**
|
|
12
|
+
* Calls the feature extraction process on an array of images, preprocesses
|
|
13
|
+
* each image, and concatenates the resulting features into a single Tensor.
|
|
14
|
+
* @param {import("../../utils/image.js").RawImage[]} images The image(s) to extract features from.
|
|
15
|
+
* @param {import("../../utils/image.js").RawImage[]} trimaps The trimaps(s) to extract features from.
|
|
16
|
+
* @returns {Promise<import("../../base/image_processors_utils.js").ImageProcessorResult>} An object containing the concatenated pixel values of the preprocessed images.
|
|
17
|
+
*/
|
|
18
|
+
async _call(images, trimaps) {
|
|
19
|
+
if (!Array.isArray(images)) {
|
|
20
|
+
images = [images];
|
|
21
|
+
}
|
|
22
|
+
if (!Array.isArray(trimaps)) {
|
|
23
|
+
trimaps = [trimaps];
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
const imageData = await Promise.all(images.map(x => this.preprocess(x)));
|
|
27
|
+
const trimapData = await Promise.all(trimaps.map(x => this.preprocess(x, {
|
|
28
|
+
do_normalize: false,
|
|
29
|
+
do_convert_rgb: false,
|
|
30
|
+
do_convert_grayscale: true,
|
|
31
|
+
})));
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
// Stack pixel values
|
|
35
|
+
const pixel_values = stack(imageData.map(
|
|
36
|
+
// Concatenate images and trimaps
|
|
37
|
+
(x, i) => cat([x.pixel_values, trimapData[i].pixel_values], 0)
|
|
38
|
+
), 0);
|
|
39
|
+
|
|
40
|
+
return {
|
|
41
|
+
pixel_values,
|
|
42
|
+
|
|
43
|
+
// Original sizes of images
|
|
44
|
+
original_sizes: imageData.map(x => x.original_size),
|
|
45
|
+
|
|
46
|
+
// Reshaped sizes of images, before padding or cropping
|
|
47
|
+
reshaped_input_sizes: imageData.map(x => x.reshaped_input_size),
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
}
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import {
|
|
2
|
+
ImageProcessor,
|
|
3
|
+
} from "../../base/image_processors_utils.js";
|
|
4
|
+
|
|
5
|
+
export class VitPoseImageProcessor extends ImageProcessor {
|
|
6
|
+
|
|
7
|
+
/**
|
|
8
|
+
* Transform the heatmaps into keypoint predictions and transform them back to the image.
|
|
9
|
+
* NOTE: This is a naive implementation and does not include advanced post-processing techniques,
|
|
10
|
+
* so the results may not be as accurate as the original implementation.
|
|
11
|
+
* @param {import('../../utils/tensor.js').Tensor} outputs The model outputs.
|
|
12
|
+
* @param {[number, number, number, number][][]} boxes List or array of bounding boxes for each image.
|
|
13
|
+
* Each box should be a list of 4 floats representing the bounding box coordinates in COCO format (top_left_x, top_left_y, width, height).
|
|
14
|
+
* @returns {{
|
|
15
|
+
* bbox: [number, number, number, number],
|
|
16
|
+
* scores: number[],
|
|
17
|
+
* labels: number[],
|
|
18
|
+
* keypoints: [number, number][]
|
|
19
|
+
* }[][]} List of keypoints predictions for each image.
|
|
20
|
+
*/
|
|
21
|
+
post_process_pose_estimation(outputs, boxes, {
|
|
22
|
+
threshold = null,
|
|
23
|
+
// TODO:
|
|
24
|
+
// kernel_size = 11,
|
|
25
|
+
// target_sizes = null,
|
|
26
|
+
} = {}) {
|
|
27
|
+
// NOTE: boxes are 3D (batch_size, num_boxes, 4)
|
|
28
|
+
const heatmaps = outputs.tolist();
|
|
29
|
+
const [batch_size, num_classes, height, width] = outputs.dims;
|
|
30
|
+
|
|
31
|
+
const results = [];
|
|
32
|
+
for (let b = 0; b < batch_size; ++b) {
|
|
33
|
+
const heatmap = heatmaps[b];
|
|
34
|
+
const bboxes = boxes[b];
|
|
35
|
+
|
|
36
|
+
const batch_results = [];
|
|
37
|
+
for (let n = 0; n < bboxes.length; ++n) {
|
|
38
|
+
const bbox = bboxes[n];
|
|
39
|
+
|
|
40
|
+
const keypoints = [];
|
|
41
|
+
const scores = [];
|
|
42
|
+
const labels = [];
|
|
43
|
+
|
|
44
|
+
const xScale = bbox.at(-2) / width;
|
|
45
|
+
const yScale = bbox.at(-1) / height;
|
|
46
|
+
for (let c = 0; c < heatmap.length; ++c) {
|
|
47
|
+
let [xWeightedSum, yWeightedSum] = [0, 0];
|
|
48
|
+
let sum = 0;
|
|
49
|
+
let score = -Infinity;
|
|
50
|
+
const row = heatmap[c];
|
|
51
|
+
for (let y = 0; y < row.length; ++y) {
|
|
52
|
+
const col = row[y];
|
|
53
|
+
for (let x = 0; x < col.length; ++x) {
|
|
54
|
+
const value = col[x];
|
|
55
|
+
sum += value;
|
|
56
|
+
|
|
57
|
+
score = Math.max(score, value);
|
|
58
|
+
|
|
59
|
+
// Get weighted sum of positions
|
|
60
|
+
// TODO: Determine best offsets
|
|
61
|
+
xWeightedSum += (x + 0.5) * value;
|
|
62
|
+
yWeightedSum += (y) * value;
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
// Ignore low scores, if threshold is set
|
|
67
|
+
if (threshold != null && score < threshold) continue;
|
|
68
|
+
|
|
69
|
+
/** @type {[number, number]} */
|
|
70
|
+
const keypoint = [
|
|
71
|
+
xScale * xWeightedSum / sum,
|
|
72
|
+
yScale * yWeightedSum / sum,
|
|
73
|
+
]
|
|
74
|
+
keypoints.push(keypoint);
|
|
75
|
+
labels.push(c);
|
|
76
|
+
scores.push(score);
|
|
77
|
+
}
|
|
78
|
+
batch_results.push({
|
|
79
|
+
bbox,
|
|
80
|
+
scores,
|
|
81
|
+
labels,
|
|
82
|
+
keypoints,
|
|
83
|
+
});
|
|
84
|
+
}
|
|
85
|
+
results.push(batch_results);
|
|
86
|
+
}
|
|
87
|
+
return results;
|
|
88
|
+
}
|
|
89
|
+
}
|