@huggingface/transformers 3.0.2 → 3.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +12 -4
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.cjs +16235 -13145
- package/dist/transformers.cjs.map +1 -1
- package/dist/transformers.js +16536 -13437
- package/dist/transformers.js.map +1 -1
- package/dist/transformers.min.cjs +238 -52
- package/dist/transformers.min.cjs.map +1 -1
- package/dist/transformers.min.js +229 -43
- package/dist/transformers.min.js.map +1 -1
- package/dist/transformers.min.mjs +240 -54
- package/dist/transformers.min.mjs.map +1 -1
- package/dist/transformers.mjs +15259 -12171
- package/dist/transformers.mjs.map +1 -1
- package/package.json +4 -4
- package/src/base/feature_extraction_utils.js +54 -0
- package/src/base/image_processors_utils.js +1089 -0
- package/src/base/processing_utils.js +145 -0
- package/src/configs.js +13 -3
- package/src/env.js +1 -1
- package/src/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js +90 -0
- package/src/models/auto/feature_extraction_auto.js +41 -0
- package/src/models/auto/image_processing_auto.js +29 -0
- package/src/models/auto/processing_auto.js +100 -0
- package/src/models/beit/image_processing_beit.js +5 -0
- package/src/models/bit/image_processing_bit.js +5 -0
- package/src/models/chinese_clip/image_processing_chinese_clip.js +5 -0
- package/src/models/clap/feature_extraction_clap.js +159 -0
- package/src/models/clip/image_processing_clip.js +6 -0
- package/src/models/convnext/image_processing_convnext.js +45 -0
- package/src/models/deit/image_processing_deit.js +6 -0
- package/src/models/detr/image_processing_detr.js +52 -0
- package/src/models/donut/image_processing_donut.js +31 -0
- package/src/models/dpt/image_processing_dpt.js +6 -0
- package/src/models/efficientnet/image_processing_efficientnet.js +13 -0
- package/src/models/feature_extractors.js +12 -0
- package/src/models/florence2/processing_florence2.js +128 -0
- package/src/models/glpn/image_processing_glpn.js +5 -0
- package/src/models/image_processors.js +36 -0
- package/src/models/janus/image_processing_janus.js +26 -0
- package/src/models/janus/processing_janus.js +123 -0
- package/src/models/jina_clip/image_processing_jina_clip.js +26 -0
- package/src/models/jina_clip/processing_jina_clip.js +24 -0
- package/src/models/llava_onevision/image_processing_llava_onevision.js +5 -0
- package/src/models/mask2former/image_processing_mask2former.js +5 -0
- package/src/models/maskformer/image_processing_maskformer.js +18 -0
- package/src/models/mgp_str/processing_mgp_str.js +170 -0
- package/src/models/mobilenet_v1/image_processing_mobilenet_v1.js +7 -0
- package/src/models/mobilenet_v2/image_processing_mobilenet_v2.js +7 -0
- package/src/models/mobilenet_v3/image_processing_mobilenet_v3.js +7 -0
- package/src/models/mobilenet_v4/image_processing_mobilenet_v4.js +7 -0
- package/src/models/mobilevit/image_processing_mobilevit.js +6 -0
- package/src/models/nougat/image_processing_nougat.js +5 -0
- package/src/models/owlv2/image_processing_owlv2.js +5 -0
- package/src/models/owlvit/image_processing_owlvit.js +12 -0
- package/src/models/owlvit/processing_owlvit.js +7 -0
- package/src/models/processors.js +11 -0
- package/src/models/pvt/image_processing_pvt.js +5 -0
- package/src/models/pyannote/feature_extraction_pyannote.js +28 -0
- package/src/models/pyannote/processing_pyannote.js +71 -0
- package/src/models/qwen2_vl/image_processing_qwen2_vl.js +52 -0
- package/src/models/qwen2_vl/processing_qwen2_vl.js +52 -0
- package/src/models/rt_detr/image_processing_rt_detr.js +12 -0
- package/src/models/sam/image_processing_sam.js +242 -0
- package/src/models/sam/processing_sam.js +20 -0
- package/src/models/sapiens/image_processing_sapiens.js +13 -0
- package/src/models/seamless_m4t/feature_extraction_seamless_m4t.js +180 -0
- package/src/models/segformer/image_processing_segformer.js +13 -0
- package/src/models/siglip/image_processing_siglip.js +5 -0
- package/src/models/speecht5/feature_extraction_speecht5.js +4 -0
- package/src/models/speecht5/processing_speecht5.js +17 -0
- package/src/models/swin2sr/image_processing_swin2sr.js +24 -0
- package/src/models/vit/image_processing_vit.js +7 -0
- package/src/models/vitmatte/image_processing_vitmatte.js +50 -0
- package/src/models/vitpose/image_processing_vitpose.js +89 -0
- package/src/models/wav2vec2/feature_extraction_wav2vec2.js +44 -0
- package/src/models/wav2vec2/processing_wav2vec2.js +15 -0
- package/src/models/wespeaker/feature_extraction_wespeaker.js +100 -0
- package/src/models/whisper/feature_extraction_whisper.js +84 -0
- package/src/models/whisper/processing_whisper.js +21 -0
- package/src/models/yolos/image_processing_yolos.js +12 -0
- package/src/models.js +675 -32
- package/src/pipelines.js +8 -8
- package/src/tokenizers.js +5 -0
- package/src/transformers.js +15 -2
- package/src/utils/constants.js +8 -1
- package/src/utils/core.js +37 -9
- package/src/utils/hub.js +2 -1
- package/src/utils/image.js +68 -17
- package/src/utils/tensor.js +33 -1
- package/types/base/feature_extraction_utils.d.ts +41 -0
- package/types/base/feature_extraction_utils.d.ts.map +1 -0
- package/types/base/image_processors_utils.d.ts +323 -0
- package/types/base/image_processors_utils.d.ts.map +1 -0
- package/types/base/processing_utils.d.ts +80 -0
- package/types/base/processing_utils.d.ts.map +1 -0
- package/types/configs.d.ts +4 -1
- package/types/configs.d.ts.map +1 -1
- package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts +25 -0
- package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts.map +1 -0
- package/types/models/auto/feature_extraction_auto.d.ts +5 -0
- package/types/models/auto/feature_extraction_auto.d.ts.map +1 -0
- package/types/models/auto/image_processing_auto.d.ts +5 -0
- package/types/models/auto/image_processing_auto.d.ts.map +1 -0
- package/types/models/auto/processing_auto.d.ts +35 -0
- package/types/models/auto/processing_auto.d.ts.map +1 -0
- package/types/models/beit/image_processing_beit.d.ts +4 -0
- package/types/models/beit/image_processing_beit.d.ts.map +1 -0
- package/types/models/bit/image_processing_bit.d.ts +4 -0
- package/types/models/bit/image_processing_bit.d.ts.map +1 -0
- package/types/models/chinese_clip/image_processing_chinese_clip.d.ts +4 -0
- package/types/models/chinese_clip/image_processing_chinese_clip.d.ts.map +1 -0
- package/types/models/clap/feature_extraction_clap.d.ts +57 -0
- package/types/models/clap/feature_extraction_clap.d.ts.map +1 -0
- package/types/models/clip/image_processing_clip.d.ts +6 -0
- package/types/models/clip/image_processing_clip.d.ts.map +1 -0
- package/types/models/convnext/image_processing_convnext.d.ts +12 -0
- package/types/models/convnext/image_processing_convnext.d.ts.map +1 -0
- package/types/models/deit/image_processing_deit.d.ts +6 -0
- package/types/models/deit/image_processing_deit.d.ts.map +1 -0
- package/types/models/detr/image_processing_detr.d.ts +42 -0
- package/types/models/detr/image_processing_detr.d.ts.map +1 -0
- package/types/models/donut/image_processing_donut.d.ts +7 -0
- package/types/models/donut/image_processing_donut.d.ts.map +1 -0
- package/types/models/dpt/image_processing_dpt.d.ts +6 -0
- package/types/models/dpt/image_processing_dpt.d.ts.map +1 -0
- package/types/models/efficientnet/image_processing_efficientnet.d.ts +6 -0
- package/types/models/efficientnet/image_processing_efficientnet.d.ts.map +1 -0
- package/types/models/feature_extractors.d.ts +10 -0
- package/types/models/feature_extractors.d.ts.map +1 -0
- package/types/models/florence2/processing_florence2.d.ts +39 -0
- package/types/models/florence2/processing_florence2.d.ts.map +1 -0
- package/types/models/glpn/image_processing_glpn.d.ts +4 -0
- package/types/models/glpn/image_processing_glpn.d.ts.map +1 -0
- package/types/models/image_processors.d.ts +36 -0
- package/types/models/image_processors.d.ts.map +1 -0
- package/types/models/janus/image_processing_janus.d.ts +7 -0
- package/types/models/janus/image_processing_janus.d.ts.map +1 -0
- package/types/models/janus/processing_janus.d.ts +77 -0
- package/types/models/janus/processing_janus.d.ts.map +1 -0
- package/types/models/jina_clip/image_processing_jina_clip.d.ts +5 -0
- package/types/models/jina_clip/image_processing_jina_clip.d.ts.map +1 -0
- package/types/models/jina_clip/processing_jina_clip.d.ts +9 -0
- package/types/models/jina_clip/processing_jina_clip.d.ts.map +1 -0
- package/types/models/llava_onevision/image_processing_llava_onevision.d.ts +4 -0
- package/types/models/llava_onevision/image_processing_llava_onevision.d.ts.map +1 -0
- package/types/models/mask2former/image_processing_mask2former.d.ts +4 -0
- package/types/models/mask2former/image_processing_mask2former.d.ts.map +1 -0
- package/types/models/maskformer/image_processing_maskformer.d.ts +22 -0
- package/types/models/maskformer/image_processing_maskformer.d.ts.map +1 -0
- package/types/models/mgp_str/processing_mgp_str.d.ts +64 -0
- package/types/models/mgp_str/processing_mgp_str.d.ts.map +1 -0
- package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts +6 -0
- package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts.map +1 -0
- package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts +6 -0
- package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts.map +1 -0
- package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts +6 -0
- package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts.map +1 -0
- package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts +6 -0
- package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts.map +1 -0
- package/types/models/mobilevit/image_processing_mobilevit.d.ts +6 -0
- package/types/models/mobilevit/image_processing_mobilevit.d.ts.map +1 -0
- package/types/models/nougat/image_processing_nougat.d.ts +4 -0
- package/types/models/nougat/image_processing_nougat.d.ts.map +1 -0
- package/types/models/owlv2/image_processing_owlv2.d.ts +4 -0
- package/types/models/owlv2/image_processing_owlv2.d.ts.map +1 -0
- package/types/models/owlvit/image_processing_owlvit.d.ts +10 -0
- package/types/models/owlvit/image_processing_owlvit.d.ts.map +1 -0
- package/types/models/owlvit/processing_owlvit.d.ts +8 -0
- package/types/models/owlvit/processing_owlvit.d.ts.map +1 -0
- package/types/models/processors.d.ts +12 -0
- package/types/models/processors.d.ts.map +1 -0
- package/types/models/pvt/image_processing_pvt.d.ts +4 -0
- package/types/models/pvt/image_processing_pvt.d.ts.map +1 -0
- package/types/models/pyannote/feature_extraction_pyannote.d.ts +13 -0
- package/types/models/pyannote/feature_extraction_pyannote.d.ts.map +1 -0
- package/types/models/pyannote/processing_pyannote.d.ts +30 -0
- package/types/models/pyannote/processing_pyannote.d.ts.map +1 -0
- package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts +11 -0
- package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts.map +1 -0
- package/types/models/qwen2_vl/processing_qwen2_vl.d.ts +17 -0
- package/types/models/qwen2_vl/processing_qwen2_vl.d.ts.map +1 -0
- package/types/models/rt_detr/image_processing_rt_detr.d.ts +8 -0
- package/types/models/rt_detr/image_processing_rt_detr.d.ts.map +1 -0
- package/types/models/sam/image_processing_sam.d.ts +103 -0
- package/types/models/sam/image_processing_sam.d.ts.map +1 -0
- package/types/models/sam/processing_sam.d.ts +9 -0
- package/types/models/sam/processing_sam.d.ts.map +1 -0
- package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts +34 -0
- package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts.map +1 -0
- package/types/models/segformer/image_processing_segformer.d.ts +10 -0
- package/types/models/segformer/image_processing_segformer.d.ts.map +1 -0
- package/types/models/siglip/image_processing_siglip.d.ts +4 -0
- package/types/models/siglip/image_processing_siglip.d.ts.map +1 -0
- package/types/models/speecht5/feature_extraction_speecht5.d.ts +4 -0
- package/types/models/speecht5/feature_extraction_speecht5.d.ts.map +1 -0
- package/types/models/speecht5/processing_speecht5.d.ts +14 -0
- package/types/models/speecht5/processing_speecht5.d.ts.map +1 -0
- package/types/models/swin2sr/image_processing_swin2sr.d.ts +5 -0
- package/types/models/swin2sr/image_processing_swin2sr.d.ts.map +1 -0
- package/types/models/vit/image_processing_vit.d.ts +6 -0
- package/types/models/vit/image_processing_vit.d.ts.map +1 -0
- package/types/models/vitmatte/image_processing_vitmatte.d.ts +12 -0
- package/types/models/vitmatte/image_processing_vitmatte.d.ts.map +1 -0
- package/types/models/vitpose/image_processing_vitpose.d.ts +26 -0
- package/types/models/vitpose/image_processing_vitpose.d.ts.map +1 -0
- package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts +19 -0
- package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts.map +1 -0
- package/types/models/wav2vec2/processing_wav2vec2.d.ts +12 -0
- package/types/models/wav2vec2/processing_wav2vec2.d.ts.map +1 -0
- package/types/models/wespeaker/feature_extraction_wespeaker.d.ts +23 -0
- package/types/models/wespeaker/feature_extraction_wespeaker.d.ts.map +1 -0
- package/types/models/whisper/feature_extraction_whisper.d.ts +21 -0
- package/types/models/whisper/feature_extraction_whisper.d.ts.map +1 -0
- package/types/models/whisper/processing_whisper.d.ts +17 -0
- package/types/models/whisper/processing_whisper.d.ts.map +1 -0
- package/types/models/yolos/image_processing_yolos.d.ts +10 -0
- package/types/models/yolos/image_processing_yolos.d.ts.map +1 -0
- package/types/models.d.ts +140 -0
- package/types/models.d.ts.map +1 -1
- package/types/pipelines.d.ts +2 -3
- package/types/pipelines.d.ts.map +1 -1
- package/types/tokenizers.d.ts +3 -0
- package/types/tokenizers.d.ts.map +1 -1
- package/types/transformers.d.ts +10 -1
- package/types/utils/constants.d.ts +6 -0
- package/types/utils/constants.d.ts.map +1 -1
- package/types/utils/core.d.ts +58 -3
- package/types/utils/core.d.ts.map +1 -1
- package/types/utils/hub.d.ts +1 -1
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/image.d.ts +10 -2
- package/types/utils/image.d.ts.map +1 -1
- package/types/utils/tensor.d.ts +34 -1
- package/types/utils/tensor.d.ts.map +1 -1
- package/src/processors.js +0 -2655
- package/types/processors.d.ts +0 -924
- package/types/processors.d.ts.map +0 -1
package/src/models.js
CHANGED
|
@@ -61,7 +61,6 @@ import {
|
|
|
61
61
|
} from './utils/generic.js';
|
|
62
62
|
|
|
63
63
|
import {
|
|
64
|
-
isIntegralNumber,
|
|
65
64
|
mergeArrays,
|
|
66
65
|
pick,
|
|
67
66
|
} from './utils/core.js';
|
|
@@ -99,17 +98,20 @@ import {
|
|
|
99
98
|
|
|
100
99
|
import {
|
|
101
100
|
cat,
|
|
102
|
-
full_like,
|
|
103
101
|
mean,
|
|
102
|
+
zeros,
|
|
103
|
+
zeros_like,
|
|
104
104
|
ones,
|
|
105
105
|
ones_like,
|
|
106
|
+
full,
|
|
107
|
+
full_like,
|
|
106
108
|
stack,
|
|
107
109
|
std_mean,
|
|
108
110
|
Tensor,
|
|
109
|
-
zeros_like,
|
|
110
111
|
} from './utils/tensor.js';
|
|
112
|
+
import { RawImage } from './utils/image.js';
|
|
111
113
|
|
|
112
|
-
import { dynamic_time_warping, medianFilter } from './utils/maths.js';
|
|
114
|
+
import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
|
|
113
115
|
import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
|
|
114
116
|
import { LogitsSampler } from './generation/logits_sampler.js';
|
|
115
117
|
import { apis } from './env.js';
|
|
@@ -128,6 +130,7 @@ const MODEL_TYPES = {
|
|
|
128
130
|
MaskGeneration: 5,
|
|
129
131
|
ImageTextToText: 6,
|
|
130
132
|
Musicgen: 7,
|
|
133
|
+
MultiModality: 8,
|
|
131
134
|
}
|
|
132
135
|
//////////////////////////////////////////////////
|
|
133
136
|
|
|
@@ -386,7 +389,7 @@ async function sessionRun(session, inputs) {
|
|
|
386
389
|
} catch (e) {
|
|
387
390
|
// This usually occurs when the inputs are of the wrong type.
|
|
388
391
|
console.error(`An error occurred during model execution: "${e}".`);
|
|
389
|
-
console.error('Inputs given to model:', checkedInputs)
|
|
392
|
+
console.error('Inputs given to model:', checkedInputs)
|
|
390
393
|
throw e;
|
|
391
394
|
}
|
|
392
395
|
}
|
|
@@ -579,11 +582,11 @@ async function imageTextToTextForward(self, {
|
|
|
579
582
|
|
|
580
583
|
if (!inputs_embeds) {
|
|
581
584
|
// 1. Extract the input embeddings
|
|
582
|
-
inputs_embeds = await self.encode_text({ input_ids });
|
|
585
|
+
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
|
|
583
586
|
|
|
584
587
|
// 2. Possibly, merge text and images
|
|
585
588
|
if (pixel_values && input_ids.dims[1] !== 1) {
|
|
586
|
-
const image_features = await self.encode_image({ pixel_values });
|
|
589
|
+
const image_features = await self.encode_image({ pixel_values, ...kwargs });
|
|
587
590
|
|
|
588
591
|
({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
|
|
589
592
|
image_features,
|
|
@@ -604,6 +607,16 @@ async function imageTextToTextForward(self, {
|
|
|
604
607
|
}
|
|
605
608
|
}
|
|
606
609
|
|
|
610
|
+
if (!position_ids) {
|
|
611
|
+
|
|
612
|
+
if (self.config.model_type === 'qwen2_vl') {
|
|
613
|
+
// Special case for qwen2_vl models
|
|
614
|
+
// @ts-ignore
|
|
615
|
+
const { image_grid_thw, video_grid_thw } = kwargs;
|
|
616
|
+
[position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
|
617
|
+
}
|
|
618
|
+
}
|
|
619
|
+
|
|
607
620
|
const outputs = await decoderForward(self, {
|
|
608
621
|
inputs_embeds,
|
|
609
622
|
past_key_values,
|
|
@@ -615,34 +628,54 @@ async function imageTextToTextForward(self, {
|
|
|
615
628
|
return outputs;
|
|
616
629
|
}
|
|
617
630
|
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
631
|
+
/**
|
|
632
|
+
* Helper function to perform the following:
|
|
633
|
+
* ```python
|
|
634
|
+
* x = attention_mask.long().cumsum(-1) - 1
|
|
635
|
+
* x.masked_fill_(attention_mask == 0, 1)
|
|
636
|
+
* ```
|
|
637
|
+
* @param {Tensor} attention_mask
|
|
638
|
+
* @returns {{data: BigInt64Array, dims: number[]}}
|
|
639
|
+
*/
|
|
640
|
+
function cumsum_masked_fill(attention_mask) {
|
|
628
641
|
const [bz, seq_len] = attention_mask.dims;
|
|
642
|
+
const attn_mask_data = attention_mask.data;
|
|
629
643
|
|
|
630
|
-
const data = new BigInt64Array(
|
|
644
|
+
const data = new BigInt64Array(attn_mask_data.length);
|
|
631
645
|
for (let i = 0; i < bz; ++i) {
|
|
632
646
|
const start = i * seq_len;
|
|
633
647
|
let sum = BigInt(0);
|
|
634
648
|
for (let j = 0; j < seq_len; ++j) {
|
|
635
649
|
const index = start + j;
|
|
636
|
-
if (
|
|
650
|
+
if (attn_mask_data[index] === 0n) {
|
|
637
651
|
data[index] = BigInt(1);
|
|
638
652
|
} else { // === 1n
|
|
639
653
|
data[index] = sum;
|
|
640
|
-
sum +=
|
|
654
|
+
sum += attn_mask_data[index];
|
|
641
655
|
}
|
|
642
656
|
}
|
|
643
657
|
}
|
|
658
|
+
return { data, dims: attention_mask.dims };
|
|
659
|
+
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
/**
|
|
663
|
+
* If the model supports providing position_ids, we create position_ids on the fly for batch generation,
|
|
664
|
+
* by computing the cumulative sum of the attention mask along the sequence length dimension.
|
|
665
|
+
*
|
|
666
|
+
* Equivalent to:
|
|
667
|
+
* ```python
|
|
668
|
+
* position_ids = attention_mask.long().cumsum(-1) - 1
|
|
669
|
+
* position_ids.masked_fill_(attention_mask == 0, 1)
|
|
670
|
+
* if past_key_values:
|
|
671
|
+
* position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
672
|
+
* ```
|
|
673
|
+
*/
|
|
674
|
+
function createPositionIds(model_inputs, past_key_values = null) {
|
|
675
|
+
const { input_ids, inputs_embeds, attention_mask } = model_inputs;
|
|
644
676
|
|
|
645
|
-
|
|
677
|
+
const { data, dims } = cumsum_masked_fill(attention_mask);
|
|
678
|
+
let position_ids = new Tensor('int64', data, dims);
|
|
646
679
|
if (past_key_values) {
|
|
647
680
|
const offset = -(input_ids ?? inputs_embeds).dims.at(1);
|
|
648
681
|
position_ids = position_ids.slice(null, [offset, null]);
|
|
@@ -716,6 +749,52 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
|
|
|
716
749
|
}
|
|
717
750
|
}
|
|
718
751
|
|
|
752
|
+
function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
|
|
753
|
+
const has_past_key_values = !!model_inputs.past_key_values;
|
|
754
|
+
|
|
755
|
+
if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
|
|
756
|
+
if (has_past_key_values) {
|
|
757
|
+
model_inputs.input_ids = cat([
|
|
758
|
+
model_inputs.input_ids,
|
|
759
|
+
model_inputs.input_ids,
|
|
760
|
+
], 0)
|
|
761
|
+
// NOTE: attention_mask handled in generation
|
|
762
|
+
} else {
|
|
763
|
+
model_inputs.input_ids = cat([
|
|
764
|
+
model_inputs.input_ids,
|
|
765
|
+
full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)),
|
|
766
|
+
], 0);
|
|
767
|
+
model_inputs.attention_mask = cat([
|
|
768
|
+
model_inputs.attention_mask,
|
|
769
|
+
full_like(model_inputs.attention_mask, 0n),
|
|
770
|
+
], 0);
|
|
771
|
+
}
|
|
772
|
+
}
|
|
773
|
+
|
|
774
|
+
if (has_past_key_values || !model_inputs.pixel_values) {
|
|
775
|
+
model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0);
|
|
776
|
+
}
|
|
777
|
+
|
|
778
|
+
if (has_past_key_values) {
|
|
779
|
+
const num_img_tokens = 0;
|
|
780
|
+
const num_text_tokens = 1;
|
|
781
|
+
const has_image = num_img_tokens > 0 ? 1 : 0;
|
|
782
|
+
|
|
783
|
+
const batch_size = 1;
|
|
784
|
+
model_inputs.images_seq_mask = new Tensor(
|
|
785
|
+
'bool',
|
|
786
|
+
new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens),
|
|
787
|
+
[batch_size, num_img_tokens + num_text_tokens],
|
|
788
|
+
);
|
|
789
|
+
model_inputs.images_emb_mask = new Tensor(
|
|
790
|
+
'bool',
|
|
791
|
+
new Array(num_img_tokens).fill(!!has_image),
|
|
792
|
+
[batch_size, 1, num_img_tokens],
|
|
793
|
+
);
|
|
794
|
+
}
|
|
795
|
+
return model_inputs;
|
|
796
|
+
}
|
|
797
|
+
|
|
719
798
|
//////////////////////////////////////////////////
|
|
720
799
|
|
|
721
800
|
//////////////////////////////////////////////////
|
|
@@ -769,6 +848,11 @@ export class PreTrainedModel extends Callable {
|
|
|
769
848
|
this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
|
|
770
849
|
break;
|
|
771
850
|
|
|
851
|
+
case MODEL_TYPES.MultiModality:
|
|
852
|
+
this.can_generate = true;
|
|
853
|
+
this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
|
|
854
|
+
break;
|
|
855
|
+
|
|
772
856
|
default:
|
|
773
857
|
// should be MODEL_TYPES.EncoderOnly
|
|
774
858
|
this._forward = encoderForward;
|
|
@@ -912,6 +996,21 @@ export class PreTrainedModel extends Callable {
|
|
|
912
996
|
}, options),
|
|
913
997
|
]);
|
|
914
998
|
|
|
999
|
+
} else if (modelType === MODEL_TYPES.MultiModality) {
|
|
1000
|
+
info = await Promise.all([
|
|
1001
|
+
constructSessions(pretrained_model_name_or_path, {
|
|
1002
|
+
prepare_inputs_embeds: 'prepare_inputs_embeds',
|
|
1003
|
+
model: 'language_model',
|
|
1004
|
+
lm_head: 'lm_head',
|
|
1005
|
+
gen_head: 'gen_head',
|
|
1006
|
+
gen_img_embeds: 'gen_img_embeds',
|
|
1007
|
+
image_decode: 'image_decode',
|
|
1008
|
+
}, options),
|
|
1009
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
1010
|
+
generation_config: 'generation_config.json',
|
|
1011
|
+
}, options),
|
|
1012
|
+
]);
|
|
1013
|
+
|
|
915
1014
|
} else { // should be MODEL_TYPES.EncoderOnly
|
|
916
1015
|
if (modelType !== MODEL_TYPES.EncoderOnly) {
|
|
917
1016
|
console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
|
|
@@ -1658,7 +1757,8 @@ export class PreTrainedModel extends Callable {
|
|
|
1658
1757
|
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
|
|
1659
1758
|
const empty = (dtype === 'float16') ? new Uint16Array() : [];
|
|
1660
1759
|
|
|
1661
|
-
const
|
|
1760
|
+
const batch_size = (decoderFeeds[this.main_input_name] ?? decoderFeeds.attention_mask).dims?.[0] ?? 1;
|
|
1761
|
+
const shapes = getKeyValueShapes(this.config, { batch_size });
|
|
1662
1762
|
|
|
1663
1763
|
for (const name in shapes) {
|
|
1664
1764
|
decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]);
|
|
@@ -3277,6 +3377,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel {
|
|
|
3277
3377
|
}
|
|
3278
3378
|
//////////////////////////////////////////////////
|
|
3279
3379
|
|
|
3380
|
+
export class LlavaOnevisionForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration
|
|
3280
3381
|
export class Moondream1ForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration
|
|
3281
3382
|
|
|
3282
3383
|
export class Florence2PreTrainedModel extends PreTrainedModel {
|
|
@@ -3437,7 +3538,7 @@ export class CLIPModel extends CLIPPreTrainedModel { }
|
|
|
3437
3538
|
* The text model from CLIP without any head or projection on top.
|
|
3438
3539
|
*/
|
|
3439
3540
|
export class CLIPTextModel extends CLIPPreTrainedModel {
|
|
3440
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
3541
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3441
3542
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3442
3543
|
// Update default model file name if not provided
|
|
3443
3544
|
options.model_file_name ??= 'text_model';
|
|
@@ -3472,7 +3573,7 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
|
|
|
3472
3573
|
* ```
|
|
3473
3574
|
*/
|
|
3474
3575
|
export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
|
|
3475
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
3576
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3476
3577
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3477
3578
|
// Update default model file name if not provided
|
|
3478
3579
|
options.model_file_name ??= 'text_model';
|
|
@@ -3484,7 +3585,7 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
|
|
|
3484
3585
|
* The vision model from CLIP without any head or projection on top.
|
|
3485
3586
|
*/
|
|
3486
3587
|
export class CLIPVisionModel extends CLIPPreTrainedModel {
|
|
3487
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
3588
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3488
3589
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3489
3590
|
// Update default model file name if not provided
|
|
3490
3591
|
options.model_file_name ??= 'vision_model';
|
|
@@ -3519,7 +3620,7 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
|
|
|
3519
3620
|
* ```
|
|
3520
3621
|
*/
|
|
3521
3622
|
export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
|
|
3522
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
3623
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3523
3624
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3524
3625
|
// Update default model file name if not provided
|
|
3525
3626
|
options.model_file_name ??= 'vision_model';
|
|
@@ -3605,8 +3706,7 @@ export class SiglipModel extends SiglipPreTrainedModel { }
|
|
|
3605
3706
|
* ```
|
|
3606
3707
|
*/
|
|
3607
3708
|
export class SiglipTextModel extends SiglipPreTrainedModel {
|
|
3608
|
-
|
|
3609
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
3709
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3610
3710
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3611
3711
|
// Update default model file name if not provided
|
|
3612
3712
|
options.model_file_name ??= 'text_model';
|
|
@@ -3641,7 +3741,7 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
|
|
|
3641
3741
|
* ```
|
|
3642
3742
|
*/
|
|
3643
3743
|
export class SiglipVisionModel extends CLIPPreTrainedModel {
|
|
3644
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
3744
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3645
3745
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3646
3746
|
// Update default model file name if not provided
|
|
3647
3747
|
options.model_file_name ??= 'vision_model';
|
|
@@ -3655,6 +3755,67 @@ export class ChineseCLIPPreTrainedModel extends PreTrainedModel { }
|
|
|
3655
3755
|
export class ChineseCLIPModel extends ChineseCLIPPreTrainedModel { }
|
|
3656
3756
|
//////////////////////////////////////////////////
|
|
3657
3757
|
|
|
3758
|
+
//////////////////////////////////////////////////
|
|
3759
|
+
// JinaCLIP models
|
|
3760
|
+
export class JinaCLIPPreTrainedModel extends PreTrainedModel { }
|
|
3761
|
+
|
|
3762
|
+
export class JinaCLIPModel extends JinaCLIPPreTrainedModel {
|
|
3763
|
+
async forward(model_inputs) {
|
|
3764
|
+
const missing_text_inputs = !model_inputs.input_ids;
|
|
3765
|
+
const missing_image_inputs = !model_inputs.pixel_values;
|
|
3766
|
+
|
|
3767
|
+
if (missing_text_inputs && missing_image_inputs) {
|
|
3768
|
+
throw new Error('Either `input_ids` or `pixel_values` should be provided.');
|
|
3769
|
+
}
|
|
3770
|
+
|
|
3771
|
+
// If either `input_ids` or `pixel_values` aren't passed, we need to create dummy input since the model requires a value to be specified.
|
|
3772
|
+
if (missing_text_inputs) {
|
|
3773
|
+
// NOTE: We cannot pass zero-dimension tensor as input for input_ids.
|
|
3774
|
+
// Fortunately, the majority of time is spent in the vision encoder, so this shouldn't significantly impact performance.
|
|
3775
|
+
model_inputs.input_ids = ones([model_inputs.pixel_values.dims[0], 1]);
|
|
3776
|
+
}
|
|
3777
|
+
|
|
3778
|
+
if (missing_image_inputs) {
|
|
3779
|
+
// NOTE: Since we create a zero-sized tensor, this does not increase computation time.
|
|
3780
|
+
// @ts-ignore
|
|
3781
|
+
const { image_size } = this.config.vision_config;
|
|
3782
|
+
model_inputs.pixel_values = full([0, 3, image_size, image_size], 0.0); // (pass zero-dimension tensor)
|
|
3783
|
+
}
|
|
3784
|
+
|
|
3785
|
+
const { text_embeddings, image_embeddings, l2norm_text_embeddings, l2norm_image_embeddings } = await super.forward(model_inputs);
|
|
3786
|
+
|
|
3787
|
+
const result = {};
|
|
3788
|
+
if (!missing_text_inputs) {
|
|
3789
|
+
result.text_embeddings = text_embeddings;
|
|
3790
|
+
result.l2norm_text_embeddings = l2norm_text_embeddings;
|
|
3791
|
+
}
|
|
3792
|
+
if (!missing_image_inputs) {
|
|
3793
|
+
result.image_embeddings = image_embeddings;
|
|
3794
|
+
result.l2norm_image_embeddings = l2norm_image_embeddings;
|
|
3795
|
+
}
|
|
3796
|
+
return result
|
|
3797
|
+
}
|
|
3798
|
+
}
|
|
3799
|
+
|
|
3800
|
+
export class JinaCLIPTextModel extends JinaCLIPPreTrainedModel {
|
|
3801
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3802
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3803
|
+
// Update default model file name if not provided
|
|
3804
|
+
options.model_file_name ??= 'text_model';
|
|
3805
|
+
return super.from_pretrained(pretrained_model_name_or_path, options);
|
|
3806
|
+
}
|
|
3807
|
+
}
|
|
3808
|
+
|
|
3809
|
+
export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel {
|
|
3810
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
3811
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
3812
|
+
// Update default model file name if not provided
|
|
3813
|
+
options.model_file_name ??= 'vision_model';
|
|
3814
|
+
return super.from_pretrained(pretrained_model_name_or_path, options);
|
|
3815
|
+
}
|
|
3816
|
+
}
|
|
3817
|
+
//////////////////////////////////////////////////
|
|
3818
|
+
|
|
3658
3819
|
|
|
3659
3820
|
//////////////////////////////////////////////////
|
|
3660
3821
|
// CLIPSeg models
|
|
@@ -3898,6 +4059,285 @@ export class Qwen2Model extends Qwen2PreTrainedModel { }
|
|
|
3898
4059
|
export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
|
|
3899
4060
|
//////////////////////////////////////////////////
|
|
3900
4061
|
|
|
4062
|
+
export class Qwen2VLPreTrainedModel extends PreTrainedModel {
|
|
4063
|
+
forward_params = [
|
|
4064
|
+
// Text inputs
|
|
4065
|
+
'input_ids',
|
|
4066
|
+
'attention_mask',
|
|
4067
|
+
'position_ids',
|
|
4068
|
+
'past_key_values',
|
|
4069
|
+
|
|
4070
|
+
// Vision inputs
|
|
4071
|
+
'pixel_values',
|
|
4072
|
+
'image_grid_thw',
|
|
4073
|
+
];
|
|
4074
|
+
}
|
|
4075
|
+
export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
|
|
4076
|
+
|
|
4077
|
+
/**
|
|
4078
|
+
* Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
|
4079
|
+
*
|
|
4080
|
+
* Explanation:
|
|
4081
|
+
* Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
|
4082
|
+
*
|
|
4083
|
+
* For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
|
4084
|
+
* Examples:
|
|
4085
|
+
* input_ids: [T T T T T], here T is for text.
|
|
4086
|
+
* temporal position_ids: [0, 1, 2, 3, 4]
|
|
4087
|
+
* height position_ids: [0, 1, 2, 3, 4]
|
|
4088
|
+
* width position_ids: [0, 1, 2, 3, 4]
|
|
4089
|
+
*
|
|
4090
|
+
* For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
|
4091
|
+
* and 1D rotary position embeddin for text part.
|
|
4092
|
+
* Examples:
|
|
4093
|
+
* Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
|
4094
|
+
* input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
|
4095
|
+
* vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
|
|
4096
|
+
* vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
|
4097
|
+
* vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
|
4098
|
+
* text temporal position_ids: [3, 4, 5, 6, 7]
|
|
4099
|
+
* text height position_ids: [3, 4, 5, 6, 7]
|
|
4100
|
+
* text width position_ids: [3, 4, 5, 6, 7]
|
|
4101
|
+
* Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
|
4102
|
+
*
|
|
4103
|
+
* @param {Tensor} input_ids Indices of input sequence tokens in the vocabulary. Tensor of shape `(batch_size, sequence_length)`.
|
|
4104
|
+
* @param {Tensor} image_grid_thw (Optional) The temporal, height and width of feature shape of each image in LLM. Tensor of shape `(num_images, 3)`.
|
|
4105
|
+
* @param {Tensor} video_grid_thw (Optional) The temporal, height and width of feature shape of each video in LLM. Tensor of shape `(num_videos, 3)`.
|
|
4106
|
+
* @param {Tensor} attention_mask (Optional) Mask to avoid performing attention on padding token indices. Tensor of shape `(batch_size, sequence_length)`. Mask values selected in `[0, 1]`:
|
|
4107
|
+
* - 1 for tokens that are **not masked**,
|
|
4108
|
+
* - 0 for tokens that are **masked**.
|
|
4109
|
+
* @returns {[Tensor, Tensor]} [position_ids, mrope_position_deltas] with:
|
|
4110
|
+
* - position_ids: Tensor of shape `(3, batch_size, sequence_length)`.
|
|
4111
|
+
* - mrope_position_deltas: Tensor of shape `(batch_size)`.
|
|
4112
|
+
*/
|
|
4113
|
+
get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) {
|
|
4114
|
+
// @ts-ignore
|
|
4115
|
+
const { vision_config, image_token_id, video_token_id, vision_start_token_id } = this.config;
|
|
4116
|
+
const spatial_merge_size = vision_config.spatial_merge_size ?? 2;
|
|
4117
|
+
|
|
4118
|
+
const mrope_position_deltas = [];
|
|
4119
|
+
if (image_grid_thw || video_grid_thw) {
|
|
4120
|
+
let total_input_ids = input_ids.tolist();
|
|
4121
|
+
if (!attention_mask) {
|
|
4122
|
+
attention_mask = ones_like(input_ids);
|
|
4123
|
+
}
|
|
4124
|
+
|
|
4125
|
+
const attention_mask_list = attention_mask.tolist();
|
|
4126
|
+
const position_ids_list = Array.from({ length: 3 }, _ => Array.from({ length: input_ids.dims[0] }, _ => Array.from({ length: input_ids.dims[1] }, _ => 1)));
|
|
4127
|
+
|
|
4128
|
+
const image_grid_thw_list = image_grid_thw ? image_grid_thw.tolist() : [];
|
|
4129
|
+
const video_grid_thw_list = video_grid_thw ? video_grid_thw.tolist() : [];
|
|
4130
|
+
|
|
4131
|
+
let image_index = 0;
|
|
4132
|
+
let video_index = 0;
|
|
4133
|
+
for (let i = 0; i < total_input_ids.length; ++i) {
|
|
4134
|
+
const ids = total_input_ids[i].filter((_, j) => attention_mask_list[i][j] == 1);
|
|
4135
|
+
|
|
4136
|
+
const vision_start_indices = ids.reduce((acc, x, idx) => {
|
|
4137
|
+
if (x == vision_start_token_id) acc.push(idx);
|
|
4138
|
+
return acc;
|
|
4139
|
+
}, []);
|
|
4140
|
+
|
|
4141
|
+
const vision_tokens = vision_start_indices.map(x => ids[x + 1]);
|
|
4142
|
+
const image_nums = vision_tokens.filter(x => x == image_token_id).length;
|
|
4143
|
+
const video_nums = vision_tokens.filter(x => x == video_token_id).length;
|
|
4144
|
+
|
|
4145
|
+
let llm_pos_ids_list = [];
|
|
4146
|
+
let st = 0;
|
|
4147
|
+
let remain_images = image_nums;
|
|
4148
|
+
let remain_videos = video_nums;
|
|
4149
|
+
for (let j = 0; j < vision_tokens.length; ++j) {
|
|
4150
|
+
const next_image_token = ids.findIndex((x, i) => i > st && x == image_token_id);
|
|
4151
|
+
const next_video_token = ids.findIndex((x, i) => i > st && x == video_token_id);
|
|
4152
|
+
|
|
4153
|
+
const ed_image = (remain_images > 0 && next_image_token !== -1)
|
|
4154
|
+
? next_image_token
|
|
4155
|
+
: ids.length + 1;
|
|
4156
|
+
|
|
4157
|
+
const ed_video = (remain_videos > 0 && next_video_token !== -1)
|
|
4158
|
+
? next_video_token
|
|
4159
|
+
: ids.length + 1;
|
|
4160
|
+
|
|
4161
|
+
let ed;
|
|
4162
|
+
let t, h, w;
|
|
4163
|
+
if (ed_image < ed_video) {
|
|
4164
|
+
([t, h, w] = image_grid_thw_list[image_index]);
|
|
4165
|
+
++image_index;
|
|
4166
|
+
--remain_images;
|
|
4167
|
+
ed = ed_image;
|
|
4168
|
+
} else {
|
|
4169
|
+
([t, h, w] = video_grid_thw_list[video_index]);
|
|
4170
|
+
++video_index;
|
|
4171
|
+
--remain_videos;
|
|
4172
|
+
ed = ed_video;
|
|
4173
|
+
}
|
|
4174
|
+
|
|
4175
|
+
const [llm_grid_t, llm_grid_h, llm_grid_w] = [
|
|
4176
|
+
Number(t),
|
|
4177
|
+
Math.floor(Number(h) / spatial_merge_size),
|
|
4178
|
+
Math.floor(Number(w) / spatial_merge_size)
|
|
4179
|
+
]
|
|
4180
|
+
const text_len = ed - st;
|
|
4181
|
+
const st_idx = llm_pos_ids_list.length > 0
|
|
4182
|
+
? max(llm_pos_ids_list.at(-1))[0] + 1
|
|
4183
|
+
: 0;
|
|
4184
|
+
|
|
4185
|
+
llm_pos_ids_list.push(
|
|
4186
|
+
Array.from({ length: 3 * text_len }, (_, i) => st_idx + (i % text_len))
|
|
4187
|
+
)
|
|
4188
|
+
|
|
4189
|
+
const offset = text_len + st_idx;
|
|
4190
|
+
const grid_size = llm_grid_t * llm_grid_h * llm_grid_w;
|
|
4191
|
+
const t_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / (llm_grid_h * llm_grid_w)))
|
|
4192
|
+
const h_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / llm_grid_w) % llm_grid_h)
|
|
4193
|
+
const w_index = Array.from({ length: grid_size }, (_, i) => offset + i % llm_grid_w)
|
|
4194
|
+
|
|
4195
|
+
llm_pos_ids_list.push([t_index, h_index, w_index].flat())
|
|
4196
|
+
|
|
4197
|
+
st = ed + grid_size;
|
|
4198
|
+
}
|
|
4199
|
+
|
|
4200
|
+
if (st < ids.length) {
|
|
4201
|
+
const st_idx = llm_pos_ids_list.length > 0
|
|
4202
|
+
? max(llm_pos_ids_list.at(-1))[0] + 1
|
|
4203
|
+
: 0;
|
|
4204
|
+
const text_len = ids.length - st;
|
|
4205
|
+
|
|
4206
|
+
llm_pos_ids_list.push(
|
|
4207
|
+
Array.from({ length: 3 * text_len }, (_, i) => (st_idx + (i % text_len)))
|
|
4208
|
+
)
|
|
4209
|
+
}
|
|
4210
|
+
|
|
4211
|
+
// NOTE: Each item in llm_pos_ids_list is an array of shape (3, text_len),
|
|
4212
|
+
// meaning to perform concatenation along dim=1, we can do the following:
|
|
4213
|
+
const num_items = llm_pos_ids_list.reduce((acc, x) => acc + x.length, 0);
|
|
4214
|
+
const llm_positions = new Array(num_items);
|
|
4215
|
+
let index = 0;
|
|
4216
|
+
for (let x = 0; x < 3; ++x) {
|
|
4217
|
+
for (let y = 0; y < llm_pos_ids_list.length; ++y) {
|
|
4218
|
+
const val = llm_pos_ids_list[y];
|
|
4219
|
+
const text_len = val.length / 3;
|
|
4220
|
+
for (let z = x * text_len; z < (x + 1) * text_len; ++z) {
|
|
4221
|
+
llm_positions[index++] = val[z];
|
|
4222
|
+
}
|
|
4223
|
+
}
|
|
4224
|
+
}
|
|
4225
|
+
|
|
4226
|
+
let count = 0;
|
|
4227
|
+
const attn_mask = attention_mask_list[i];
|
|
4228
|
+
for (let y = 0; y < attn_mask.length; ++y) {
|
|
4229
|
+
if (attn_mask[y] == 1) {
|
|
4230
|
+
for (let x = 0; x < 3; ++x) {
|
|
4231
|
+
position_ids_list[x][i][y] = llm_positions[x * num_items / 3 + count];
|
|
4232
|
+
}
|
|
4233
|
+
++count;
|
|
4234
|
+
}
|
|
4235
|
+
}
|
|
4236
|
+
|
|
4237
|
+
const max_llm_positions = max(llm_positions)[0];
|
|
4238
|
+
mrope_position_deltas.push(max_llm_positions + 1 - total_input_ids[i].length);
|
|
4239
|
+
}
|
|
4240
|
+
|
|
4241
|
+
return [
|
|
4242
|
+
new Tensor('int64', position_ids_list.flat(Infinity), [3, input_ids.dims[0], input_ids.dims[1]]),
|
|
4243
|
+
new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
|
|
4244
|
+
];
|
|
4245
|
+
|
|
4246
|
+
} else { // Text-only
|
|
4247
|
+
if (attention_mask) {
|
|
4248
|
+
const { data, dims } = cumsum_masked_fill(attention_mask);
|
|
4249
|
+
|
|
4250
|
+
const position_ids = BigInt64Array.from(
|
|
4251
|
+
{ length: 3 * data.length },
|
|
4252
|
+
(_, i) => data[i % data.length]
|
|
4253
|
+
);
|
|
4254
|
+
const mrope_position_deltas = Array.from(
|
|
4255
|
+
{ length: dims[0] },
|
|
4256
|
+
(_, i) => max(data.subarray(dims[1] * i, dims[1] * (i + 1)))[0] + 1 + dims[1]
|
|
4257
|
+
);
|
|
4258
|
+
|
|
4259
|
+
return [
|
|
4260
|
+
new Tensor('int64', position_ids, [3, ...dims]),
|
|
4261
|
+
new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
|
|
4262
|
+
]
|
|
4263
|
+
} else {
|
|
4264
|
+
const [batch_size, seq_length] = input_ids.dims;
|
|
4265
|
+
const position_ids = BigInt64Array.from(
|
|
4266
|
+
{ length: 3 * batch_size * seq_length },
|
|
4267
|
+
(_, i) => BigInt(Math.floor(i % seq_length / batch_size)),
|
|
4268
|
+
);
|
|
4269
|
+
|
|
4270
|
+
return [
|
|
4271
|
+
new Tensor('int64', position_ids, [3, ...input_ids.dims]),
|
|
4272
|
+
zeros([batch_size, 1]),
|
|
4273
|
+
]
|
|
4274
|
+
}
|
|
4275
|
+
}
|
|
4276
|
+
}
|
|
4277
|
+
|
|
4278
|
+
async encode_image({ pixel_values, image_grid_thw }) {
|
|
4279
|
+
const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, grid_thw: image_grid_thw })).image_features;
|
|
4280
|
+
return features;
|
|
4281
|
+
}
|
|
4282
|
+
|
|
4283
|
+
_merge_input_ids_with_image_features({
|
|
4284
|
+
inputs_embeds,
|
|
4285
|
+
image_features,
|
|
4286
|
+
input_ids,
|
|
4287
|
+
attention_mask,
|
|
4288
|
+
}) {
|
|
4289
|
+
// @ts-ignore
|
|
4290
|
+
const { image_token_id } = this.config;
|
|
4291
|
+
const image_tokens = input_ids.tolist().map(ids =>
|
|
4292
|
+
ids.reduce((acc, x, idx) => {
|
|
4293
|
+
if (x == image_token_id) acc.push(idx);
|
|
4294
|
+
return acc;
|
|
4295
|
+
}, [])
|
|
4296
|
+
);
|
|
4297
|
+
const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
|
|
4298
|
+
const n_image_features = image_features.dims[0];
|
|
4299
|
+
if (n_image_tokens !== n_image_features) {
|
|
4300
|
+
throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
|
|
4301
|
+
}
|
|
4302
|
+
|
|
4303
|
+
// Equivalent to performing a masked_scatter
|
|
4304
|
+
let img = 0;
|
|
4305
|
+
for (let i = 0; i < image_tokens.length; ++i) {
|
|
4306
|
+
const tokens = image_tokens[i];
|
|
4307
|
+
const embeds = inputs_embeds[i];
|
|
4308
|
+
for (let j = 0; j < tokens.length; ++j) {
|
|
4309
|
+
embeds[tokens[j]].data.set(image_features[img++].data)
|
|
4310
|
+
}
|
|
4311
|
+
}
|
|
4312
|
+
return { inputs_embeds, attention_mask }
|
|
4313
|
+
}
|
|
4314
|
+
|
|
4315
|
+
prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
|
|
4316
|
+
// Overwritten -- in specific circumstances we don't want to forward image inputs to the model
|
|
4317
|
+
if (model_inputs.attention_mask && !model_inputs.position_ids) {
|
|
4318
|
+
// Calculate position_ids and rope_deltas
|
|
4319
|
+
if (!model_inputs.past_key_values) {
|
|
4320
|
+
([model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
|
|
4321
|
+
model_inputs.input_ids,
|
|
4322
|
+
model_inputs.image_grid_thw,
|
|
4323
|
+
model_inputs.video_grid_thw,
|
|
4324
|
+
model_inputs.attention_mask,
|
|
4325
|
+
));
|
|
4326
|
+
|
|
4327
|
+
} else {
|
|
4328
|
+
model_inputs.pixel_values = null;
|
|
4329
|
+
// model_inputs.pixel_values_videos = null;
|
|
4330
|
+
|
|
4331
|
+
const delta = BigInt(Object.values(model_inputs.past_key_values)[0].dims.at(-2));
|
|
4332
|
+
const rope_deltas_list = model_inputs.rope_deltas.map(x => delta + x);
|
|
4333
|
+
model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0)
|
|
4334
|
+
}
|
|
4335
|
+
}
|
|
4336
|
+
|
|
4337
|
+
return model_inputs;
|
|
4338
|
+
}
|
|
4339
|
+
}
|
|
4340
|
+
|
|
3901
4341
|
|
|
3902
4342
|
//////////////////////////////////////////////////
|
|
3903
4343
|
// Phi models
|
|
@@ -3985,6 +4425,17 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
|
|
|
3985
4425
|
}
|
|
3986
4426
|
//////////////////////////////////////////////////
|
|
3987
4427
|
|
|
4428
|
+
|
|
4429
|
+
//////////////////////////////////////////////////
|
|
4430
|
+
export class VitPosePreTrainedModel extends PreTrainedModel { }
|
|
4431
|
+
|
|
4432
|
+
/**
|
|
4433
|
+
* The VitPose model with a pose estimation head on top.
|
|
4434
|
+
*/
|
|
4435
|
+
export class VitPoseForPoseEstimation extends VitPosePreTrainedModel { }
|
|
4436
|
+
//////////////////////////////////////////////////
|
|
4437
|
+
|
|
4438
|
+
|
|
3988
4439
|
//////////////////////////////////////////////////
|
|
3989
4440
|
export class PvtPreTrainedModel extends PreTrainedModel { }
|
|
3990
4441
|
export class PvtModel extends PvtPreTrainedModel { }
|
|
@@ -5583,8 +6034,7 @@ export class ClapModel extends ClapPreTrainedModel { }
|
|
|
5583
6034
|
* ```
|
|
5584
6035
|
*/
|
|
5585
6036
|
export class ClapTextModelWithProjection extends ClapPreTrainedModel {
|
|
5586
|
-
|
|
5587
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
6037
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
5588
6038
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
5589
6039
|
// Update default model file name if not provided
|
|
5590
6040
|
options.model_file_name ??= 'text_model';
|
|
@@ -5619,7 +6069,7 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
|
|
|
5619
6069
|
* ```
|
|
5620
6070
|
*/
|
|
5621
6071
|
export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
|
|
5622
|
-
/** @type {PreTrainedModel.from_pretrained} */
|
|
6072
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
5623
6073
|
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
5624
6074
|
// Update default model file name if not provided
|
|
5625
6075
|
options.model_file_name ??= 'audio_model';
|
|
@@ -5970,6 +6420,170 @@ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel
|
|
|
5970
6420
|
|
|
5971
6421
|
//////////////////////////////////////////////////
|
|
5972
6422
|
|
|
6423
|
+
export class MultiModalityPreTrainedModel extends PreTrainedModel { }
|
|
6424
|
+
export class MultiModalityCausalLM extends MultiModalityPreTrainedModel {
|
|
6425
|
+
forward_params = [
|
|
6426
|
+
// prepare_inputs_embeds
|
|
6427
|
+
'input_ids',
|
|
6428
|
+
'pixel_values',
|
|
6429
|
+
'images_seq_mask',
|
|
6430
|
+
'images_emb_mask',
|
|
6431
|
+
|
|
6432
|
+
// language_model
|
|
6433
|
+
'attention_mask',
|
|
6434
|
+
'position_ids',
|
|
6435
|
+
'past_key_values',
|
|
6436
|
+
];
|
|
6437
|
+
|
|
6438
|
+
constructor(...args) {
|
|
6439
|
+
super(...args);
|
|
6440
|
+
|
|
6441
|
+
// State-based approach to switch out which heads to use during generation
|
|
6442
|
+
this._generation_mode = 'text';
|
|
6443
|
+
}
|
|
6444
|
+
|
|
6445
|
+
async forward(model_inputs) {
|
|
6446
|
+
const mode = this._generation_mode ?? 'text';
|
|
6447
|
+
|
|
6448
|
+
// TODO support re-using PKVs for input_ids.dims[1] !== 1
|
|
6449
|
+
// if (model_inputs.past_key_values) {
|
|
6450
|
+
// // && model_inputs.input_ids.dims[1] === 1
|
|
6451
|
+
// }
|
|
6452
|
+
|
|
6453
|
+
let output_1;
|
|
6454
|
+
if (mode === 'text' || !model_inputs.past_key_values) {
|
|
6455
|
+
const session = this.sessions['prepare_inputs_embeds'];
|
|
6456
|
+
const prep_inputs = pick(model_inputs, session.inputNames);
|
|
6457
|
+
output_1 = await sessionRun(session, prep_inputs);
|
|
6458
|
+
} else {
|
|
6459
|
+
const session = this.sessions['gen_img_embeds'];
|
|
6460
|
+
const prep_inputs = pick({
|
|
6461
|
+
image_ids: model_inputs.input_ids,
|
|
6462
|
+
}, session.inputNames);
|
|
6463
|
+
output_1 = await sessionRun(session, prep_inputs);
|
|
6464
|
+
}
|
|
6465
|
+
|
|
6466
|
+
const input_2 = { ...model_inputs, ...output_1 }
|
|
6467
|
+
const output_2 = await decoderForward(this, input_2);
|
|
6468
|
+
|
|
6469
|
+
const head = this.sessions[
|
|
6470
|
+
mode === 'text'
|
|
6471
|
+
? 'lm_head'
|
|
6472
|
+
: 'gen_head'
|
|
6473
|
+
];
|
|
6474
|
+
if (!head) {
|
|
6475
|
+
throw new Error(`Unable to find "${head}" generation head`);
|
|
6476
|
+
}
|
|
6477
|
+
|
|
6478
|
+
const output_3 = await sessionRun(head, pick(output_2, head.inputNames))
|
|
6479
|
+
|
|
6480
|
+
return {
|
|
6481
|
+
...output_1,
|
|
6482
|
+
...output_2,
|
|
6483
|
+
...output_3,
|
|
6484
|
+
};
|
|
6485
|
+
}
|
|
6486
|
+
|
|
6487
|
+
/**
|
|
6488
|
+
* @param {import('./generation/parameters.js').GenerationFunctionParameters} options
|
|
6489
|
+
*/
|
|
6490
|
+
async generate(options) {
|
|
6491
|
+
this._generation_mode = 'text';
|
|
6492
|
+
return super.generate(options);
|
|
6493
|
+
}
|
|
6494
|
+
|
|
6495
|
+
/**
|
|
6496
|
+
* @param {import('./generation/parameters.js').GenerationFunctionParameters} options
|
|
6497
|
+
*/
|
|
6498
|
+
async generate_images(options) {
|
|
6499
|
+
this._generation_mode = 'image';
|
|
6500
|
+
|
|
6501
|
+
const start_num_tokens = (options.inputs ?? options[this.main_input_name]).dims[1];
|
|
6502
|
+
const all_tokens = await super.generate(options);
|
|
6503
|
+
|
|
6504
|
+
const generated_tokens = (/** @type {Tensor} */(all_tokens)).slice(null, [start_num_tokens, null])
|
|
6505
|
+
|
|
6506
|
+
const image_decode = this.sessions['image_decode'];
|
|
6507
|
+
const { decoded_image } = await sessionRun(image_decode, {
|
|
6508
|
+
generated_tokens,
|
|
6509
|
+
});
|
|
6510
|
+
|
|
6511
|
+
// Equivalent to `np.clip((dec + 1) / 2 * 255, 0, 255)`
|
|
6512
|
+
const clamped = decoded_image
|
|
6513
|
+
.add_(1)
|
|
6514
|
+
.mul_(255 / 2)
|
|
6515
|
+
.clamp_(0, 255)
|
|
6516
|
+
.to('uint8');
|
|
6517
|
+
|
|
6518
|
+
// Return as a list of images
|
|
6519
|
+
const images = [];
|
|
6520
|
+
for (const tensor of clamped) {
|
|
6521
|
+
const img = RawImage.fromTensor(tensor);
|
|
6522
|
+
images.push(img);
|
|
6523
|
+
}
|
|
6524
|
+
return images;
|
|
6525
|
+
}
|
|
6526
|
+
}
|
|
6527
|
+
|
|
6528
|
+
export class MgpstrModelOutput extends ModelOutput {
|
|
6529
|
+
constructor({ char_logits, bpe_logits, wp_logits }) {
|
|
6530
|
+
super();
|
|
6531
|
+
this.char_logits = char_logits;
|
|
6532
|
+
this.bpe_logits = bpe_logits;
|
|
6533
|
+
this.wp_logits = wp_logits;
|
|
6534
|
+
}
|
|
6535
|
+
|
|
6536
|
+
get logits() {
|
|
6537
|
+
return [this.char_logits, this.bpe_logits, this.wp_logits];
|
|
6538
|
+
}
|
|
6539
|
+
}
|
|
6540
|
+
|
|
6541
|
+
export class MgpstrPreTrainedModel extends PreTrainedModel { }
|
|
6542
|
+
|
|
6543
|
+
/**
|
|
6544
|
+
* MGP-STR Model transformer with three classification heads on top
|
|
6545
|
+
* (three A^3 modules and three linear layer on top of the transformer encoder output) for scene text recognition (STR).
|
|
6546
|
+
*/
|
|
6547
|
+
export class MgpstrForSceneTextRecognition extends MgpstrPreTrainedModel {
|
|
6548
|
+
/**
|
|
6549
|
+
* @param {any} model_inputs
|
|
6550
|
+
*/
|
|
6551
|
+
async _call(model_inputs) {
|
|
6552
|
+
return new MgpstrModelOutput(await super._call(model_inputs));
|
|
6553
|
+
}
|
|
6554
|
+
}
|
|
6555
|
+
|
|
6556
|
+
//////////////////////////////////////////////////
|
|
6557
|
+
// PatchTST Transformer models
|
|
6558
|
+
export class PatchTSTPreTrainedModel extends PreTrainedModel { }
|
|
6559
|
+
|
|
6560
|
+
/**
|
|
6561
|
+
* The bare PatchTST Model outputting raw hidden-states without any specific head.
|
|
6562
|
+
*/
|
|
6563
|
+
export class PatchTSTModel extends PatchTSTPreTrainedModel { }
|
|
6564
|
+
|
|
6565
|
+
/**
|
|
6566
|
+
* The PatchTST for prediction model.
|
|
6567
|
+
*/
|
|
6568
|
+
export class PatchTSTForPrediction extends PatchTSTPreTrainedModel { }
|
|
6569
|
+
//////////////////////////////////////////////////
|
|
6570
|
+
|
|
6571
|
+
//////////////////////////////////////////////////
|
|
6572
|
+
// PatchTSMixer Transformer models
|
|
6573
|
+
export class PatchTSMixerPreTrainedModel extends PreTrainedModel { }
|
|
6574
|
+
|
|
6575
|
+
/**
|
|
6576
|
+
* The bare PatchTSMixer Model outputting raw hidden-states without any specific head.
|
|
6577
|
+
*/
|
|
6578
|
+
export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel { }
|
|
6579
|
+
|
|
6580
|
+
/**
|
|
6581
|
+
* The PatchTSMixer for prediction model.
|
|
6582
|
+
*/
|
|
6583
|
+
export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel { }
|
|
6584
|
+
//////////////////////////////////////////////////
|
|
6585
|
+
|
|
6586
|
+
|
|
5973
6587
|
//////////////////////////////////////////////////
|
|
5974
6588
|
// AutoModels, used to simplify construction of PreTrainedModels
|
|
5975
6589
|
// (uses config to instantiate correct class)
|
|
@@ -6064,6 +6678,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6064
6678
|
['clipseg', ['CLIPSegModel', CLIPSegModel]],
|
|
6065
6679
|
['chinese_clip', ['ChineseCLIPModel', ChineseCLIPModel]],
|
|
6066
6680
|
['siglip', ['SiglipModel', SiglipModel]],
|
|
6681
|
+
['jina_clip', ['JinaCLIPModel', JinaCLIPModel]],
|
|
6067
6682
|
['mobilebert', ['MobileBertModel', MobileBertModel]],
|
|
6068
6683
|
['squeezebert', ['SqueezeBertModel', SqueezeBertModel]],
|
|
6069
6684
|
['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]],
|
|
@@ -6108,6 +6723,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6108
6723
|
['efficientnet', ['EfficientNetModel', EfficientNetModel]],
|
|
6109
6724
|
|
|
6110
6725
|
['decision_transformer', ['DecisionTransformerModel', DecisionTransformerModel]],
|
|
6726
|
+
['patchtst', ['PatchTSTForPrediction', PatchTSTModel]],
|
|
6727
|
+
['patchtsmixer', ['PatchTSMixerForPrediction', PatchTSMixerModel]],
|
|
6111
6728
|
|
|
6112
6729
|
['mobilenet_v1', ['MobileNetV1Model', MobileNetV1Model]],
|
|
6113
6730
|
['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]],
|
|
@@ -6115,6 +6732,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
|
|
|
6115
6732
|
['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]],
|
|
6116
6733
|
|
|
6117
6734
|
['maskformer', ['MaskFormerModel', MaskFormerModel]],
|
|
6735
|
+
['mgp-str', ['MgpstrForSceneTextRecognition', MgpstrForSceneTextRecognition]],
|
|
6118
6736
|
]);
|
|
6119
6737
|
|
|
6120
6738
|
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
|
|
@@ -6252,6 +6870,11 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
|
|
|
6252
6870
|
['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]],
|
|
6253
6871
|
]);
|
|
6254
6872
|
|
|
6873
|
+
const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([
|
|
6874
|
+
['multi_modality', ['MultiModalityCausalLM', MultiModalityCausalLM]],
|
|
6875
|
+
]);
|
|
6876
|
+
|
|
6877
|
+
|
|
6255
6878
|
const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
|
|
6256
6879
|
['bert', ['BertForMaskedLM', BertForMaskedLM]],
|
|
6257
6880
|
['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]],
|
|
@@ -6295,8 +6918,10 @@ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
|
|
|
6295
6918
|
|
|
6296
6919
|
const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
|
|
6297
6920
|
['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]],
|
|
6921
|
+
['llava_onevision', ['LlavaOnevisionForConditionalGeneration', LlavaOnevisionForConditionalGeneration]],
|
|
6298
6922
|
['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]],
|
|
6299
6923
|
['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
|
|
6924
|
+
['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
|
|
6300
6925
|
]);
|
|
6301
6926
|
|
|
6302
6927
|
const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
|
|
@@ -6392,6 +7017,11 @@ const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
|
|
|
6392
7017
|
['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]],
|
|
6393
7018
|
]);
|
|
6394
7019
|
|
|
7020
|
+
const MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = new Map([
|
|
7021
|
+
['patchtst', ['PatchTSTForPrediction', PatchTSTForPrediction]],
|
|
7022
|
+
['patchtsmixer', ['PatchTSMixerForPrediction', PatchTSMixerForPrediction]],
|
|
7023
|
+
])
|
|
7024
|
+
|
|
6395
7025
|
const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([
|
|
6396
7026
|
['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]],
|
|
6397
7027
|
])
|
|
@@ -6408,11 +7038,16 @@ const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
|
|
|
6408
7038
|
['sapiens', ['SapiensForNormalEstimation', SapiensForNormalEstimation]],
|
|
6409
7039
|
])
|
|
6410
7040
|
|
|
7041
|
+
const MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES = new Map([
|
|
7042
|
+
['vitpose', ['VitPoseForPoseEstimation', VitPoseForPoseEstimation]],
|
|
7043
|
+
])
|
|
7044
|
+
|
|
6411
7045
|
// NOTE: This is custom to Transformers.js, and is necessary because certain models
|
|
6412
7046
|
// (e.g., CLIP) are split into vision and text components
|
|
6413
7047
|
const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([
|
|
6414
7048
|
['clip', ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection]],
|
|
6415
7049
|
['siglip', ['SiglipVisionModel', SiglipVisionModel]],
|
|
7050
|
+
['jina_clip', ['JinaCLIPVisionModel', JinaCLIPVisionModel]],
|
|
6416
7051
|
])
|
|
6417
7052
|
|
|
6418
7053
|
const MODEL_CLASS_TYPE_MAPPING = [
|
|
@@ -6424,6 +7059,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
|
|
|
6424
7059
|
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
|
|
6425
7060
|
[MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
|
|
6426
7061
|
[MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
|
|
7062
|
+
[MODEL_FOR_MULTIMODALITY_MAPPING_NAMES, MODEL_TYPES.MultiModality],
|
|
6427
7063
|
[MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6428
7064
|
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6429
7065
|
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
|
|
@@ -6433,9 +7069,11 @@ const MODEL_CLASS_TYPE_MAPPING = [
|
|
|
6433
7069
|
[MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6434
7070
|
[MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6435
7071
|
[MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
7072
|
+
[MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6436
7073
|
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6437
7074
|
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6438
7075
|
[MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
7076
|
+
[MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6439
7077
|
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6440
7078
|
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
6441
7079
|
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
|
|
@@ -6466,6 +7104,7 @@ const CUSTOM_MAPPING = [
|
|
|
6466
7104
|
|
|
6467
7105
|
['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
|
|
6468
7106
|
['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly],
|
|
7107
|
+
['JinaCLIPTextModel', JinaCLIPTextModel, MODEL_TYPES.EncoderOnly],
|
|
6469
7108
|
['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
|
|
6470
7109
|
['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly],
|
|
6471
7110
|
]
|
|
@@ -6707,6 +7346,10 @@ export class AutoModelForNormalEstimation extends PretrainedMixin {
|
|
|
6707
7346
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES];
|
|
6708
7347
|
}
|
|
6709
7348
|
|
|
7349
|
+
export class AutoModelForPoseEstimation extends PretrainedMixin {
|
|
7350
|
+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES];
|
|
7351
|
+
}
|
|
7352
|
+
|
|
6710
7353
|
export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
|
|
6711
7354
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
|
|
6712
7355
|
}
|