@huggingface/transformers 3.0.1 → 3.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (239) hide show
  1. package/README.md +14 -4
  2. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  3. package/dist/transformers.cjs +16607 -13472
  4. package/dist/transformers.cjs.map +1 -1
  5. package/dist/transformers.js +16601 -13451
  6. package/dist/transformers.js.map +1 -1
  7. package/dist/transformers.min.cjs +238 -52
  8. package/dist/transformers.min.cjs.map +1 -1
  9. package/dist/transformers.min.js +229 -43
  10. package/dist/transformers.min.js.map +1 -1
  11. package/dist/transformers.min.mjs +240 -54
  12. package/dist/transformers.min.mjs.map +1 -1
  13. package/dist/transformers.mjs +16017 -12878
  14. package/dist/transformers.mjs.map +1 -1
  15. package/package.json +7 -7
  16. package/src/base/feature_extraction_utils.js +54 -0
  17. package/src/base/image_processors_utils.js +1089 -0
  18. package/src/base/processing_utils.js +145 -0
  19. package/src/configs.js +15 -3
  20. package/src/env.js +15 -4
  21. package/src/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js +90 -0
  22. package/src/models/auto/feature_extraction_auto.js +41 -0
  23. package/src/models/auto/image_processing_auto.js +29 -0
  24. package/src/models/auto/processing_auto.js +100 -0
  25. package/src/models/beit/image_processing_beit.js +5 -0
  26. package/src/models/bit/image_processing_bit.js +5 -0
  27. package/src/models/chinese_clip/image_processing_chinese_clip.js +5 -0
  28. package/src/models/clap/feature_extraction_clap.js +159 -0
  29. package/src/models/clip/image_processing_clip.js +6 -0
  30. package/src/models/convnext/image_processing_convnext.js +45 -0
  31. package/src/models/deit/image_processing_deit.js +6 -0
  32. package/src/models/detr/image_processing_detr.js +52 -0
  33. package/src/models/donut/image_processing_donut.js +31 -0
  34. package/src/models/dpt/image_processing_dpt.js +6 -0
  35. package/src/models/efficientnet/image_processing_efficientnet.js +13 -0
  36. package/src/models/feature_extractors.js +12 -0
  37. package/src/models/florence2/processing_florence2.js +128 -0
  38. package/src/models/glpn/image_processing_glpn.js +5 -0
  39. package/src/models/image_processors.js +36 -0
  40. package/src/models/janus/image_processing_janus.js +26 -0
  41. package/src/models/janus/processing_janus.js +123 -0
  42. package/src/models/jina_clip/image_processing_jina_clip.js +26 -0
  43. package/src/models/jina_clip/processing_jina_clip.js +24 -0
  44. package/src/models/llava_onevision/image_processing_llava_onevision.js +5 -0
  45. package/src/models/mask2former/image_processing_mask2former.js +5 -0
  46. package/src/models/maskformer/image_processing_maskformer.js +18 -0
  47. package/src/models/mgp_str/processing_mgp_str.js +170 -0
  48. package/src/models/mobilenet_v1/image_processing_mobilenet_v1.js +7 -0
  49. package/src/models/mobilenet_v2/image_processing_mobilenet_v2.js +7 -0
  50. package/src/models/mobilenet_v3/image_processing_mobilenet_v3.js +7 -0
  51. package/src/models/mobilenet_v4/image_processing_mobilenet_v4.js +7 -0
  52. package/src/models/mobilevit/image_processing_mobilevit.js +6 -0
  53. package/src/models/nougat/image_processing_nougat.js +5 -0
  54. package/src/models/owlv2/image_processing_owlv2.js +5 -0
  55. package/src/models/owlvit/image_processing_owlvit.js +12 -0
  56. package/src/models/owlvit/processing_owlvit.js +7 -0
  57. package/src/models/processors.js +11 -0
  58. package/src/models/pvt/image_processing_pvt.js +5 -0
  59. package/src/models/pyannote/feature_extraction_pyannote.js +28 -0
  60. package/src/models/pyannote/processing_pyannote.js +71 -0
  61. package/src/models/qwen2_vl/image_processing_qwen2_vl.js +52 -0
  62. package/src/models/qwen2_vl/processing_qwen2_vl.js +52 -0
  63. package/src/models/rt_detr/image_processing_rt_detr.js +12 -0
  64. package/src/models/sam/image_processing_sam.js +242 -0
  65. package/src/models/sam/processing_sam.js +20 -0
  66. package/src/models/sapiens/image_processing_sapiens.js +13 -0
  67. package/src/models/seamless_m4t/feature_extraction_seamless_m4t.js +180 -0
  68. package/src/models/segformer/image_processing_segformer.js +13 -0
  69. package/src/models/siglip/image_processing_siglip.js +5 -0
  70. package/src/models/speecht5/feature_extraction_speecht5.js +4 -0
  71. package/src/models/speecht5/processing_speecht5.js +17 -0
  72. package/src/models/swin2sr/image_processing_swin2sr.js +24 -0
  73. package/src/models/vit/image_processing_vit.js +7 -0
  74. package/src/models/vitmatte/image_processing_vitmatte.js +50 -0
  75. package/src/models/vitpose/image_processing_vitpose.js +89 -0
  76. package/src/models/wav2vec2/feature_extraction_wav2vec2.js +44 -0
  77. package/src/models/wav2vec2/processing_wav2vec2.js +15 -0
  78. package/src/models/wespeaker/feature_extraction_wespeaker.js +100 -0
  79. package/src/models/whisper/feature_extraction_whisper.js +84 -0
  80. package/src/models/whisper/processing_whisper.js +21 -0
  81. package/src/models/yolos/image_processing_yolos.js +12 -0
  82. package/src/models.js +695 -32
  83. package/src/pipelines.js +8 -8
  84. package/src/tokenizers.js +5 -0
  85. package/src/transformers.js +15 -2
  86. package/src/utils/constants.js +8 -1
  87. package/src/utils/core.js +37 -9
  88. package/src/utils/hub.js +2 -1
  89. package/src/utils/image.js +68 -17
  90. package/src/utils/tensor.js +33 -1
  91. package/types/base/feature_extraction_utils.d.ts +41 -0
  92. package/types/base/feature_extraction_utils.d.ts.map +1 -0
  93. package/types/base/image_processors_utils.d.ts +323 -0
  94. package/types/base/image_processors_utils.d.ts.map +1 -0
  95. package/types/base/processing_utils.d.ts +80 -0
  96. package/types/base/processing_utils.d.ts.map +1 -0
  97. package/types/configs.d.ts +4 -1
  98. package/types/configs.d.ts.map +1 -1
  99. package/types/env.d.ts.map +1 -1
  100. package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts +25 -0
  101. package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts.map +1 -0
  102. package/types/models/auto/feature_extraction_auto.d.ts +5 -0
  103. package/types/models/auto/feature_extraction_auto.d.ts.map +1 -0
  104. package/types/models/auto/image_processing_auto.d.ts +5 -0
  105. package/types/models/auto/image_processing_auto.d.ts.map +1 -0
  106. package/types/models/auto/processing_auto.d.ts +35 -0
  107. package/types/models/auto/processing_auto.d.ts.map +1 -0
  108. package/types/models/beit/image_processing_beit.d.ts +4 -0
  109. package/types/models/beit/image_processing_beit.d.ts.map +1 -0
  110. package/types/models/bit/image_processing_bit.d.ts +4 -0
  111. package/types/models/bit/image_processing_bit.d.ts.map +1 -0
  112. package/types/models/chinese_clip/image_processing_chinese_clip.d.ts +4 -0
  113. package/types/models/chinese_clip/image_processing_chinese_clip.d.ts.map +1 -0
  114. package/types/models/clap/feature_extraction_clap.d.ts +57 -0
  115. package/types/models/clap/feature_extraction_clap.d.ts.map +1 -0
  116. package/types/models/clip/image_processing_clip.d.ts +6 -0
  117. package/types/models/clip/image_processing_clip.d.ts.map +1 -0
  118. package/types/models/convnext/image_processing_convnext.d.ts +12 -0
  119. package/types/models/convnext/image_processing_convnext.d.ts.map +1 -0
  120. package/types/models/deit/image_processing_deit.d.ts +6 -0
  121. package/types/models/deit/image_processing_deit.d.ts.map +1 -0
  122. package/types/models/detr/image_processing_detr.d.ts +42 -0
  123. package/types/models/detr/image_processing_detr.d.ts.map +1 -0
  124. package/types/models/donut/image_processing_donut.d.ts +7 -0
  125. package/types/models/donut/image_processing_donut.d.ts.map +1 -0
  126. package/types/models/dpt/image_processing_dpt.d.ts +6 -0
  127. package/types/models/dpt/image_processing_dpt.d.ts.map +1 -0
  128. package/types/models/efficientnet/image_processing_efficientnet.d.ts +6 -0
  129. package/types/models/efficientnet/image_processing_efficientnet.d.ts.map +1 -0
  130. package/types/models/feature_extractors.d.ts +10 -0
  131. package/types/models/feature_extractors.d.ts.map +1 -0
  132. package/types/models/florence2/processing_florence2.d.ts +39 -0
  133. package/types/models/florence2/processing_florence2.d.ts.map +1 -0
  134. package/types/models/glpn/image_processing_glpn.d.ts +4 -0
  135. package/types/models/glpn/image_processing_glpn.d.ts.map +1 -0
  136. package/types/models/image_processors.d.ts +36 -0
  137. package/types/models/image_processors.d.ts.map +1 -0
  138. package/types/models/janus/image_processing_janus.d.ts +7 -0
  139. package/types/models/janus/image_processing_janus.d.ts.map +1 -0
  140. package/types/models/janus/processing_janus.d.ts +77 -0
  141. package/types/models/janus/processing_janus.d.ts.map +1 -0
  142. package/types/models/jina_clip/image_processing_jina_clip.d.ts +5 -0
  143. package/types/models/jina_clip/image_processing_jina_clip.d.ts.map +1 -0
  144. package/types/models/jina_clip/processing_jina_clip.d.ts +9 -0
  145. package/types/models/jina_clip/processing_jina_clip.d.ts.map +1 -0
  146. package/types/models/llava_onevision/image_processing_llava_onevision.d.ts +4 -0
  147. package/types/models/llava_onevision/image_processing_llava_onevision.d.ts.map +1 -0
  148. package/types/models/mask2former/image_processing_mask2former.d.ts +4 -0
  149. package/types/models/mask2former/image_processing_mask2former.d.ts.map +1 -0
  150. package/types/models/maskformer/image_processing_maskformer.d.ts +22 -0
  151. package/types/models/maskformer/image_processing_maskformer.d.ts.map +1 -0
  152. package/types/models/mgp_str/processing_mgp_str.d.ts +64 -0
  153. package/types/models/mgp_str/processing_mgp_str.d.ts.map +1 -0
  154. package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts +6 -0
  155. package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts.map +1 -0
  156. package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts +6 -0
  157. package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts.map +1 -0
  158. package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts +6 -0
  159. package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts.map +1 -0
  160. package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts +6 -0
  161. package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts.map +1 -0
  162. package/types/models/mobilevit/image_processing_mobilevit.d.ts +6 -0
  163. package/types/models/mobilevit/image_processing_mobilevit.d.ts.map +1 -0
  164. package/types/models/nougat/image_processing_nougat.d.ts +4 -0
  165. package/types/models/nougat/image_processing_nougat.d.ts.map +1 -0
  166. package/types/models/owlv2/image_processing_owlv2.d.ts +4 -0
  167. package/types/models/owlv2/image_processing_owlv2.d.ts.map +1 -0
  168. package/types/models/owlvit/image_processing_owlvit.d.ts +10 -0
  169. package/types/models/owlvit/image_processing_owlvit.d.ts.map +1 -0
  170. package/types/models/owlvit/processing_owlvit.d.ts +8 -0
  171. package/types/models/owlvit/processing_owlvit.d.ts.map +1 -0
  172. package/types/models/processors.d.ts +12 -0
  173. package/types/models/processors.d.ts.map +1 -0
  174. package/types/models/pvt/image_processing_pvt.d.ts +4 -0
  175. package/types/models/pvt/image_processing_pvt.d.ts.map +1 -0
  176. package/types/models/pyannote/feature_extraction_pyannote.d.ts +13 -0
  177. package/types/models/pyannote/feature_extraction_pyannote.d.ts.map +1 -0
  178. package/types/models/pyannote/processing_pyannote.d.ts +30 -0
  179. package/types/models/pyannote/processing_pyannote.d.ts.map +1 -0
  180. package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts +11 -0
  181. package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts.map +1 -0
  182. package/types/models/qwen2_vl/processing_qwen2_vl.d.ts +17 -0
  183. package/types/models/qwen2_vl/processing_qwen2_vl.d.ts.map +1 -0
  184. package/types/models/rt_detr/image_processing_rt_detr.d.ts +8 -0
  185. package/types/models/rt_detr/image_processing_rt_detr.d.ts.map +1 -0
  186. package/types/models/sam/image_processing_sam.d.ts +103 -0
  187. package/types/models/sam/image_processing_sam.d.ts.map +1 -0
  188. package/types/models/sam/processing_sam.d.ts +9 -0
  189. package/types/models/sam/processing_sam.d.ts.map +1 -0
  190. package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts +34 -0
  191. package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts.map +1 -0
  192. package/types/models/segformer/image_processing_segformer.d.ts +10 -0
  193. package/types/models/segformer/image_processing_segformer.d.ts.map +1 -0
  194. package/types/models/siglip/image_processing_siglip.d.ts +4 -0
  195. package/types/models/siglip/image_processing_siglip.d.ts.map +1 -0
  196. package/types/models/speecht5/feature_extraction_speecht5.d.ts +4 -0
  197. package/types/models/speecht5/feature_extraction_speecht5.d.ts.map +1 -0
  198. package/types/models/speecht5/processing_speecht5.d.ts +14 -0
  199. package/types/models/speecht5/processing_speecht5.d.ts.map +1 -0
  200. package/types/models/swin2sr/image_processing_swin2sr.d.ts +5 -0
  201. package/types/models/swin2sr/image_processing_swin2sr.d.ts.map +1 -0
  202. package/types/models/vit/image_processing_vit.d.ts +6 -0
  203. package/types/models/vit/image_processing_vit.d.ts.map +1 -0
  204. package/types/models/vitmatte/image_processing_vitmatte.d.ts +12 -0
  205. package/types/models/vitmatte/image_processing_vitmatte.d.ts.map +1 -0
  206. package/types/models/vitpose/image_processing_vitpose.d.ts +26 -0
  207. package/types/models/vitpose/image_processing_vitpose.d.ts.map +1 -0
  208. package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts +19 -0
  209. package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts.map +1 -0
  210. package/types/models/wav2vec2/processing_wav2vec2.d.ts +12 -0
  211. package/types/models/wav2vec2/processing_wav2vec2.d.ts.map +1 -0
  212. package/types/models/wespeaker/feature_extraction_wespeaker.d.ts +23 -0
  213. package/types/models/wespeaker/feature_extraction_wespeaker.d.ts.map +1 -0
  214. package/types/models/whisper/feature_extraction_whisper.d.ts +21 -0
  215. package/types/models/whisper/feature_extraction_whisper.d.ts.map +1 -0
  216. package/types/models/whisper/processing_whisper.d.ts +17 -0
  217. package/types/models/whisper/processing_whisper.d.ts.map +1 -0
  218. package/types/models/yolos/image_processing_yolos.d.ts +10 -0
  219. package/types/models/yolos/image_processing_yolos.d.ts.map +1 -0
  220. package/types/models.d.ts +152 -0
  221. package/types/models.d.ts.map +1 -1
  222. package/types/pipelines.d.ts +2 -3
  223. package/types/pipelines.d.ts.map +1 -1
  224. package/types/tokenizers.d.ts +3 -0
  225. package/types/tokenizers.d.ts.map +1 -1
  226. package/types/transformers.d.ts +10 -1
  227. package/types/utils/constants.d.ts +6 -0
  228. package/types/utils/constants.d.ts.map +1 -1
  229. package/types/utils/core.d.ts +58 -3
  230. package/types/utils/core.d.ts.map +1 -1
  231. package/types/utils/hub.d.ts +1 -1
  232. package/types/utils/hub.d.ts.map +1 -1
  233. package/types/utils/image.d.ts +10 -2
  234. package/types/utils/image.d.ts.map +1 -1
  235. package/types/utils/tensor.d.ts +34 -1
  236. package/types/utils/tensor.d.ts.map +1 -1
  237. package/src/processors.js +0 -2655
  238. package/types/processors.d.ts +0 -924
  239. package/types/processors.d.ts.map +0 -1
package/src/models.js CHANGED
@@ -61,7 +61,6 @@ import {
61
61
  } from './utils/generic.js';
62
62
 
63
63
  import {
64
- isIntegralNumber,
65
64
  mergeArrays,
66
65
  pick,
67
66
  } from './utils/core.js';
@@ -99,17 +98,20 @@ import {
99
98
 
100
99
  import {
101
100
  cat,
102
- full_like,
103
101
  mean,
102
+ zeros,
103
+ zeros_like,
104
104
  ones,
105
105
  ones_like,
106
+ full,
107
+ full_like,
106
108
  stack,
107
109
  std_mean,
108
110
  Tensor,
109
- zeros_like,
110
111
  } from './utils/tensor.js';
112
+ import { RawImage } from './utils/image.js';
111
113
 
112
- import { dynamic_time_warping, medianFilter } from './utils/maths.js';
114
+ import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
113
115
  import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
114
116
  import { LogitsSampler } from './generation/logits_sampler.js';
115
117
  import { apis } from './env.js';
@@ -128,6 +130,7 @@ const MODEL_TYPES = {
128
130
  MaskGeneration: 5,
129
131
  ImageTextToText: 6,
130
132
  Musicgen: 7,
133
+ MultiModality: 8,
131
134
  }
132
135
  //////////////////////////////////////////////////
133
136
 
@@ -386,7 +389,7 @@ async function sessionRun(session, inputs) {
386
389
  } catch (e) {
387
390
  // This usually occurs when the inputs are of the wrong type.
388
391
  console.error(`An error occurred during model execution: "${e}".`);
389
- console.error('Inputs given to model:', checkedInputs);
392
+ console.error('Inputs given to model:', checkedInputs)
390
393
  throw e;
391
394
  }
392
395
  }
@@ -579,11 +582,11 @@ async function imageTextToTextForward(self, {
579
582
 
580
583
  if (!inputs_embeds) {
581
584
  // 1. Extract the input embeddings
582
- inputs_embeds = await self.encode_text({ input_ids });
585
+ inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
583
586
 
584
587
  // 2. Possibly, merge text and images
585
588
  if (pixel_values && input_ids.dims[1] !== 1) {
586
- const image_features = await self.encode_image({ pixel_values });
589
+ const image_features = await self.encode_image({ pixel_values, ...kwargs });
587
590
 
588
591
  ({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
589
592
  image_features,
@@ -604,6 +607,16 @@ async function imageTextToTextForward(self, {
604
607
  }
605
608
  }
606
609
 
610
+ if (!position_ids) {
611
+
612
+ if (self.config.model_type === 'qwen2_vl') {
613
+ // Special case for qwen2_vl models
614
+ // @ts-ignore
615
+ const { image_grid_thw, video_grid_thw } = kwargs;
616
+ [position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
617
+ }
618
+ }
619
+
607
620
  const outputs = await decoderForward(self, {
608
621
  inputs_embeds,
609
622
  past_key_values,
@@ -615,34 +628,54 @@ async function imageTextToTextForward(self, {
615
628
  return outputs;
616
629
  }
617
630
 
618
- function createPositionIds(model_inputs, past_key_values = null) {
619
- // If the model supports providing position_ids, we create position_ids on the fly for batch generation,
620
- // by computing the cumulative sum of the attention mask along the sequence length dimension.
621
- //
622
- // Equivalent to:
623
- // position_ids = attention_mask.long().cumsum(-1) - 1
624
- // position_ids.masked_fill_(attention_mask == 0, 1)
625
- // if past_key_values:
626
- // position_ids = position_ids[:, -input_ids.shape[1] :]
627
- const { input_ids, inputs_embeds, attention_mask } = model_inputs;
631
+ /**
632
+ * Helper function to perform the following:
633
+ * ```python
634
+ * x = attention_mask.long().cumsum(-1) - 1
635
+ * x.masked_fill_(attention_mask == 0, 1)
636
+ * ```
637
+ * @param {Tensor} attention_mask
638
+ * @returns {{data: BigInt64Array, dims: number[]}}
639
+ */
640
+ function cumsum_masked_fill(attention_mask) {
628
641
  const [bz, seq_len] = attention_mask.dims;
642
+ const attn_mask_data = attention_mask.data;
629
643
 
630
- const data = new BigInt64Array(attention_mask.data.length);
644
+ const data = new BigInt64Array(attn_mask_data.length);
631
645
  for (let i = 0; i < bz; ++i) {
632
646
  const start = i * seq_len;
633
647
  let sum = BigInt(0);
634
648
  for (let j = 0; j < seq_len; ++j) {
635
649
  const index = start + j;
636
- if (attention_mask.data[index] === 0n) {
650
+ if (attn_mask_data[index] === 0n) {
637
651
  data[index] = BigInt(1);
638
652
  } else { // === 1n
639
653
  data[index] = sum;
640
- sum += attention_mask.data[index];
654
+ sum += attn_mask_data[index];
641
655
  }
642
656
  }
643
657
  }
658
+ return { data, dims: attention_mask.dims };
644
659
 
645
- let position_ids = new Tensor('int64', data, attention_mask.dims);
660
+ }
661
+
662
+ /**
663
+ * If the model supports providing position_ids, we create position_ids on the fly for batch generation,
664
+ * by computing the cumulative sum of the attention mask along the sequence length dimension.
665
+ *
666
+ * Equivalent to:
667
+ * ```python
668
+ * position_ids = attention_mask.long().cumsum(-1) - 1
669
+ * position_ids.masked_fill_(attention_mask == 0, 1)
670
+ * if past_key_values:
671
+ * position_ids = position_ids[:, -input_ids.shape[1] :]
672
+ * ```
673
+ */
674
+ function createPositionIds(model_inputs, past_key_values = null) {
675
+ const { input_ids, inputs_embeds, attention_mask } = model_inputs;
676
+
677
+ const { data, dims } = cumsum_masked_fill(attention_mask);
678
+ let position_ids = new Tensor('int64', data, dims);
646
679
  if (past_key_values) {
647
680
  const offset = -(input_ids ?? inputs_embeds).dims.at(1);
648
681
  position_ids = position_ids.slice(null, [offset, null]);
@@ -716,6 +749,52 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
716
749
  }
717
750
  }
718
751
 
752
+ function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
753
+ const has_past_key_values = !!model_inputs.past_key_values;
754
+
755
+ if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
756
+ if (has_past_key_values) {
757
+ model_inputs.input_ids = cat([
758
+ model_inputs.input_ids,
759
+ model_inputs.input_ids,
760
+ ], 0)
761
+ // NOTE: attention_mask handled in generation
762
+ } else {
763
+ model_inputs.input_ids = cat([
764
+ model_inputs.input_ids,
765
+ full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)),
766
+ ], 0);
767
+ model_inputs.attention_mask = cat([
768
+ model_inputs.attention_mask,
769
+ full_like(model_inputs.attention_mask, 0n),
770
+ ], 0);
771
+ }
772
+ }
773
+
774
+ if (has_past_key_values || !model_inputs.pixel_values) {
775
+ model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0);
776
+ }
777
+
778
+ if (has_past_key_values) {
779
+ const num_img_tokens = 0;
780
+ const num_text_tokens = 1;
781
+ const has_image = num_img_tokens > 0 ? 1 : 0;
782
+
783
+ const batch_size = 1;
784
+ model_inputs.images_seq_mask = new Tensor(
785
+ 'bool',
786
+ new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens),
787
+ [batch_size, num_img_tokens + num_text_tokens],
788
+ );
789
+ model_inputs.images_emb_mask = new Tensor(
790
+ 'bool',
791
+ new Array(num_img_tokens).fill(!!has_image),
792
+ [batch_size, 1, num_img_tokens],
793
+ );
794
+ }
795
+ return model_inputs;
796
+ }
797
+
719
798
  //////////////////////////////////////////////////
720
799
 
721
800
  //////////////////////////////////////////////////
@@ -769,6 +848,11 @@ export class PreTrainedModel extends Callable {
769
848
  this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
770
849
  break;
771
850
 
851
+ case MODEL_TYPES.MultiModality:
852
+ this.can_generate = true;
853
+ this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
854
+ break;
855
+
772
856
  default:
773
857
  // should be MODEL_TYPES.EncoderOnly
774
858
  this._forward = encoderForward;
@@ -912,6 +996,21 @@ export class PreTrainedModel extends Callable {
912
996
  }, options),
913
997
  ]);
914
998
 
999
+ } else if (modelType === MODEL_TYPES.MultiModality) {
1000
+ info = await Promise.all([
1001
+ constructSessions(pretrained_model_name_or_path, {
1002
+ prepare_inputs_embeds: 'prepare_inputs_embeds',
1003
+ model: 'language_model',
1004
+ lm_head: 'lm_head',
1005
+ gen_head: 'gen_head',
1006
+ gen_img_embeds: 'gen_img_embeds',
1007
+ image_decode: 'image_decode',
1008
+ }, options),
1009
+ getOptionalConfigs(pretrained_model_name_or_path, {
1010
+ generation_config: 'generation_config.json',
1011
+ }, options),
1012
+ ]);
1013
+
915
1014
  } else { // should be MODEL_TYPES.EncoderOnly
916
1015
  if (modelType !== MODEL_TYPES.EncoderOnly) {
917
1016
  console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
@@ -1658,7 +1757,8 @@ export class PreTrainedModel extends Callable {
1658
1757
  const dtype = session?.config?.kv_cache_dtype ?? 'float32';
1659
1758
  const empty = (dtype === 'float16') ? new Uint16Array() : [];
1660
1759
 
1661
- const shapes = getKeyValueShapes(this.config);
1760
+ const batch_size = (decoderFeeds[this.main_input_name] ?? decoderFeeds.attention_mask).dims?.[0] ?? 1;
1761
+ const shapes = getKeyValueShapes(this.config, { batch_size });
1662
1762
 
1663
1763
  for (const name in shapes) {
1664
1764
  decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]);
@@ -3277,6 +3377,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel {
3277
3377
  }
3278
3378
  //////////////////////////////////////////////////
3279
3379
 
3380
+ export class LlavaOnevisionForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration
3280
3381
  export class Moondream1ForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration
3281
3382
 
3282
3383
  export class Florence2PreTrainedModel extends PreTrainedModel {
@@ -3437,7 +3538,7 @@ export class CLIPModel extends CLIPPreTrainedModel { }
3437
3538
  * The text model from CLIP without any head or projection on top.
3438
3539
  */
3439
3540
  export class CLIPTextModel extends CLIPPreTrainedModel {
3440
- /** @type {PreTrainedModel.from_pretrained} */
3541
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3441
3542
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3442
3543
  // Update default model file name if not provided
3443
3544
  options.model_file_name ??= 'text_model';
@@ -3472,7 +3573,7 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
3472
3573
  * ```
3473
3574
  */
3474
3575
  export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
3475
- /** @type {PreTrainedModel.from_pretrained} */
3576
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3476
3577
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3477
3578
  // Update default model file name if not provided
3478
3579
  options.model_file_name ??= 'text_model';
@@ -3484,7 +3585,7 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
3484
3585
  * The vision model from CLIP without any head or projection on top.
3485
3586
  */
3486
3587
  export class CLIPVisionModel extends CLIPPreTrainedModel {
3487
- /** @type {PreTrainedModel.from_pretrained} */
3588
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3488
3589
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3489
3590
  // Update default model file name if not provided
3490
3591
  options.model_file_name ??= 'vision_model';
@@ -3519,7 +3620,7 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
3519
3620
  * ```
3520
3621
  */
3521
3622
  export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
3522
- /** @type {PreTrainedModel.from_pretrained} */
3623
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3523
3624
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3524
3625
  // Update default model file name if not provided
3525
3626
  options.model_file_name ??= 'vision_model';
@@ -3605,8 +3706,7 @@ export class SiglipModel extends SiglipPreTrainedModel { }
3605
3706
  * ```
3606
3707
  */
3607
3708
  export class SiglipTextModel extends SiglipPreTrainedModel {
3608
-
3609
- /** @type {PreTrainedModel.from_pretrained} */
3709
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3610
3710
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3611
3711
  // Update default model file name if not provided
3612
3712
  options.model_file_name ??= 'text_model';
@@ -3641,7 +3741,7 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
3641
3741
  * ```
3642
3742
  */
3643
3743
  export class SiglipVisionModel extends CLIPPreTrainedModel {
3644
- /** @type {PreTrainedModel.from_pretrained} */
3744
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3645
3745
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3646
3746
  // Update default model file name if not provided
3647
3747
  options.model_file_name ??= 'vision_model';
@@ -3655,6 +3755,67 @@ export class ChineseCLIPPreTrainedModel extends PreTrainedModel { }
3655
3755
  export class ChineseCLIPModel extends ChineseCLIPPreTrainedModel { }
3656
3756
  //////////////////////////////////////////////////
3657
3757
 
3758
+ //////////////////////////////////////////////////
3759
+ // JinaCLIP models
3760
+ export class JinaCLIPPreTrainedModel extends PreTrainedModel { }
3761
+
3762
+ export class JinaCLIPModel extends JinaCLIPPreTrainedModel {
3763
+ async forward(model_inputs) {
3764
+ const missing_text_inputs = !model_inputs.input_ids;
3765
+ const missing_image_inputs = !model_inputs.pixel_values;
3766
+
3767
+ if (missing_text_inputs && missing_image_inputs) {
3768
+ throw new Error('Either `input_ids` or `pixel_values` should be provided.');
3769
+ }
3770
+
3771
+ // If either `input_ids` or `pixel_values` aren't passed, we need to create dummy input since the model requires a value to be specified.
3772
+ if (missing_text_inputs) {
3773
+ // NOTE: We cannot pass zero-dimension tensor as input for input_ids.
3774
+ // Fortunately, the majority of time is spent in the vision encoder, so this shouldn't significantly impact performance.
3775
+ model_inputs.input_ids = ones([model_inputs.pixel_values.dims[0], 1]);
3776
+ }
3777
+
3778
+ if (missing_image_inputs) {
3779
+ // NOTE: Since we create a zero-sized tensor, this does not increase computation time.
3780
+ // @ts-ignore
3781
+ const { image_size } = this.config.vision_config;
3782
+ model_inputs.pixel_values = full([0, 3, image_size, image_size], 0.0); // (pass zero-dimension tensor)
3783
+ }
3784
+
3785
+ const { text_embeddings, image_embeddings, l2norm_text_embeddings, l2norm_image_embeddings } = await super.forward(model_inputs);
3786
+
3787
+ const result = {};
3788
+ if (!missing_text_inputs) {
3789
+ result.text_embeddings = text_embeddings;
3790
+ result.l2norm_text_embeddings = l2norm_text_embeddings;
3791
+ }
3792
+ if (!missing_image_inputs) {
3793
+ result.image_embeddings = image_embeddings;
3794
+ result.l2norm_image_embeddings = l2norm_image_embeddings;
3795
+ }
3796
+ return result
3797
+ }
3798
+ }
3799
+
3800
+ export class JinaCLIPTextModel extends JinaCLIPPreTrainedModel {
3801
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3802
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3803
+ // Update default model file name if not provided
3804
+ options.model_file_name ??= 'text_model';
3805
+ return super.from_pretrained(pretrained_model_name_or_path, options);
3806
+ }
3807
+ }
3808
+
3809
+ export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel {
3810
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3811
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3812
+ // Update default model file name if not provided
3813
+ options.model_file_name ??= 'vision_model';
3814
+ return super.from_pretrained(pretrained_model_name_or_path, options);
3815
+ }
3816
+ }
3817
+ //////////////////////////////////////////////////
3818
+
3658
3819
 
3659
3820
  //////////////////////////////////////////////////
3660
3821
  // CLIPSeg models
@@ -3810,6 +3971,22 @@ export class LlamaForCausalLM extends LlamaPreTrainedModel { }
3810
3971
  //////////////////////////////////////////////////
3811
3972
 
3812
3973
 
3974
+ //////////////////////////////////////////////////
3975
+ // MobileLLM models
3976
+ export class MobileLLMPreTrainedModel extends PreTrainedModel { }
3977
+ export class MobileLLMModel extends MobileLLMPreTrainedModel { }
3978
+ export class MobileLLMForCausalLM extends MobileLLMPreTrainedModel { }
3979
+ //////////////////////////////////////////////////
3980
+
3981
+
3982
+ //////////////////////////////////////////////////
3983
+ // OLMo models
3984
+ export class OlmoPreTrainedModel extends PreTrainedModel { }
3985
+ export class OlmoModel extends OlmoPreTrainedModel { }
3986
+ export class OlmoForCausalLM extends OlmoPreTrainedModel { }
3987
+ //////////////////////////////////////////////////
3988
+
3989
+
3813
3990
  //////////////////////////////////////////////////
3814
3991
  // Granite models
3815
3992
  export class GranitePreTrainedModel extends PreTrainedModel { }
@@ -3882,6 +4059,285 @@ export class Qwen2Model extends Qwen2PreTrainedModel { }
3882
4059
  export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
3883
4060
  //////////////////////////////////////////////////
3884
4061
 
4062
+ export class Qwen2VLPreTrainedModel extends PreTrainedModel {
4063
+ forward_params = [
4064
+ // Text inputs
4065
+ 'input_ids',
4066
+ 'attention_mask',
4067
+ 'position_ids',
4068
+ 'past_key_values',
4069
+
4070
+ // Vision inputs
4071
+ 'pixel_values',
4072
+ 'image_grid_thw',
4073
+ ];
4074
+ }
4075
+ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
4076
+
4077
+ /**
4078
+ * Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
4079
+ *
4080
+ * Explanation:
4081
+ * Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
4082
+ *
4083
+ * For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
4084
+ * Examples:
4085
+ * input_ids: [T T T T T], here T is for text.
4086
+ * temporal position_ids: [0, 1, 2, 3, 4]
4087
+ * height position_ids: [0, 1, 2, 3, 4]
4088
+ * width position_ids: [0, 1, 2, 3, 4]
4089
+ *
4090
+ * For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
4091
+ * and 1D rotary position embeddin for text part.
4092
+ * Examples:
4093
+ * Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
4094
+ * input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
4095
+ * vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
4096
+ * vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
4097
+ * vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
4098
+ * text temporal position_ids: [3, 4, 5, 6, 7]
4099
+ * text height position_ids: [3, 4, 5, 6, 7]
4100
+ * text width position_ids: [3, 4, 5, 6, 7]
4101
+ * Here we calculate the text start position_ids as the max vision position_ids plus 1.
4102
+ *
4103
+ * @param {Tensor} input_ids Indices of input sequence tokens in the vocabulary. Tensor of shape `(batch_size, sequence_length)`.
4104
+ * @param {Tensor} image_grid_thw (Optional) The temporal, height and width of feature shape of each image in LLM. Tensor of shape `(num_images, 3)`.
4105
+ * @param {Tensor} video_grid_thw (Optional) The temporal, height and width of feature shape of each video in LLM. Tensor of shape `(num_videos, 3)`.
4106
+ * @param {Tensor} attention_mask (Optional) Mask to avoid performing attention on padding token indices. Tensor of shape `(batch_size, sequence_length)`. Mask values selected in `[0, 1]`:
4107
+ * - 1 for tokens that are **not masked**,
4108
+ * - 0 for tokens that are **masked**.
4109
+ * @returns {[Tensor, Tensor]} [position_ids, mrope_position_deltas] with:
4110
+ * - position_ids: Tensor of shape `(3, batch_size, sequence_length)`.
4111
+ * - mrope_position_deltas: Tensor of shape `(batch_size)`.
4112
+ */
4113
+ get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) {
4114
+ // @ts-ignore
4115
+ const { vision_config, image_token_id, video_token_id, vision_start_token_id } = this.config;
4116
+ const spatial_merge_size = vision_config.spatial_merge_size ?? 2;
4117
+
4118
+ const mrope_position_deltas = [];
4119
+ if (image_grid_thw || video_grid_thw) {
4120
+ let total_input_ids = input_ids.tolist();
4121
+ if (!attention_mask) {
4122
+ attention_mask = ones_like(input_ids);
4123
+ }
4124
+
4125
+ const attention_mask_list = attention_mask.tolist();
4126
+ const position_ids_list = Array.from({ length: 3 }, _ => Array.from({ length: input_ids.dims[0] }, _ => Array.from({ length: input_ids.dims[1] }, _ => 1)));
4127
+
4128
+ const image_grid_thw_list = image_grid_thw ? image_grid_thw.tolist() : [];
4129
+ const video_grid_thw_list = video_grid_thw ? video_grid_thw.tolist() : [];
4130
+
4131
+ let image_index = 0;
4132
+ let video_index = 0;
4133
+ for (let i = 0; i < total_input_ids.length; ++i) {
4134
+ const ids = total_input_ids[i].filter((_, j) => attention_mask_list[i][j] == 1);
4135
+
4136
+ const vision_start_indices = ids.reduce((acc, x, idx) => {
4137
+ if (x == vision_start_token_id) acc.push(idx);
4138
+ return acc;
4139
+ }, []);
4140
+
4141
+ const vision_tokens = vision_start_indices.map(x => ids[x + 1]);
4142
+ const image_nums = vision_tokens.filter(x => x == image_token_id).length;
4143
+ const video_nums = vision_tokens.filter(x => x == video_token_id).length;
4144
+
4145
+ let llm_pos_ids_list = [];
4146
+ let st = 0;
4147
+ let remain_images = image_nums;
4148
+ let remain_videos = video_nums;
4149
+ for (let j = 0; j < vision_tokens.length; ++j) {
4150
+ const next_image_token = ids.findIndex((x, i) => i > st && x == image_token_id);
4151
+ const next_video_token = ids.findIndex((x, i) => i > st && x == video_token_id);
4152
+
4153
+ const ed_image = (remain_images > 0 && next_image_token !== -1)
4154
+ ? next_image_token
4155
+ : ids.length + 1;
4156
+
4157
+ const ed_video = (remain_videos > 0 && next_video_token !== -1)
4158
+ ? next_video_token
4159
+ : ids.length + 1;
4160
+
4161
+ let ed;
4162
+ let t, h, w;
4163
+ if (ed_image < ed_video) {
4164
+ ([t, h, w] = image_grid_thw_list[image_index]);
4165
+ ++image_index;
4166
+ --remain_images;
4167
+ ed = ed_image;
4168
+ } else {
4169
+ ([t, h, w] = video_grid_thw_list[video_index]);
4170
+ ++video_index;
4171
+ --remain_videos;
4172
+ ed = ed_video;
4173
+ }
4174
+
4175
+ const [llm_grid_t, llm_grid_h, llm_grid_w] = [
4176
+ Number(t),
4177
+ Math.floor(Number(h) / spatial_merge_size),
4178
+ Math.floor(Number(w) / spatial_merge_size)
4179
+ ]
4180
+ const text_len = ed - st;
4181
+ const st_idx = llm_pos_ids_list.length > 0
4182
+ ? max(llm_pos_ids_list.at(-1))[0] + 1
4183
+ : 0;
4184
+
4185
+ llm_pos_ids_list.push(
4186
+ Array.from({ length: 3 * text_len }, (_, i) => st_idx + (i % text_len))
4187
+ )
4188
+
4189
+ const offset = text_len + st_idx;
4190
+ const grid_size = llm_grid_t * llm_grid_h * llm_grid_w;
4191
+ const t_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / (llm_grid_h * llm_grid_w)))
4192
+ const h_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / llm_grid_w) % llm_grid_h)
4193
+ const w_index = Array.from({ length: grid_size }, (_, i) => offset + i % llm_grid_w)
4194
+
4195
+ llm_pos_ids_list.push([t_index, h_index, w_index].flat())
4196
+
4197
+ st = ed + grid_size;
4198
+ }
4199
+
4200
+ if (st < ids.length) {
4201
+ const st_idx = llm_pos_ids_list.length > 0
4202
+ ? max(llm_pos_ids_list.at(-1))[0] + 1
4203
+ : 0;
4204
+ const text_len = ids.length - st;
4205
+
4206
+ llm_pos_ids_list.push(
4207
+ Array.from({ length: 3 * text_len }, (_, i) => (st_idx + (i % text_len)))
4208
+ )
4209
+ }
4210
+
4211
+ // NOTE: Each item in llm_pos_ids_list is an array of shape (3, text_len),
4212
+ // meaning to perform concatenation along dim=1, we can do the following:
4213
+ const num_items = llm_pos_ids_list.reduce((acc, x) => acc + x.length, 0);
4214
+ const llm_positions = new Array(num_items);
4215
+ let index = 0;
4216
+ for (let x = 0; x < 3; ++x) {
4217
+ for (let y = 0; y < llm_pos_ids_list.length; ++y) {
4218
+ const val = llm_pos_ids_list[y];
4219
+ const text_len = val.length / 3;
4220
+ for (let z = x * text_len; z < (x + 1) * text_len; ++z) {
4221
+ llm_positions[index++] = val[z];
4222
+ }
4223
+ }
4224
+ }
4225
+
4226
+ let count = 0;
4227
+ const attn_mask = attention_mask_list[i];
4228
+ for (let y = 0; y < attn_mask.length; ++y) {
4229
+ if (attn_mask[y] == 1) {
4230
+ for (let x = 0; x < 3; ++x) {
4231
+ position_ids_list[x][i][y] = llm_positions[x * num_items / 3 + count];
4232
+ }
4233
+ ++count;
4234
+ }
4235
+ }
4236
+
4237
+ const max_llm_positions = max(llm_positions)[0];
4238
+ mrope_position_deltas.push(max_llm_positions + 1 - total_input_ids[i].length);
4239
+ }
4240
+
4241
+ return [
4242
+ new Tensor('int64', position_ids_list.flat(Infinity), [3, input_ids.dims[0], input_ids.dims[1]]),
4243
+ new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
4244
+ ];
4245
+
4246
+ } else { // Text-only
4247
+ if (attention_mask) {
4248
+ const { data, dims } = cumsum_masked_fill(attention_mask);
4249
+
4250
+ const position_ids = BigInt64Array.from(
4251
+ { length: 3 * data.length },
4252
+ (_, i) => data[i % data.length]
4253
+ );
4254
+ const mrope_position_deltas = Array.from(
4255
+ { length: dims[0] },
4256
+ (_, i) => max(data.subarray(dims[1] * i, dims[1] * (i + 1)))[0] + 1 + dims[1]
4257
+ );
4258
+
4259
+ return [
4260
+ new Tensor('int64', position_ids, [3, ...dims]),
4261
+ new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
4262
+ ]
4263
+ } else {
4264
+ const [batch_size, seq_length] = input_ids.dims;
4265
+ const position_ids = BigInt64Array.from(
4266
+ { length: 3 * batch_size * seq_length },
4267
+ (_, i) => BigInt(Math.floor(i % seq_length / batch_size)),
4268
+ );
4269
+
4270
+ return [
4271
+ new Tensor('int64', position_ids, [3, ...input_ids.dims]),
4272
+ zeros([batch_size, 1]),
4273
+ ]
4274
+ }
4275
+ }
4276
+ }
4277
+
4278
+ async encode_image({ pixel_values, image_grid_thw }) {
4279
+ const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, grid_thw: image_grid_thw })).image_features;
4280
+ return features;
4281
+ }
4282
+
4283
+ _merge_input_ids_with_image_features({
4284
+ inputs_embeds,
4285
+ image_features,
4286
+ input_ids,
4287
+ attention_mask,
4288
+ }) {
4289
+ // @ts-ignore
4290
+ const { image_token_id } = this.config;
4291
+ const image_tokens = input_ids.tolist().map(ids =>
4292
+ ids.reduce((acc, x, idx) => {
4293
+ if (x == image_token_id) acc.push(idx);
4294
+ return acc;
4295
+ }, [])
4296
+ );
4297
+ const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
4298
+ const n_image_features = image_features.dims[0];
4299
+ if (n_image_tokens !== n_image_features) {
4300
+ throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
4301
+ }
4302
+
4303
+ // Equivalent to performing a masked_scatter
4304
+ let img = 0;
4305
+ for (let i = 0; i < image_tokens.length; ++i) {
4306
+ const tokens = image_tokens[i];
4307
+ const embeds = inputs_embeds[i];
4308
+ for (let j = 0; j < tokens.length; ++j) {
4309
+ embeds[tokens[j]].data.set(image_features[img++].data)
4310
+ }
4311
+ }
4312
+ return { inputs_embeds, attention_mask }
4313
+ }
4314
+
4315
+ prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
4316
+ // Overwritten -- in specific circumstances we don't want to forward image inputs to the model
4317
+ if (model_inputs.attention_mask && !model_inputs.position_ids) {
4318
+ // Calculate position_ids and rope_deltas
4319
+ if (!model_inputs.past_key_values) {
4320
+ ([model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
4321
+ model_inputs.input_ids,
4322
+ model_inputs.image_grid_thw,
4323
+ model_inputs.video_grid_thw,
4324
+ model_inputs.attention_mask,
4325
+ ));
4326
+
4327
+ } else {
4328
+ model_inputs.pixel_values = null;
4329
+ // model_inputs.pixel_values_videos = null;
4330
+
4331
+ const delta = BigInt(Object.values(model_inputs.past_key_values)[0].dims.at(-2));
4332
+ const rope_deltas_list = model_inputs.rope_deltas.map(x => delta + x);
4333
+ model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0)
4334
+ }
4335
+ }
4336
+
4337
+ return model_inputs;
4338
+ }
4339
+ }
4340
+
3885
4341
 
3886
4342
  //////////////////////////////////////////////////
3887
4343
  // Phi models
@@ -3969,6 +4425,17 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
3969
4425
  }
3970
4426
  //////////////////////////////////////////////////
3971
4427
 
4428
+
4429
+ //////////////////////////////////////////////////
4430
+ export class VitPosePreTrainedModel extends PreTrainedModel { }
4431
+
4432
+ /**
4433
+ * The VitPose model with a pose estimation head on top.
4434
+ */
4435
+ export class VitPoseForPoseEstimation extends VitPosePreTrainedModel { }
4436
+ //////////////////////////////////////////////////
4437
+
4438
+
3972
4439
  //////////////////////////////////////////////////
3973
4440
  export class PvtPreTrainedModel extends PreTrainedModel { }
3974
4441
  export class PvtModel extends PvtPreTrainedModel { }
@@ -5567,8 +6034,7 @@ export class ClapModel extends ClapPreTrainedModel { }
5567
6034
  * ```
5568
6035
  */
5569
6036
  export class ClapTextModelWithProjection extends ClapPreTrainedModel {
5570
-
5571
- /** @type {PreTrainedModel.from_pretrained} */
6037
+ /** @type {typeof PreTrainedModel.from_pretrained} */
5572
6038
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
5573
6039
  // Update default model file name if not provided
5574
6040
  options.model_file_name ??= 'text_model';
@@ -5603,7 +6069,7 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
5603
6069
  * ```
5604
6070
  */
5605
6071
  export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
5606
- /** @type {PreTrainedModel.from_pretrained} */
6072
+ /** @type {typeof PreTrainedModel.from_pretrained} */
5607
6073
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
5608
6074
  // Update default model file name if not provided
5609
6075
  options.model_file_name ??= 'audio_model';
@@ -5954,6 +6420,170 @@ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel
5954
6420
 
5955
6421
  //////////////////////////////////////////////////
5956
6422
 
6423
+ export class MultiModalityPreTrainedModel extends PreTrainedModel { }
6424
+ export class MultiModalityCausalLM extends MultiModalityPreTrainedModel {
6425
+ forward_params = [
6426
+ // prepare_inputs_embeds
6427
+ 'input_ids',
6428
+ 'pixel_values',
6429
+ 'images_seq_mask',
6430
+ 'images_emb_mask',
6431
+
6432
+ // language_model
6433
+ 'attention_mask',
6434
+ 'position_ids',
6435
+ 'past_key_values',
6436
+ ];
6437
+
6438
+ constructor(...args) {
6439
+ super(...args);
6440
+
6441
+ // State-based approach to switch out which heads to use during generation
6442
+ this._generation_mode = 'text';
6443
+ }
6444
+
6445
+ async forward(model_inputs) {
6446
+ const mode = this._generation_mode ?? 'text';
6447
+
6448
+ // TODO support re-using PKVs for input_ids.dims[1] !== 1
6449
+ // if (model_inputs.past_key_values) {
6450
+ // // && model_inputs.input_ids.dims[1] === 1
6451
+ // }
6452
+
6453
+ let output_1;
6454
+ if (mode === 'text' || !model_inputs.past_key_values) {
6455
+ const session = this.sessions['prepare_inputs_embeds'];
6456
+ const prep_inputs = pick(model_inputs, session.inputNames);
6457
+ output_1 = await sessionRun(session, prep_inputs);
6458
+ } else {
6459
+ const session = this.sessions['gen_img_embeds'];
6460
+ const prep_inputs = pick({
6461
+ image_ids: model_inputs.input_ids,
6462
+ }, session.inputNames);
6463
+ output_1 = await sessionRun(session, prep_inputs);
6464
+ }
6465
+
6466
+ const input_2 = { ...model_inputs, ...output_1 }
6467
+ const output_2 = await decoderForward(this, input_2);
6468
+
6469
+ const head = this.sessions[
6470
+ mode === 'text'
6471
+ ? 'lm_head'
6472
+ : 'gen_head'
6473
+ ];
6474
+ if (!head) {
6475
+ throw new Error(`Unable to find "${head}" generation head`);
6476
+ }
6477
+
6478
+ const output_3 = await sessionRun(head, pick(output_2, head.inputNames))
6479
+
6480
+ return {
6481
+ ...output_1,
6482
+ ...output_2,
6483
+ ...output_3,
6484
+ };
6485
+ }
6486
+
6487
+ /**
6488
+ * @param {import('./generation/parameters.js').GenerationFunctionParameters} options
6489
+ */
6490
+ async generate(options) {
6491
+ this._generation_mode = 'text';
6492
+ return super.generate(options);
6493
+ }
6494
+
6495
+ /**
6496
+ * @param {import('./generation/parameters.js').GenerationFunctionParameters} options
6497
+ */
6498
+ async generate_images(options) {
6499
+ this._generation_mode = 'image';
6500
+
6501
+ const start_num_tokens = (options.inputs ?? options[this.main_input_name]).dims[1];
6502
+ const all_tokens = await super.generate(options);
6503
+
6504
+ const generated_tokens = (/** @type {Tensor} */(all_tokens)).slice(null, [start_num_tokens, null])
6505
+
6506
+ const image_decode = this.sessions['image_decode'];
6507
+ const { decoded_image } = await sessionRun(image_decode, {
6508
+ generated_tokens,
6509
+ });
6510
+
6511
+ // Equivalent to `np.clip((dec + 1) / 2 * 255, 0, 255)`
6512
+ const clamped = decoded_image
6513
+ .add_(1)
6514
+ .mul_(255 / 2)
6515
+ .clamp_(0, 255)
6516
+ .to('uint8');
6517
+
6518
+ // Return as a list of images
6519
+ const images = [];
6520
+ for (const tensor of clamped) {
6521
+ const img = RawImage.fromTensor(tensor);
6522
+ images.push(img);
6523
+ }
6524
+ return images;
6525
+ }
6526
+ }
6527
+
6528
+ export class MgpstrModelOutput extends ModelOutput {
6529
+ constructor({ char_logits, bpe_logits, wp_logits }) {
6530
+ super();
6531
+ this.char_logits = char_logits;
6532
+ this.bpe_logits = bpe_logits;
6533
+ this.wp_logits = wp_logits;
6534
+ }
6535
+
6536
+ get logits() {
6537
+ return [this.char_logits, this.bpe_logits, this.wp_logits];
6538
+ }
6539
+ }
6540
+
6541
+ export class MgpstrPreTrainedModel extends PreTrainedModel { }
6542
+
6543
+ /**
6544
+ * MGP-STR Model transformer with three classification heads on top
6545
+ * (three A^3 modules and three linear layer on top of the transformer encoder output) for scene text recognition (STR).
6546
+ */
6547
+ export class MgpstrForSceneTextRecognition extends MgpstrPreTrainedModel {
6548
+ /**
6549
+ * @param {any} model_inputs
6550
+ */
6551
+ async _call(model_inputs) {
6552
+ return new MgpstrModelOutput(await super._call(model_inputs));
6553
+ }
6554
+ }
6555
+
6556
+ //////////////////////////////////////////////////
6557
+ // PatchTST Transformer models
6558
+ export class PatchTSTPreTrainedModel extends PreTrainedModel { }
6559
+
6560
+ /**
6561
+ * The bare PatchTST Model outputting raw hidden-states without any specific head.
6562
+ */
6563
+ export class PatchTSTModel extends PatchTSTPreTrainedModel { }
6564
+
6565
+ /**
6566
+ * The PatchTST for prediction model.
6567
+ */
6568
+ export class PatchTSTForPrediction extends PatchTSTPreTrainedModel { }
6569
+ //////////////////////////////////////////////////
6570
+
6571
+ //////////////////////////////////////////////////
6572
+ // PatchTSMixer Transformer models
6573
+ export class PatchTSMixerPreTrainedModel extends PreTrainedModel { }
6574
+
6575
+ /**
6576
+ * The bare PatchTSMixer Model outputting raw hidden-states without any specific head.
6577
+ */
6578
+ export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel { }
6579
+
6580
+ /**
6581
+ * The PatchTSMixer for prediction model.
6582
+ */
6583
+ export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel { }
6584
+ //////////////////////////////////////////////////
6585
+
6586
+
5957
6587
  //////////////////////////////////////////////////
5958
6588
  // AutoModels, used to simplify construction of PreTrainedModels
5959
6589
  // (uses config to instantiate correct class)
@@ -6048,6 +6678,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6048
6678
  ['clipseg', ['CLIPSegModel', CLIPSegModel]],
6049
6679
  ['chinese_clip', ['ChineseCLIPModel', ChineseCLIPModel]],
6050
6680
  ['siglip', ['SiglipModel', SiglipModel]],
6681
+ ['jina_clip', ['JinaCLIPModel', JinaCLIPModel]],
6051
6682
  ['mobilebert', ['MobileBertModel', MobileBertModel]],
6052
6683
  ['squeezebert', ['SqueezeBertModel', SqueezeBertModel]],
6053
6684
  ['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]],
@@ -6092,6 +6723,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6092
6723
  ['efficientnet', ['EfficientNetModel', EfficientNetModel]],
6093
6724
 
6094
6725
  ['decision_transformer', ['DecisionTransformerModel', DecisionTransformerModel]],
6726
+ ['patchtst', ['PatchTSTForPrediction', PatchTSTModel]],
6727
+ ['patchtsmixer', ['PatchTSMixerForPrediction', PatchTSMixerModel]],
6095
6728
 
6096
6729
  ['mobilenet_v1', ['MobileNetV1Model', MobileNetV1Model]],
6097
6730
  ['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]],
@@ -6099,6 +6732,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6099
6732
  ['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]],
6100
6733
 
6101
6734
  ['maskformer', ['MaskFormerModel', MaskFormerModel]],
6735
+ ['mgp-str', ['MgpstrForSceneTextRecognition', MgpstrForSceneTextRecognition]],
6102
6736
  ]);
6103
6737
 
6104
6738
  const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
@@ -6125,6 +6759,8 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
6125
6759
  ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]],
6126
6760
  ['codegen', ['CodeGenModel', CodeGenModel]],
6127
6761
  ['llama', ['LlamaModel', LlamaModel]],
6762
+ ['olmo', ['OlmoModel', OlmoModel]],
6763
+ ['mobilellm', ['MobileLLMModel', MobileLLMModel]],
6128
6764
  ['granite', ['GraniteModel', GraniteModel]],
6129
6765
  ['cohere', ['CohereModel', CohereModel]],
6130
6766
  ['gemma', ['GemmaModel', GemmaModel]],
@@ -6214,6 +6850,8 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
6214
6850
  ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]],
6215
6851
  ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]],
6216
6852
  ['llama', ['LlamaForCausalLM', LlamaForCausalLM]],
6853
+ ['olmo', ['OlmoForCausalLM', OlmoForCausalLM]],
6854
+ ['mobilellm', ['MobileLLMForCausalLM', MobileLLMForCausalLM]],
6217
6855
  ['granite', ['GraniteForCausalLM', GraniteForCausalLM]],
6218
6856
  ['cohere', ['CohereForCausalLM', CohereForCausalLM]],
6219
6857
  ['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
@@ -6232,6 +6870,11 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
6232
6870
  ['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]],
6233
6871
  ]);
6234
6872
 
6873
+ const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([
6874
+ ['multi_modality', ['MultiModalityCausalLM', MultiModalityCausalLM]],
6875
+ ]);
6876
+
6877
+
6235
6878
  const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
6236
6879
  ['bert', ['BertForMaskedLM', BertForMaskedLM]],
6237
6880
  ['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]],
@@ -6275,8 +6918,10 @@ const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
6275
6918
 
6276
6919
  const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
6277
6920
  ['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]],
6921
+ ['llava_onevision', ['LlavaOnevisionForConditionalGeneration', LlavaOnevisionForConditionalGeneration]],
6278
6922
  ['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]],
6279
6923
  ['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
6924
+ ['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
6280
6925
  ]);
6281
6926
 
6282
6927
  const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
@@ -6372,6 +7017,11 @@ const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
6372
7017
  ['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]],
6373
7018
  ]);
6374
7019
 
7020
+ const MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = new Map([
7021
+ ['patchtst', ['PatchTSTForPrediction', PatchTSTForPrediction]],
7022
+ ['patchtsmixer', ['PatchTSMixerForPrediction', PatchTSMixerForPrediction]],
7023
+ ])
7024
+
6375
7025
  const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([
6376
7026
  ['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]],
6377
7027
  ])
@@ -6388,11 +7038,16 @@ const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
6388
7038
  ['sapiens', ['SapiensForNormalEstimation', SapiensForNormalEstimation]],
6389
7039
  ])
6390
7040
 
7041
+ const MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES = new Map([
7042
+ ['vitpose', ['VitPoseForPoseEstimation', VitPoseForPoseEstimation]],
7043
+ ])
7044
+
6391
7045
  // NOTE: This is custom to Transformers.js, and is necessary because certain models
6392
7046
  // (e.g., CLIP) are split into vision and text components
6393
7047
  const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([
6394
7048
  ['clip', ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection]],
6395
7049
  ['siglip', ['SiglipVisionModel', SiglipVisionModel]],
7050
+ ['jina_clip', ['JinaCLIPVisionModel', JinaCLIPVisionModel]],
6396
7051
  ])
6397
7052
 
6398
7053
  const MODEL_CLASS_TYPE_MAPPING = [
@@ -6404,6 +7059,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
6404
7059
  [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
6405
7060
  [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
6406
7061
  [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
7062
+ [MODEL_FOR_MULTIMODALITY_MAPPING_NAMES, MODEL_TYPES.MultiModality],
6407
7063
  [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6408
7064
  [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6409
7065
  [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
@@ -6413,9 +7069,11 @@ const MODEL_CLASS_TYPE_MAPPING = [
6413
7069
  [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6414
7070
  [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6415
7071
  [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
7072
+ [MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6416
7073
  [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6417
7074
  [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6418
7075
  [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
7076
+ [MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6419
7077
  [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6420
7078
  [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6421
7079
  [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
@@ -6446,6 +7104,7 @@ const CUSTOM_MAPPING = [
6446
7104
 
6447
7105
  ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
6448
7106
  ['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly],
7107
+ ['JinaCLIPTextModel', JinaCLIPTextModel, MODEL_TYPES.EncoderOnly],
6449
7108
  ['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
6450
7109
  ['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly],
6451
7110
  ]
@@ -6687,6 +7346,10 @@ export class AutoModelForNormalEstimation extends PretrainedMixin {
6687
7346
  static MODEL_CLASS_MAPPINGS = [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES];
6688
7347
  }
6689
7348
 
7349
+ export class AutoModelForPoseEstimation extends PretrainedMixin {
7350
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES];
7351
+ }
7352
+
6690
7353
  export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
6691
7354
  static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
6692
7355
  }