@huggingface/transformers 3.0.2 → 3.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. package/README.md +13 -4
  2. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  3. package/dist/transformers.cjs +16655 -13040
  4. package/dist/transformers.cjs.map +1 -1
  5. package/dist/transformers.js +17095 -13468
  6. package/dist/transformers.js.map +1 -1
  7. package/dist/transformers.min.cjs +244 -52
  8. package/dist/transformers.min.cjs.map +1 -1
  9. package/dist/transformers.min.js +235 -43
  10. package/dist/transformers.min.js.map +1 -1
  11. package/dist/transformers.min.mjs +246 -54
  12. package/dist/transformers.min.mjs.map +1 -1
  13. package/dist/transformers.mjs +16818 -13202
  14. package/dist/transformers.mjs.map +1 -1
  15. package/package.json +4 -4
  16. package/src/base/feature_extraction_utils.js +54 -0
  17. package/src/base/image_processors_utils.js +1089 -0
  18. package/src/base/processing_utils.js +145 -0
  19. package/src/configs.js +15 -4
  20. package/src/env.js +6 -6
  21. package/src/generation/configuration_utils.js +7 -0
  22. package/src/generation/logits_process.js +22 -16
  23. package/src/generation/streamers.js +7 -2
  24. package/src/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.js +90 -0
  25. package/src/models/auto/feature_extraction_auto.js +41 -0
  26. package/src/models/auto/image_processing_auto.js +29 -0
  27. package/src/models/auto/processing_auto.js +100 -0
  28. package/src/models/beit/image_processing_beit.js +5 -0
  29. package/src/models/bit/image_processing_bit.js +5 -0
  30. package/src/models/chinese_clip/image_processing_chinese_clip.js +5 -0
  31. package/src/models/clap/feature_extraction_clap.js +159 -0
  32. package/src/models/clip/image_processing_clip.js +6 -0
  33. package/src/models/convnext/image_processing_convnext.js +45 -0
  34. package/src/models/deit/image_processing_deit.js +6 -0
  35. package/src/models/detr/image_processing_detr.js +52 -0
  36. package/src/models/donut/image_processing_donut.js +31 -0
  37. package/src/models/dpt/image_processing_dpt.js +6 -0
  38. package/src/models/efficientnet/image_processing_efficientnet.js +13 -0
  39. package/src/models/feature_extractors.js +12 -0
  40. package/src/models/florence2/processing_florence2.js +128 -0
  41. package/src/models/glpn/image_processing_glpn.js +5 -0
  42. package/src/models/idefics3/image_processing_idefics3.js +219 -0
  43. package/src/models/idefics3/processing_idefics3.js +136 -0
  44. package/src/models/image_processors.js +37 -0
  45. package/src/models/janus/image_processing_janus.js +26 -0
  46. package/src/models/janus/processing_janus.js +123 -0
  47. package/src/models/jina_clip/image_processing_jina_clip.js +26 -0
  48. package/src/models/jina_clip/processing_jina_clip.js +24 -0
  49. package/src/models/llava_onevision/image_processing_llava_onevision.js +5 -0
  50. package/src/models/mask2former/image_processing_mask2former.js +5 -0
  51. package/src/models/maskformer/image_processing_maskformer.js +18 -0
  52. package/src/models/mgp_str/processing_mgp_str.js +170 -0
  53. package/src/models/mobilenet_v1/image_processing_mobilenet_v1.js +7 -0
  54. package/src/models/mobilenet_v2/image_processing_mobilenet_v2.js +7 -0
  55. package/src/models/mobilenet_v3/image_processing_mobilenet_v3.js +7 -0
  56. package/src/models/mobilenet_v4/image_processing_mobilenet_v4.js +7 -0
  57. package/src/models/mobilevit/image_processing_mobilevit.js +6 -0
  58. package/src/models/nougat/image_processing_nougat.js +5 -0
  59. package/src/models/owlv2/image_processing_owlv2.js +5 -0
  60. package/src/models/owlvit/image_processing_owlvit.js +12 -0
  61. package/src/models/owlvit/processing_owlvit.js +7 -0
  62. package/src/models/processors.js +12 -0
  63. package/src/models/pvt/image_processing_pvt.js +5 -0
  64. package/src/models/pyannote/feature_extraction_pyannote.js +28 -0
  65. package/src/models/pyannote/processing_pyannote.js +71 -0
  66. package/src/models/qwen2_vl/image_processing_qwen2_vl.js +52 -0
  67. package/src/models/qwen2_vl/processing_qwen2_vl.js +52 -0
  68. package/src/models/rt_detr/image_processing_rt_detr.js +12 -0
  69. package/src/models/sam/image_processing_sam.js +242 -0
  70. package/src/models/sam/processing_sam.js +20 -0
  71. package/src/models/sapiens/image_processing_sapiens.js +13 -0
  72. package/src/models/seamless_m4t/feature_extraction_seamless_m4t.js +180 -0
  73. package/src/models/segformer/image_processing_segformer.js +13 -0
  74. package/src/models/siglip/image_processing_siglip.js +5 -0
  75. package/src/models/speecht5/feature_extraction_speecht5.js +4 -0
  76. package/src/models/speecht5/processing_speecht5.js +17 -0
  77. package/src/models/swin2sr/image_processing_swin2sr.js +24 -0
  78. package/src/models/vit/image_processing_vit.js +7 -0
  79. package/src/models/vitmatte/image_processing_vitmatte.js +50 -0
  80. package/src/models/vitpose/image_processing_vitpose.js +89 -0
  81. package/src/models/wav2vec2/feature_extraction_wav2vec2.js +44 -0
  82. package/src/models/wav2vec2/processing_wav2vec2.js +15 -0
  83. package/src/models/wespeaker/feature_extraction_wespeaker.js +100 -0
  84. package/src/models/whisper/feature_extraction_whisper.js +84 -0
  85. package/src/models/whisper/processing_whisper.js +21 -0
  86. package/src/models/yolos/image_processing_yolos.js +12 -0
  87. package/src/models.js +755 -34
  88. package/src/pipelines.js +8 -8
  89. package/src/tokenizers.js +5 -0
  90. package/src/transformers.js +15 -2
  91. package/src/utils/constants.js +8 -1
  92. package/src/utils/core.js +51 -9
  93. package/src/utils/dtypes.js +2 -1
  94. package/src/utils/hub.js +2 -1
  95. package/src/utils/image.js +87 -33
  96. package/src/utils/tensor.js +39 -2
  97. package/types/base/feature_extraction_utils.d.ts +41 -0
  98. package/types/base/feature_extraction_utils.d.ts.map +1 -0
  99. package/types/base/image_processors_utils.d.ts +323 -0
  100. package/types/base/image_processors_utils.d.ts.map +1 -0
  101. package/types/base/processing_utils.d.ts +80 -0
  102. package/types/base/processing_utils.d.ts.map +1 -0
  103. package/types/configs.d.ts +5 -2
  104. package/types/configs.d.ts.map +1 -1
  105. package/types/env.d.ts +1 -1
  106. package/types/env.d.ts.map +1 -1
  107. package/types/generation/configuration_utils.d.ts +6 -0
  108. package/types/generation/configuration_utils.d.ts.map +1 -1
  109. package/types/generation/logits_process.d.ts +30 -20
  110. package/types/generation/logits_process.d.ts.map +1 -1
  111. package/types/generation/streamers.d.ts +13 -8
  112. package/types/generation/streamers.d.ts.map +1 -1
  113. package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts +25 -0
  114. package/types/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.d.ts.map +1 -0
  115. package/types/models/auto/feature_extraction_auto.d.ts +5 -0
  116. package/types/models/auto/feature_extraction_auto.d.ts.map +1 -0
  117. package/types/models/auto/image_processing_auto.d.ts +5 -0
  118. package/types/models/auto/image_processing_auto.d.ts.map +1 -0
  119. package/types/models/auto/processing_auto.d.ts +35 -0
  120. package/types/models/auto/processing_auto.d.ts.map +1 -0
  121. package/types/models/beit/image_processing_beit.d.ts +4 -0
  122. package/types/models/beit/image_processing_beit.d.ts.map +1 -0
  123. package/types/models/bit/image_processing_bit.d.ts +4 -0
  124. package/types/models/bit/image_processing_bit.d.ts.map +1 -0
  125. package/types/models/chinese_clip/image_processing_chinese_clip.d.ts +4 -0
  126. package/types/models/chinese_clip/image_processing_chinese_clip.d.ts.map +1 -0
  127. package/types/models/clap/feature_extraction_clap.d.ts +57 -0
  128. package/types/models/clap/feature_extraction_clap.d.ts.map +1 -0
  129. package/types/models/clip/image_processing_clip.d.ts +6 -0
  130. package/types/models/clip/image_processing_clip.d.ts.map +1 -0
  131. package/types/models/convnext/image_processing_convnext.d.ts +12 -0
  132. package/types/models/convnext/image_processing_convnext.d.ts.map +1 -0
  133. package/types/models/deit/image_processing_deit.d.ts +6 -0
  134. package/types/models/deit/image_processing_deit.d.ts.map +1 -0
  135. package/types/models/detr/image_processing_detr.d.ts +42 -0
  136. package/types/models/detr/image_processing_detr.d.ts.map +1 -0
  137. package/types/models/donut/image_processing_donut.d.ts +7 -0
  138. package/types/models/donut/image_processing_donut.d.ts.map +1 -0
  139. package/types/models/dpt/image_processing_dpt.d.ts +6 -0
  140. package/types/models/dpt/image_processing_dpt.d.ts.map +1 -0
  141. package/types/models/efficientnet/image_processing_efficientnet.d.ts +6 -0
  142. package/types/models/efficientnet/image_processing_efficientnet.d.ts.map +1 -0
  143. package/types/models/feature_extractors.d.ts +10 -0
  144. package/types/models/feature_extractors.d.ts.map +1 -0
  145. package/types/models/florence2/processing_florence2.d.ts +39 -0
  146. package/types/models/florence2/processing_florence2.d.ts.map +1 -0
  147. package/types/models/glpn/image_processing_glpn.d.ts +4 -0
  148. package/types/models/glpn/image_processing_glpn.d.ts.map +1 -0
  149. package/types/models/idefics3/image_processing_idefics3.d.ts +40 -0
  150. package/types/models/idefics3/image_processing_idefics3.d.ts.map +1 -0
  151. package/types/models/idefics3/processing_idefics3.d.ts +19 -0
  152. package/types/models/idefics3/processing_idefics3.d.ts.map +1 -0
  153. package/types/models/image_processors.d.ts +37 -0
  154. package/types/models/image_processors.d.ts.map +1 -0
  155. package/types/models/janus/image_processing_janus.d.ts +7 -0
  156. package/types/models/janus/image_processing_janus.d.ts.map +1 -0
  157. package/types/models/janus/processing_janus.d.ts +77 -0
  158. package/types/models/janus/processing_janus.d.ts.map +1 -0
  159. package/types/models/jina_clip/image_processing_jina_clip.d.ts +5 -0
  160. package/types/models/jina_clip/image_processing_jina_clip.d.ts.map +1 -0
  161. package/types/models/jina_clip/processing_jina_clip.d.ts +9 -0
  162. package/types/models/jina_clip/processing_jina_clip.d.ts.map +1 -0
  163. package/types/models/llava_onevision/image_processing_llava_onevision.d.ts +4 -0
  164. package/types/models/llava_onevision/image_processing_llava_onevision.d.ts.map +1 -0
  165. package/types/models/mask2former/image_processing_mask2former.d.ts +4 -0
  166. package/types/models/mask2former/image_processing_mask2former.d.ts.map +1 -0
  167. package/types/models/maskformer/image_processing_maskformer.d.ts +22 -0
  168. package/types/models/maskformer/image_processing_maskformer.d.ts.map +1 -0
  169. package/types/models/mgp_str/processing_mgp_str.d.ts +64 -0
  170. package/types/models/mgp_str/processing_mgp_str.d.ts.map +1 -0
  171. package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts +6 -0
  172. package/types/models/mobilenet_v1/image_processing_mobilenet_v1.d.ts.map +1 -0
  173. package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts +6 -0
  174. package/types/models/mobilenet_v2/image_processing_mobilenet_v2.d.ts.map +1 -0
  175. package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts +6 -0
  176. package/types/models/mobilenet_v3/image_processing_mobilenet_v3.d.ts.map +1 -0
  177. package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts +6 -0
  178. package/types/models/mobilenet_v4/image_processing_mobilenet_v4.d.ts.map +1 -0
  179. package/types/models/mobilevit/image_processing_mobilevit.d.ts +6 -0
  180. package/types/models/mobilevit/image_processing_mobilevit.d.ts.map +1 -0
  181. package/types/models/nougat/image_processing_nougat.d.ts +4 -0
  182. package/types/models/nougat/image_processing_nougat.d.ts.map +1 -0
  183. package/types/models/owlv2/image_processing_owlv2.d.ts +4 -0
  184. package/types/models/owlv2/image_processing_owlv2.d.ts.map +1 -0
  185. package/types/models/owlvit/image_processing_owlvit.d.ts +10 -0
  186. package/types/models/owlvit/image_processing_owlvit.d.ts.map +1 -0
  187. package/types/models/owlvit/processing_owlvit.d.ts +8 -0
  188. package/types/models/owlvit/processing_owlvit.d.ts.map +1 -0
  189. package/types/models/processors.d.ts +13 -0
  190. package/types/models/processors.d.ts.map +1 -0
  191. package/types/models/pvt/image_processing_pvt.d.ts +4 -0
  192. package/types/models/pvt/image_processing_pvt.d.ts.map +1 -0
  193. package/types/models/pyannote/feature_extraction_pyannote.d.ts +13 -0
  194. package/types/models/pyannote/feature_extraction_pyannote.d.ts.map +1 -0
  195. package/types/models/pyannote/processing_pyannote.d.ts +30 -0
  196. package/types/models/pyannote/processing_pyannote.d.ts.map +1 -0
  197. package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts +11 -0
  198. package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts.map +1 -0
  199. package/types/models/qwen2_vl/processing_qwen2_vl.d.ts +17 -0
  200. package/types/models/qwen2_vl/processing_qwen2_vl.d.ts.map +1 -0
  201. package/types/models/rt_detr/image_processing_rt_detr.d.ts +8 -0
  202. package/types/models/rt_detr/image_processing_rt_detr.d.ts.map +1 -0
  203. package/types/models/sam/image_processing_sam.d.ts +103 -0
  204. package/types/models/sam/image_processing_sam.d.ts.map +1 -0
  205. package/types/models/sam/processing_sam.d.ts +9 -0
  206. package/types/models/sam/processing_sam.d.ts.map +1 -0
  207. package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts +34 -0
  208. package/types/models/seamless_m4t/feature_extraction_seamless_m4t.d.ts.map +1 -0
  209. package/types/models/segformer/image_processing_segformer.d.ts +10 -0
  210. package/types/models/segformer/image_processing_segformer.d.ts.map +1 -0
  211. package/types/models/siglip/image_processing_siglip.d.ts +4 -0
  212. package/types/models/siglip/image_processing_siglip.d.ts.map +1 -0
  213. package/types/models/speecht5/feature_extraction_speecht5.d.ts +4 -0
  214. package/types/models/speecht5/feature_extraction_speecht5.d.ts.map +1 -0
  215. package/types/models/speecht5/processing_speecht5.d.ts +14 -0
  216. package/types/models/speecht5/processing_speecht5.d.ts.map +1 -0
  217. package/types/models/swin2sr/image_processing_swin2sr.d.ts +5 -0
  218. package/types/models/swin2sr/image_processing_swin2sr.d.ts.map +1 -0
  219. package/types/models/vit/image_processing_vit.d.ts +6 -0
  220. package/types/models/vit/image_processing_vit.d.ts.map +1 -0
  221. package/types/models/vitmatte/image_processing_vitmatte.d.ts +12 -0
  222. package/types/models/vitmatte/image_processing_vitmatte.d.ts.map +1 -0
  223. package/types/models/vitpose/image_processing_vitpose.d.ts +26 -0
  224. package/types/models/vitpose/image_processing_vitpose.d.ts.map +1 -0
  225. package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts +19 -0
  226. package/types/models/wav2vec2/feature_extraction_wav2vec2.d.ts.map +1 -0
  227. package/types/models/wav2vec2/processing_wav2vec2.d.ts +12 -0
  228. package/types/models/wav2vec2/processing_wav2vec2.d.ts.map +1 -0
  229. package/types/models/wespeaker/feature_extraction_wespeaker.d.ts +23 -0
  230. package/types/models/wespeaker/feature_extraction_wespeaker.d.ts.map +1 -0
  231. package/types/models/whisper/feature_extraction_whisper.d.ts +21 -0
  232. package/types/models/whisper/feature_extraction_whisper.d.ts.map +1 -0
  233. package/types/models/whisper/processing_whisper.d.ts +17 -0
  234. package/types/models/whisper/processing_whisper.d.ts.map +1 -0
  235. package/types/models/yolos/image_processing_yolos.d.ts +10 -0
  236. package/types/models/yolos/image_processing_yolos.d.ts.map +1 -0
  237. package/types/models.d.ts +150 -0
  238. package/types/models.d.ts.map +1 -1
  239. package/types/pipelines.d.ts +2 -3
  240. package/types/pipelines.d.ts.map +1 -1
  241. package/types/tokenizers.d.ts +3 -0
  242. package/types/tokenizers.d.ts.map +1 -1
  243. package/types/transformers.d.ts +10 -1
  244. package/types/utils/constants.d.ts +6 -0
  245. package/types/utils/constants.d.ts.map +1 -1
  246. package/types/utils/core.d.ts +65 -3
  247. package/types/utils/core.d.ts.map +1 -1
  248. package/types/utils/dtypes.d.ts +3 -2
  249. package/types/utils/dtypes.d.ts.map +1 -1
  250. package/types/utils/hub.d.ts +1 -1
  251. package/types/utils/hub.d.ts.map +1 -1
  252. package/types/utils/image.d.ts +14 -2
  253. package/types/utils/image.d.ts.map +1 -1
  254. package/types/utils/tensor.d.ts +39 -4
  255. package/types/utils/tensor.d.ts.map +1 -1
  256. package/src/processors.js +0 -2655
  257. package/types/processors.d.ts +0 -924
  258. package/types/processors.d.ts.map +0 -1
package/src/models.js CHANGED
@@ -61,7 +61,6 @@ import {
61
61
  } from './utils/generic.js';
62
62
 
63
63
  import {
64
- isIntegralNumber,
65
64
  mergeArrays,
66
65
  pick,
67
66
  } from './utils/core.js';
@@ -99,17 +98,20 @@ import {
99
98
 
100
99
  import {
101
100
  cat,
102
- full_like,
103
101
  mean,
102
+ zeros,
103
+ zeros_like,
104
104
  ones,
105
105
  ones_like,
106
+ full,
107
+ full_like,
106
108
  stack,
107
109
  std_mean,
108
110
  Tensor,
109
- zeros_like,
110
111
  } from './utils/tensor.js';
112
+ import { RawImage } from './utils/image.js';
111
113
 
112
- import { dynamic_time_warping, medianFilter } from './utils/maths.js';
114
+ import { dynamic_time_warping, max, medianFilter } from './utils/maths.js';
113
115
  import { EosTokenCriteria, MaxLengthCriteria, StoppingCriteriaList } from './generation/stopping_criteria.js';
114
116
  import { LogitsSampler } from './generation/logits_sampler.js';
115
117
  import { apis } from './env.js';
@@ -128,6 +130,7 @@ const MODEL_TYPES = {
128
130
  MaskGeneration: 5,
129
131
  ImageTextToText: 6,
130
132
  Musicgen: 7,
133
+ MultiModality: 8,
131
134
  }
132
135
  //////////////////////////////////////////////////
133
136
 
@@ -179,6 +182,22 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
179
182
  }
180
183
  }
181
184
 
185
+ if (dtype === DATA_TYPES.auto) {
186
+ // Try to choose the auto dtype based on the custom config
187
+ let config_dtype = custom_config.dtype;
188
+ if (typeof config_dtype !== 'string') {
189
+ config_dtype = config_dtype[fileName];
190
+ }
191
+
192
+ if (config_dtype && config_dtype !== DATA_TYPES.auto && DATA_TYPES.hasOwnProperty(config_dtype)) {
193
+ // Defined by the custom config, and is not "auto"
194
+ dtype = config_dtype;
195
+ } else {
196
+ // Choose default dtype based on device, falling back to fp32
197
+ dtype = DEFAULT_DEVICE_DTYPE_MAPPING[selectedDevice] ?? DATA_TYPES.fp32;
198
+ }
199
+ }
200
+
182
201
  const selectedDtype = /** @type {import("./utils/dtypes.js").DataType} */(dtype);
183
202
 
184
203
  if (!DEFAULT_DTYPE_SUFFIX_MAPPING.hasOwnProperty(selectedDtype)) {
@@ -384,9 +403,17 @@ async function sessionRun(session, inputs) {
384
403
  output = replaceTensors(output);
385
404
  return output;
386
405
  } catch (e) {
406
+ // Error messages can be long (nested) and uninformative. For this reason,
407
+ // we apply minor formatting to show the most important information
408
+ const formatted = Object.fromEntries(Object.entries(checkedInputs)
409
+ .map(([k, { type, dims, data }]) => [k, {
410
+ // Extract these properties from the underlying ORT tensor
411
+ type, dims, data,
412
+ }]));
413
+
387
414
  // This usually occurs when the inputs are of the wrong type.
388
415
  console.error(`An error occurred during model execution: "${e}".`);
389
- console.error('Inputs given to model:', checkedInputs);
416
+ console.error('Inputs given to model:', formatted);
390
417
  throw e;
391
418
  }
392
419
  }
@@ -543,6 +570,39 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
543
570
  }
544
571
 
545
572
 
573
+
574
+ function default_merge_input_ids_with_image_features({
575
+ image_token_id,
576
+ inputs_embeds,
577
+ image_features,
578
+ input_ids,
579
+ attention_mask,
580
+ }) {
581
+ const image_tokens = input_ids.tolist().map(ids =>
582
+ ids.reduce((acc, x, idx) => {
583
+ if (x == image_token_id) acc.push(idx);
584
+ return acc;
585
+ }, [])
586
+ );
587
+ const n_image_tokens = image_tokens.reduce((acc, x) => acc + x.length, 0);
588
+ const n_image_features = image_features.dims[0];
589
+ if (n_image_tokens !== n_image_features) {
590
+ throw new Error(`Image features and image tokens do not match: tokens: ${n_image_tokens}, features ${n_image_features}`);
591
+ }
592
+
593
+ // Equivalent to performing a masked_scatter
594
+ let img = 0;
595
+ for (let i = 0; i < image_tokens.length; ++i) {
596
+ const tokens = image_tokens[i];
597
+ const embeds = inputs_embeds[i];
598
+ for (let j = 0; j < tokens.length; ++j) {
599
+ embeds[tokens[j]].data.set(image_features[img++].data)
600
+ }
601
+ }
602
+ return { inputs_embeds, attention_mask }
603
+ }
604
+
605
+
546
606
  /**
547
607
  * Forward pass of an image-text-to-text model.
548
608
  * @param {Object} self The image-text-to-text model model.
@@ -579,11 +639,11 @@ async function imageTextToTextForward(self, {
579
639
 
580
640
  if (!inputs_embeds) {
581
641
  // 1. Extract the input embeddings
582
- inputs_embeds = await self.encode_text({ input_ids });
642
+ inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
583
643
 
584
644
  // 2. Possibly, merge text and images
585
645
  if (pixel_values && input_ids.dims[1] !== 1) {
586
- const image_features = await self.encode_image({ pixel_values });
646
+ const image_features = await self.encode_image({ pixel_values, ...kwargs });
587
647
 
588
648
  ({ inputs_embeds, attention_mask } = self._merge_input_ids_with_image_features({
589
649
  image_features,
@@ -604,6 +664,16 @@ async function imageTextToTextForward(self, {
604
664
  }
605
665
  }
606
666
 
667
+ if (!position_ids) {
668
+
669
+ if (self.config.model_type === 'qwen2_vl') {
670
+ // Special case for qwen2_vl models
671
+ // @ts-ignore
672
+ const { image_grid_thw, video_grid_thw } = kwargs;
673
+ [position_ids] = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
674
+ }
675
+ }
676
+
607
677
  const outputs = await decoderForward(self, {
608
678
  inputs_embeds,
609
679
  past_key_values,
@@ -615,34 +685,54 @@ async function imageTextToTextForward(self, {
615
685
  return outputs;
616
686
  }
617
687
 
618
- function createPositionIds(model_inputs, past_key_values = null) {
619
- // If the model supports providing position_ids, we create position_ids on the fly for batch generation,
620
- // by computing the cumulative sum of the attention mask along the sequence length dimension.
621
- //
622
- // Equivalent to:
623
- // position_ids = attention_mask.long().cumsum(-1) - 1
624
- // position_ids.masked_fill_(attention_mask == 0, 1)
625
- // if past_key_values:
626
- // position_ids = position_ids[:, -input_ids.shape[1] :]
627
- const { input_ids, inputs_embeds, attention_mask } = model_inputs;
688
+ /**
689
+ * Helper function to perform the following:
690
+ * ```python
691
+ * x = attention_mask.long().cumsum(-1) - 1
692
+ * x.masked_fill_(attention_mask == 0, 1)
693
+ * ```
694
+ * @param {Tensor} attention_mask
695
+ * @returns {{data: BigInt64Array, dims: number[]}}
696
+ */
697
+ function cumsum_masked_fill(attention_mask) {
628
698
  const [bz, seq_len] = attention_mask.dims;
699
+ const attn_mask_data = attention_mask.data;
629
700
 
630
- const data = new BigInt64Array(attention_mask.data.length);
701
+ const data = new BigInt64Array(attn_mask_data.length);
631
702
  for (let i = 0; i < bz; ++i) {
632
703
  const start = i * seq_len;
633
704
  let sum = BigInt(0);
634
705
  for (let j = 0; j < seq_len; ++j) {
635
706
  const index = start + j;
636
- if (attention_mask.data[index] === 0n) {
707
+ if (attn_mask_data[index] === 0n) {
637
708
  data[index] = BigInt(1);
638
709
  } else { // === 1n
639
710
  data[index] = sum;
640
- sum += attention_mask.data[index];
711
+ sum += attn_mask_data[index];
641
712
  }
642
713
  }
643
714
  }
715
+ return { data, dims: attention_mask.dims };
716
+
717
+ }
718
+
719
+ /**
720
+ * If the model supports providing position_ids, we create position_ids on the fly for batch generation,
721
+ * by computing the cumulative sum of the attention mask along the sequence length dimension.
722
+ *
723
+ * Equivalent to:
724
+ * ```python
725
+ * position_ids = attention_mask.long().cumsum(-1) - 1
726
+ * position_ids.masked_fill_(attention_mask == 0, 1)
727
+ * if past_key_values:
728
+ * position_ids = position_ids[:, -input_ids.shape[1] :]
729
+ * ```
730
+ */
731
+ function createPositionIds(model_inputs, past_key_values = null) {
732
+ const { input_ids, inputs_embeds, attention_mask } = model_inputs;
644
733
 
645
- let position_ids = new Tensor('int64', data, attention_mask.dims);
734
+ const { data, dims } = cumsum_masked_fill(attention_mask);
735
+ let position_ids = new Tensor('int64', data, dims);
646
736
  if (past_key_values) {
647
737
  const offset = -(input_ids ?? inputs_embeds).dims.at(1);
648
738
  position_ids = position_ids.slice(null, [offset, null]);
@@ -716,6 +806,52 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
716
806
  }
717
807
  }
718
808
 
809
+ function multimodality_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
810
+ const has_past_key_values = !!model_inputs.past_key_values;
811
+
812
+ if (generation_config.guidance_scale !== null && generation_config.guidance_scale > 1) {
813
+ if (has_past_key_values) {
814
+ model_inputs.input_ids = cat([
815
+ model_inputs.input_ids,
816
+ model_inputs.input_ids,
817
+ ], 0)
818
+ // NOTE: attention_mask handled in generation
819
+ } else {
820
+ model_inputs.input_ids = cat([
821
+ model_inputs.input_ids,
822
+ full_like(model_inputs.input_ids, BigInt(generation_config.pad_token_id)),
823
+ ], 0);
824
+ model_inputs.attention_mask = cat([
825
+ model_inputs.attention_mask,
826
+ full_like(model_inputs.attention_mask, 0n),
827
+ ], 0);
828
+ }
829
+ }
830
+
831
+ if (has_past_key_values || !model_inputs.pixel_values) {
832
+ model_inputs.pixel_values = full([0, 0, 3, 384, 384], 1.0);
833
+ }
834
+
835
+ if (has_past_key_values) {
836
+ const num_img_tokens = 0;
837
+ const num_text_tokens = 1;
838
+ const has_image = num_img_tokens > 0 ? 1 : 0;
839
+
840
+ const batch_size = 1;
841
+ model_inputs.images_seq_mask = new Tensor(
842
+ 'bool',
843
+ new Array(num_img_tokens + num_text_tokens).fill(true).fill(false, 0, num_text_tokens),
844
+ [batch_size, num_img_tokens + num_text_tokens],
845
+ );
846
+ model_inputs.images_emb_mask = new Tensor(
847
+ 'bool',
848
+ new Array(num_img_tokens).fill(!!has_image),
849
+ [batch_size, 1, num_img_tokens],
850
+ );
851
+ }
852
+ return model_inputs;
853
+ }
854
+
719
855
  //////////////////////////////////////////////////
720
856
 
721
857
  //////////////////////////////////////////////////
@@ -769,6 +905,11 @@ export class PreTrainedModel extends Callable {
769
905
  this._prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation;
770
906
  break;
771
907
 
908
+ case MODEL_TYPES.MultiModality:
909
+ this.can_generate = true;
910
+ this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
911
+ break;
912
+
772
913
  default:
773
914
  // should be MODEL_TYPES.EncoderOnly
774
915
  this._forward = encoderForward;
@@ -912,9 +1053,27 @@ export class PreTrainedModel extends Callable {
912
1053
  }, options),
913
1054
  ]);
914
1055
 
1056
+ } else if (modelType === MODEL_TYPES.MultiModality) {
1057
+ info = await Promise.all([
1058
+ constructSessions(pretrained_model_name_or_path, {
1059
+ prepare_inputs_embeds: 'prepare_inputs_embeds',
1060
+ model: 'language_model',
1061
+ lm_head: 'lm_head',
1062
+ gen_head: 'gen_head',
1063
+ gen_img_embeds: 'gen_img_embeds',
1064
+ image_decode: 'image_decode',
1065
+ }, options),
1066
+ getOptionalConfigs(pretrained_model_name_or_path, {
1067
+ generation_config: 'generation_config.json',
1068
+ }, options),
1069
+ ]);
1070
+
915
1071
  } else { // should be MODEL_TYPES.EncoderOnly
916
1072
  if (modelType !== MODEL_TYPES.EncoderOnly) {
917
- console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
1073
+ const type = modelName ?? config?.model_type;
1074
+ if (type !== 'custom') {
1075
+ console.warn(`Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
1076
+ }
918
1077
  }
919
1078
  info = await Promise.all([
920
1079
  constructSessions(pretrained_model_name_or_path, {
@@ -1658,7 +1817,8 @@ export class PreTrainedModel extends Callable {
1658
1817
  const dtype = session?.config?.kv_cache_dtype ?? 'float32';
1659
1818
  const empty = (dtype === 'float16') ? new Uint16Array() : [];
1660
1819
 
1661
- const shapes = getKeyValueShapes(this.config);
1820
+ const batch_size = (decoderFeeds[this.main_input_name] ?? decoderFeeds.attention_mask)?.dims?.[0] ?? 1;
1821
+ const shapes = getKeyValueShapes(this.config, { batch_size });
1662
1822
 
1663
1823
  for (const name in shapes) {
1664
1824
  decoderFeeds[name] = new Tensor(dtype, empty, shapes[name]);
@@ -3204,8 +3364,8 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
3204
3364
  export class LlavaPreTrainedModel extends PreTrainedModel {
3205
3365
  forward_params = [
3206
3366
  'input_ids',
3207
- 'pixel_values',
3208
3367
  'attention_mask',
3368
+ 'pixel_values',
3209
3369
  'position_ids',
3210
3370
  'past_key_values',
3211
3371
  ];
@@ -3277,6 +3437,7 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel {
3277
3437
  }
3278
3438
  //////////////////////////////////////////////////
3279
3439
 
3440
+ export class LlavaOnevisionForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration
3280
3441
  export class Moondream1ForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration
3281
3442
 
3282
3443
  export class Florence2PreTrainedModel extends PreTrainedModel {
@@ -3386,6 +3547,46 @@ export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel
3386
3547
  return decoder_outputs;
3387
3548
  }
3388
3549
  }
3550
+
3551
+
3552
+ //////////////////////////////////////////////////
3553
+ // Idefics3 Models
3554
+ export class Idefics3PreTrainedModel extends PreTrainedModel {
3555
+ forward_params = [
3556
+ 'input_ids',
3557
+ 'attention_mask',
3558
+ 'pixel_values',
3559
+ 'pixel_attention_mask',
3560
+ 'position_ids',
3561
+ 'past_key_values',
3562
+ ];
3563
+ }
3564
+
3565
+ /**
3566
+ * The LLAVA model which consists of a vision backbone and a language model.
3567
+ */
3568
+ export class Idefics3ForConditionalGeneration extends Idefics3PreTrainedModel {
3569
+
3570
+ async encode_image({ pixel_values, pixel_attention_mask }) {
3571
+ const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, pixel_attention_mask })).image_features;
3572
+ return features;
3573
+ }
3574
+
3575
+ _merge_input_ids_with_image_features(kwargs) {
3576
+ const vision_hidden_size = kwargs.image_features.dims.at(-1);
3577
+ const reshaped_image_hidden_states = kwargs.image_features.view(-1, vision_hidden_size);
3578
+
3579
+ return default_merge_input_ids_with_image_features({
3580
+ // @ts-ignore
3581
+ image_token_id: this.config.image_token_id,
3582
+ ...kwargs,
3583
+ image_features: reshaped_image_hidden_states,
3584
+ })
3585
+ }
3586
+ }
3587
+ //////////////////////////////////////////////////
3588
+
3589
+ //////////////////////////////////////////////////
3389
3590
  export class CLIPPreTrainedModel extends PreTrainedModel { }
3390
3591
 
3391
3592
  /**
@@ -3437,7 +3638,7 @@ export class CLIPModel extends CLIPPreTrainedModel { }
3437
3638
  * The text model from CLIP without any head or projection on top.
3438
3639
  */
3439
3640
  export class CLIPTextModel extends CLIPPreTrainedModel {
3440
- /** @type {PreTrainedModel.from_pretrained} */
3641
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3441
3642
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3442
3643
  // Update default model file name if not provided
3443
3644
  options.model_file_name ??= 'text_model';
@@ -3472,7 +3673,7 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
3472
3673
  * ```
3473
3674
  */
3474
3675
  export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
3475
- /** @type {PreTrainedModel.from_pretrained} */
3676
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3476
3677
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3477
3678
  // Update default model file name if not provided
3478
3679
  options.model_file_name ??= 'text_model';
@@ -3484,7 +3685,7 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
3484
3685
  * The vision model from CLIP without any head or projection on top.
3485
3686
  */
3486
3687
  export class CLIPVisionModel extends CLIPPreTrainedModel {
3487
- /** @type {PreTrainedModel.from_pretrained} */
3688
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3488
3689
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3489
3690
  // Update default model file name if not provided
3490
3691
  options.model_file_name ??= 'vision_model';
@@ -3519,7 +3720,7 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
3519
3720
  * ```
3520
3721
  */
3521
3722
  export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
3522
- /** @type {PreTrainedModel.from_pretrained} */
3723
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3523
3724
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3524
3725
  // Update default model file name if not provided
3525
3726
  options.model_file_name ??= 'vision_model';
@@ -3605,8 +3806,7 @@ export class SiglipModel extends SiglipPreTrainedModel { }
3605
3806
  * ```
3606
3807
  */
3607
3808
  export class SiglipTextModel extends SiglipPreTrainedModel {
3608
-
3609
- /** @type {PreTrainedModel.from_pretrained} */
3809
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3610
3810
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3611
3811
  // Update default model file name if not provided
3612
3812
  options.model_file_name ??= 'text_model';
@@ -3641,7 +3841,7 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
3641
3841
  * ```
3642
3842
  */
3643
3843
  export class SiglipVisionModel extends CLIPPreTrainedModel {
3644
- /** @type {PreTrainedModel.from_pretrained} */
3844
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3645
3845
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3646
3846
  // Update default model file name if not provided
3647
3847
  options.model_file_name ??= 'vision_model';
@@ -3655,6 +3855,67 @@ export class ChineseCLIPPreTrainedModel extends PreTrainedModel { }
3655
3855
  export class ChineseCLIPModel extends ChineseCLIPPreTrainedModel { }
3656
3856
  //////////////////////////////////////////////////
3657
3857
 
3858
+ //////////////////////////////////////////////////
3859
+ // JinaCLIP models
3860
+ export class JinaCLIPPreTrainedModel extends PreTrainedModel { }
3861
+
3862
+ export class JinaCLIPModel extends JinaCLIPPreTrainedModel {
3863
+ async forward(model_inputs) {
3864
+ const missing_text_inputs = !model_inputs.input_ids;
3865
+ const missing_image_inputs = !model_inputs.pixel_values;
3866
+
3867
+ if (missing_text_inputs && missing_image_inputs) {
3868
+ throw new Error('Either `input_ids` or `pixel_values` should be provided.');
3869
+ }
3870
+
3871
+ // If either `input_ids` or `pixel_values` aren't passed, we need to create dummy input since the model requires a value to be specified.
3872
+ if (missing_text_inputs) {
3873
+ // NOTE: We cannot pass zero-dimension tensor as input for input_ids.
3874
+ // Fortunately, the majority of time is spent in the vision encoder, so this shouldn't significantly impact performance.
3875
+ model_inputs.input_ids = ones([model_inputs.pixel_values.dims[0], 1]);
3876
+ }
3877
+
3878
+ if (missing_image_inputs) {
3879
+ // NOTE: Since we create a zero-sized tensor, this does not increase computation time.
3880
+ // @ts-ignore
3881
+ const { image_size } = this.config.vision_config;
3882
+ model_inputs.pixel_values = full([0, 3, image_size, image_size], 0.0); // (pass zero-dimension tensor)
3883
+ }
3884
+
3885
+ const { text_embeddings, image_embeddings, l2norm_text_embeddings, l2norm_image_embeddings } = await super.forward(model_inputs);
3886
+
3887
+ const result = {};
3888
+ if (!missing_text_inputs) {
3889
+ result.text_embeddings = text_embeddings;
3890
+ result.l2norm_text_embeddings = l2norm_text_embeddings;
3891
+ }
3892
+ if (!missing_image_inputs) {
3893
+ result.image_embeddings = image_embeddings;
3894
+ result.l2norm_image_embeddings = l2norm_image_embeddings;
3895
+ }
3896
+ return result
3897
+ }
3898
+ }
3899
+
3900
+ export class JinaCLIPTextModel extends JinaCLIPPreTrainedModel {
3901
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3902
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3903
+ // Update default model file name if not provided
3904
+ options.model_file_name ??= 'text_model';
3905
+ return super.from_pretrained(pretrained_model_name_or_path, options);
3906
+ }
3907
+ }
3908
+
3909
+ export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel {
3910
+ /** @type {typeof PreTrainedModel.from_pretrained} */
3911
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3912
+ // Update default model file name if not provided
3913
+ options.model_file_name ??= 'vision_model';
3914
+ return super.from_pretrained(pretrained_model_name_or_path, options);
3915
+ }
3916
+ }
3917
+ //////////////////////////////////////////////////
3918
+
3658
3919
 
3659
3920
  //////////////////////////////////////////////////
3660
3921
  // CLIPSeg models
@@ -3898,6 +4159,261 @@ export class Qwen2Model extends Qwen2PreTrainedModel { }
3898
4159
  export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
3899
4160
  //////////////////////////////////////////////////
3900
4161
 
4162
+ export class Qwen2VLPreTrainedModel extends PreTrainedModel {
4163
+ forward_params = [
4164
+ // Text inputs
4165
+ 'input_ids',
4166
+ 'attention_mask',
4167
+ 'position_ids',
4168
+ 'past_key_values',
4169
+
4170
+ // Vision inputs
4171
+ 'pixel_values',
4172
+ 'image_grid_thw',
4173
+ ];
4174
+ }
4175
+ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel {
4176
+
4177
+ /**
4178
+ * Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
4179
+ *
4180
+ * Explanation:
4181
+ * Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
4182
+ *
4183
+ * For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
4184
+ * Examples:
4185
+ * input_ids: [T T T T T], here T is for text.
4186
+ * temporal position_ids: [0, 1, 2, 3, 4]
4187
+ * height position_ids: [0, 1, 2, 3, 4]
4188
+ * width position_ids: [0, 1, 2, 3, 4]
4189
+ *
4190
+ * For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
4191
+ * and 1D rotary position embeddin for text part.
4192
+ * Examples:
4193
+ * Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
4194
+ * input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
4195
+ * vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
4196
+ * vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
4197
+ * vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
4198
+ * text temporal position_ids: [3, 4, 5, 6, 7]
4199
+ * text height position_ids: [3, 4, 5, 6, 7]
4200
+ * text width position_ids: [3, 4, 5, 6, 7]
4201
+ * Here we calculate the text start position_ids as the max vision position_ids plus 1.
4202
+ *
4203
+ * @param {Tensor} input_ids Indices of input sequence tokens in the vocabulary. Tensor of shape `(batch_size, sequence_length)`.
4204
+ * @param {Tensor} image_grid_thw (Optional) The temporal, height and width of feature shape of each image in LLM. Tensor of shape `(num_images, 3)`.
4205
+ * @param {Tensor} video_grid_thw (Optional) The temporal, height and width of feature shape of each video in LLM. Tensor of shape `(num_videos, 3)`.
4206
+ * @param {Tensor} attention_mask (Optional) Mask to avoid performing attention on padding token indices. Tensor of shape `(batch_size, sequence_length)`. Mask values selected in `[0, 1]`:
4207
+ * - 1 for tokens that are **not masked**,
4208
+ * - 0 for tokens that are **masked**.
4209
+ * @returns {[Tensor, Tensor]} [position_ids, mrope_position_deltas] with:
4210
+ * - position_ids: Tensor of shape `(3, batch_size, sequence_length)`.
4211
+ * - mrope_position_deltas: Tensor of shape `(batch_size)`.
4212
+ */
4213
+ get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask) {
4214
+ // @ts-ignore
4215
+ const { vision_config, image_token_id, video_token_id, vision_start_token_id } = this.config;
4216
+ const spatial_merge_size = vision_config.spatial_merge_size ?? 2;
4217
+
4218
+ const mrope_position_deltas = [];
4219
+ if (image_grid_thw || video_grid_thw) {
4220
+ let total_input_ids = input_ids.tolist();
4221
+ if (!attention_mask) {
4222
+ attention_mask = ones_like(input_ids);
4223
+ }
4224
+
4225
+ const attention_mask_list = attention_mask.tolist();
4226
+ const position_ids_list = Array.from({ length: 3 }, _ => Array.from({ length: input_ids.dims[0] }, _ => Array.from({ length: input_ids.dims[1] }, _ => 1)));
4227
+
4228
+ const image_grid_thw_list = image_grid_thw ? image_grid_thw.tolist() : [];
4229
+ const video_grid_thw_list = video_grid_thw ? video_grid_thw.tolist() : [];
4230
+
4231
+ let image_index = 0;
4232
+ let video_index = 0;
4233
+ for (let i = 0; i < total_input_ids.length; ++i) {
4234
+ const ids = total_input_ids[i].filter((_, j) => attention_mask_list[i][j] == 1);
4235
+
4236
+ const vision_start_indices = ids.reduce((acc, x, idx) => {
4237
+ if (x == vision_start_token_id) acc.push(idx);
4238
+ return acc;
4239
+ }, []);
4240
+
4241
+ const vision_tokens = vision_start_indices.map(x => ids[x + 1]);
4242
+ const image_nums = vision_tokens.filter(x => x == image_token_id).length;
4243
+ const video_nums = vision_tokens.filter(x => x == video_token_id).length;
4244
+
4245
+ let llm_pos_ids_list = [];
4246
+ let st = 0;
4247
+ let remain_images = image_nums;
4248
+ let remain_videos = video_nums;
4249
+ for (let j = 0; j < vision_tokens.length; ++j) {
4250
+ const next_image_token = ids.findIndex((x, i) => i > st && x == image_token_id);
4251
+ const next_video_token = ids.findIndex((x, i) => i > st && x == video_token_id);
4252
+
4253
+ const ed_image = (remain_images > 0 && next_image_token !== -1)
4254
+ ? next_image_token
4255
+ : ids.length + 1;
4256
+
4257
+ const ed_video = (remain_videos > 0 && next_video_token !== -1)
4258
+ ? next_video_token
4259
+ : ids.length + 1;
4260
+
4261
+ let ed;
4262
+ let t, h, w;
4263
+ if (ed_image < ed_video) {
4264
+ ([t, h, w] = image_grid_thw_list[image_index]);
4265
+ ++image_index;
4266
+ --remain_images;
4267
+ ed = ed_image;
4268
+ } else {
4269
+ ([t, h, w] = video_grid_thw_list[video_index]);
4270
+ ++video_index;
4271
+ --remain_videos;
4272
+ ed = ed_video;
4273
+ }
4274
+
4275
+ const [llm_grid_t, llm_grid_h, llm_grid_w] = [
4276
+ Number(t),
4277
+ Math.floor(Number(h) / spatial_merge_size),
4278
+ Math.floor(Number(w) / spatial_merge_size)
4279
+ ]
4280
+ const text_len = ed - st;
4281
+ const st_idx = llm_pos_ids_list.length > 0
4282
+ ? max(llm_pos_ids_list.at(-1))[0] + 1
4283
+ : 0;
4284
+
4285
+ llm_pos_ids_list.push(
4286
+ Array.from({ length: 3 * text_len }, (_, i) => st_idx + (i % text_len))
4287
+ )
4288
+
4289
+ const offset = text_len + st_idx;
4290
+ const grid_size = llm_grid_t * llm_grid_h * llm_grid_w;
4291
+ const t_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / (llm_grid_h * llm_grid_w)))
4292
+ const h_index = Array.from({ length: grid_size }, (_, i) => offset + Math.floor(i / llm_grid_w) % llm_grid_h)
4293
+ const w_index = Array.from({ length: grid_size }, (_, i) => offset + i % llm_grid_w)
4294
+
4295
+ llm_pos_ids_list.push([t_index, h_index, w_index].flat())
4296
+
4297
+ st = ed + grid_size;
4298
+ }
4299
+
4300
+ if (st < ids.length) {
4301
+ const st_idx = llm_pos_ids_list.length > 0
4302
+ ? max(llm_pos_ids_list.at(-1))[0] + 1
4303
+ : 0;
4304
+ const text_len = ids.length - st;
4305
+
4306
+ llm_pos_ids_list.push(
4307
+ Array.from({ length: 3 * text_len }, (_, i) => (st_idx + (i % text_len)))
4308
+ )
4309
+ }
4310
+
4311
+ // NOTE: Each item in llm_pos_ids_list is an array of shape (3, text_len),
4312
+ // meaning to perform concatenation along dim=1, we can do the following:
4313
+ const num_items = llm_pos_ids_list.reduce((acc, x) => acc + x.length, 0);
4314
+ const llm_positions = new Array(num_items);
4315
+ let index = 0;
4316
+ for (let x = 0; x < 3; ++x) {
4317
+ for (let y = 0; y < llm_pos_ids_list.length; ++y) {
4318
+ const val = llm_pos_ids_list[y];
4319
+ const text_len = val.length / 3;
4320
+ for (let z = x * text_len; z < (x + 1) * text_len; ++z) {
4321
+ llm_positions[index++] = val[z];
4322
+ }
4323
+ }
4324
+ }
4325
+
4326
+ let count = 0;
4327
+ const attn_mask = attention_mask_list[i];
4328
+ for (let y = 0; y < attn_mask.length; ++y) {
4329
+ if (attn_mask[y] == 1) {
4330
+ for (let x = 0; x < 3; ++x) {
4331
+ position_ids_list[x][i][y] = llm_positions[x * num_items / 3 + count];
4332
+ }
4333
+ ++count;
4334
+ }
4335
+ }
4336
+
4337
+ const max_llm_positions = max(llm_positions)[0];
4338
+ mrope_position_deltas.push(max_llm_positions + 1 - total_input_ids[i].length);
4339
+ }
4340
+
4341
+ return [
4342
+ new Tensor('int64', position_ids_list.flat(Infinity), [3, input_ids.dims[0], input_ids.dims[1]]),
4343
+ new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
4344
+ ];
4345
+
4346
+ } else { // Text-only
4347
+ if (attention_mask) {
4348
+ const { data, dims } = cumsum_masked_fill(attention_mask);
4349
+
4350
+ const position_ids = BigInt64Array.from(
4351
+ { length: 3 * data.length },
4352
+ (_, i) => data[i % data.length]
4353
+ );
4354
+ const mrope_position_deltas = Array.from(
4355
+ { length: dims[0] },
4356
+ (_, i) => max(data.subarray(dims[1] * i, dims[1] * (i + 1)))[0] + 1 + dims[1]
4357
+ );
4358
+
4359
+ return [
4360
+ new Tensor('int64', position_ids, [3, ...dims]),
4361
+ new Tensor('int64', mrope_position_deltas, [mrope_position_deltas.length, 1]),
4362
+ ]
4363
+ } else {
4364
+ const [batch_size, seq_length] = input_ids.dims;
4365
+ const position_ids = BigInt64Array.from(
4366
+ { length: 3 * batch_size * seq_length },
4367
+ (_, i) => BigInt(Math.floor(i % seq_length / batch_size)),
4368
+ );
4369
+
4370
+ return [
4371
+ new Tensor('int64', position_ids, [3, ...input_ids.dims]),
4372
+ zeros([batch_size, 1]),
4373
+ ]
4374
+ }
4375
+ }
4376
+ }
4377
+
4378
+ async encode_image({ pixel_values, image_grid_thw }) {
4379
+ const features = (await sessionRun(this.sessions['vision_encoder'], { pixel_values, grid_thw: image_grid_thw })).image_features;
4380
+ return features;
4381
+ }
4382
+
4383
+ _merge_input_ids_with_image_features(kwargs) {
4384
+ return default_merge_input_ids_with_image_features({
4385
+ // @ts-ignore
4386
+ image_token_id: this.config.image_token_id,
4387
+ ...kwargs
4388
+ })
4389
+ }
4390
+
4391
+ prepare_inputs_for_generation(input_ids, model_inputs, generation_config) {
4392
+ // Overwritten -- in specific circumstances we don't want to forward image inputs to the model
4393
+ if (model_inputs.attention_mask && !model_inputs.position_ids) {
4394
+ // Calculate position_ids and rope_deltas
4395
+ if (!model_inputs.past_key_values) {
4396
+ ([model_inputs.position_ids, model_inputs.rope_deltas] = this.get_rope_index(
4397
+ model_inputs.input_ids,
4398
+ model_inputs.image_grid_thw,
4399
+ model_inputs.video_grid_thw,
4400
+ model_inputs.attention_mask,
4401
+ ));
4402
+
4403
+ } else {
4404
+ model_inputs.pixel_values = null;
4405
+ // model_inputs.pixel_values_videos = null;
4406
+
4407
+ const delta = BigInt(Object.values(model_inputs.past_key_values)[0].dims.at(-2));
4408
+ const rope_deltas_list = model_inputs.rope_deltas.map(x => delta + x);
4409
+ model_inputs.position_ids = stack([rope_deltas_list, rope_deltas_list, rope_deltas_list], 0)
4410
+ }
4411
+ }
4412
+
4413
+ return model_inputs;
4414
+ }
4415
+ }
4416
+
3901
4417
 
3902
4418
  //////////////////////////////////////////////////
3903
4419
  // Phi models
@@ -3985,6 +4501,17 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
3985
4501
  }
3986
4502
  //////////////////////////////////////////////////
3987
4503
 
4504
+
4505
+ //////////////////////////////////////////////////
4506
+ export class VitPosePreTrainedModel extends PreTrainedModel { }
4507
+
4508
+ /**
4509
+ * The VitPose model with a pose estimation head on top.
4510
+ */
4511
+ export class VitPoseForPoseEstimation extends VitPosePreTrainedModel { }
4512
+ //////////////////////////////////////////////////
4513
+
4514
+
3988
4515
  //////////////////////////////////////////////////
3989
4516
  export class PvtPreTrainedModel extends PreTrainedModel { }
3990
4517
  export class PvtModel extends PvtPreTrainedModel { }
@@ -5583,8 +6110,7 @@ export class ClapModel extends ClapPreTrainedModel { }
5583
6110
  * ```
5584
6111
  */
5585
6112
  export class ClapTextModelWithProjection extends ClapPreTrainedModel {
5586
-
5587
- /** @type {PreTrainedModel.from_pretrained} */
6113
+ /** @type {typeof PreTrainedModel.from_pretrained} */
5588
6114
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
5589
6115
  // Update default model file name if not provided
5590
6116
  options.model_file_name ??= 'text_model';
@@ -5619,7 +6145,7 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
5619
6145
  * ```
5620
6146
  */
5621
6147
  export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
5622
- /** @type {PreTrainedModel.from_pretrained} */
6148
+ /** @type {typeof PreTrainedModel.from_pretrained} */
5623
6149
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
5624
6150
  // Update default model file name if not provided
5625
6151
  options.model_file_name ??= 'audio_model';
@@ -5970,6 +6496,170 @@ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel
5970
6496
 
5971
6497
  //////////////////////////////////////////////////
5972
6498
 
6499
+ export class MultiModalityPreTrainedModel extends PreTrainedModel { }
6500
+ export class MultiModalityCausalLM extends MultiModalityPreTrainedModel {
6501
+ forward_params = [
6502
+ // prepare_inputs_embeds
6503
+ 'input_ids',
6504
+ 'pixel_values',
6505
+ 'images_seq_mask',
6506
+ 'images_emb_mask',
6507
+
6508
+ // language_model
6509
+ 'attention_mask',
6510
+ 'position_ids',
6511
+ 'past_key_values',
6512
+ ];
6513
+
6514
+ constructor(...args) {
6515
+ super(...args);
6516
+
6517
+ // State-based approach to switch out which heads to use during generation
6518
+ this._generation_mode = 'text';
6519
+ }
6520
+
6521
+ async forward(model_inputs) {
6522
+ const mode = this._generation_mode ?? 'text';
6523
+
6524
+ // TODO support re-using PKVs for input_ids.dims[1] !== 1
6525
+ // if (model_inputs.past_key_values) {
6526
+ // // && model_inputs.input_ids.dims[1] === 1
6527
+ // }
6528
+
6529
+ let output_1;
6530
+ if (mode === 'text' || !model_inputs.past_key_values) {
6531
+ const session = this.sessions['prepare_inputs_embeds'];
6532
+ const prep_inputs = pick(model_inputs, session.inputNames);
6533
+ output_1 = await sessionRun(session, prep_inputs);
6534
+ } else {
6535
+ const session = this.sessions['gen_img_embeds'];
6536
+ const prep_inputs = pick({
6537
+ image_ids: model_inputs.input_ids,
6538
+ }, session.inputNames);
6539
+ output_1 = await sessionRun(session, prep_inputs);
6540
+ }
6541
+
6542
+ const input_2 = { ...model_inputs, ...output_1 }
6543
+ const output_2 = await decoderForward(this, input_2);
6544
+
6545
+ const head = this.sessions[
6546
+ mode === 'text'
6547
+ ? 'lm_head'
6548
+ : 'gen_head'
6549
+ ];
6550
+ if (!head) {
6551
+ throw new Error(`Unable to find "${head}" generation head`);
6552
+ }
6553
+
6554
+ const output_3 = await sessionRun(head, pick(output_2, head.inputNames))
6555
+
6556
+ return {
6557
+ ...output_1,
6558
+ ...output_2,
6559
+ ...output_3,
6560
+ };
6561
+ }
6562
+
6563
+ /**
6564
+ * @param {import('./generation/parameters.js').GenerationFunctionParameters} options
6565
+ */
6566
+ async generate(options) {
6567
+ this._generation_mode = 'text';
6568
+ return super.generate(options);
6569
+ }
6570
+
6571
+ /**
6572
+ * @param {import('./generation/parameters.js').GenerationFunctionParameters} options
6573
+ */
6574
+ async generate_images(options) {
6575
+ this._generation_mode = 'image';
6576
+
6577
+ const start_num_tokens = (options.inputs ?? options[this.main_input_name]).dims[1];
6578
+ const all_tokens = await super.generate(options);
6579
+
6580
+ const generated_tokens = (/** @type {Tensor} */(all_tokens)).slice(null, [start_num_tokens, null])
6581
+
6582
+ const image_decode = this.sessions['image_decode'];
6583
+ const { decoded_image } = await sessionRun(image_decode, {
6584
+ generated_tokens,
6585
+ });
6586
+
6587
+ // Equivalent to `np.clip((dec + 1) / 2 * 255, 0, 255)`
6588
+ const clamped = decoded_image
6589
+ .add_(1)
6590
+ .mul_(255 / 2)
6591
+ .clamp_(0, 255)
6592
+ .to('uint8');
6593
+
6594
+ // Return as a list of images
6595
+ const images = [];
6596
+ for (const tensor of clamped) {
6597
+ const img = RawImage.fromTensor(tensor);
6598
+ images.push(img);
6599
+ }
6600
+ return images;
6601
+ }
6602
+ }
6603
+
6604
+ export class MgpstrModelOutput extends ModelOutput {
6605
+ constructor({ char_logits, bpe_logits, wp_logits }) {
6606
+ super();
6607
+ this.char_logits = char_logits;
6608
+ this.bpe_logits = bpe_logits;
6609
+ this.wp_logits = wp_logits;
6610
+ }
6611
+
6612
+ get logits() {
6613
+ return [this.char_logits, this.bpe_logits, this.wp_logits];
6614
+ }
6615
+ }
6616
+
6617
+ export class MgpstrPreTrainedModel extends PreTrainedModel { }
6618
+
6619
+ /**
6620
+ * MGP-STR Model transformer with three classification heads on top
6621
+ * (three A^3 modules and three linear layer on top of the transformer encoder output) for scene text recognition (STR).
6622
+ */
6623
+ export class MgpstrForSceneTextRecognition extends MgpstrPreTrainedModel {
6624
+ /**
6625
+ * @param {any} model_inputs
6626
+ */
6627
+ async _call(model_inputs) {
6628
+ return new MgpstrModelOutput(await super._call(model_inputs));
6629
+ }
6630
+ }
6631
+
6632
+ //////////////////////////////////////////////////
6633
+ // PatchTST Transformer models
6634
+ export class PatchTSTPreTrainedModel extends PreTrainedModel { }
6635
+
6636
+ /**
6637
+ * The bare PatchTST Model outputting raw hidden-states without any specific head.
6638
+ */
6639
+ export class PatchTSTModel extends PatchTSTPreTrainedModel { }
6640
+
6641
+ /**
6642
+ * The PatchTST for prediction model.
6643
+ */
6644
+ export class PatchTSTForPrediction extends PatchTSTPreTrainedModel { }
6645
+ //////////////////////////////////////////////////
6646
+
6647
+ //////////////////////////////////////////////////
6648
+ // PatchTSMixer Transformer models
6649
+ export class PatchTSMixerPreTrainedModel extends PreTrainedModel { }
6650
+
6651
+ /**
6652
+ * The bare PatchTSMixer Model outputting raw hidden-states without any specific head.
6653
+ */
6654
+ export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel { }
6655
+
6656
+ /**
6657
+ * The PatchTSMixer for prediction model.
6658
+ */
6659
+ export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel { }
6660
+ //////////////////////////////////////////////////
6661
+
6662
+
5973
6663
  //////////////////////////////////////////////////
5974
6664
  // AutoModels, used to simplify construction of PreTrainedModels
5975
6665
  // (uses config to instantiate correct class)
@@ -6064,6 +6754,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6064
6754
  ['clipseg', ['CLIPSegModel', CLIPSegModel]],
6065
6755
  ['chinese_clip', ['ChineseCLIPModel', ChineseCLIPModel]],
6066
6756
  ['siglip', ['SiglipModel', SiglipModel]],
6757
+ ['jina_clip', ['JinaCLIPModel', JinaCLIPModel]],
6067
6758
  ['mobilebert', ['MobileBertModel', MobileBertModel]],
6068
6759
  ['squeezebert', ['SqueezeBertModel', SqueezeBertModel]],
6069
6760
  ['wav2vec2', ['Wav2Vec2Model', Wav2Vec2Model]],
@@ -6108,6 +6799,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6108
6799
  ['efficientnet', ['EfficientNetModel', EfficientNetModel]],
6109
6800
 
6110
6801
  ['decision_transformer', ['DecisionTransformerModel', DecisionTransformerModel]],
6802
+ ['patchtst', ['PatchTSTForPrediction', PatchTSTModel]],
6803
+ ['patchtsmixer', ['PatchTSMixerForPrediction', PatchTSMixerModel]],
6111
6804
 
6112
6805
  ['mobilenet_v1', ['MobileNetV1Model', MobileNetV1Model]],
6113
6806
  ['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]],
@@ -6115,6 +6808,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6115
6808
  ['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]],
6116
6809
 
6117
6810
  ['maskformer', ['MaskFormerModel', MaskFormerModel]],
6811
+ ['mgp-str', ['MgpstrForSceneTextRecognition', MgpstrForSceneTextRecognition]],
6118
6812
  ]);
6119
6813
 
6120
6814
  const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
@@ -6252,6 +6946,11 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
6252
6946
  ['stablelm', ['StableLmForCausalLM', StableLmForCausalLM]],
6253
6947
  ]);
6254
6948
 
6949
+ const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map([
6950
+ ['multi_modality', ['MultiModalityCausalLM', MultiModalityCausalLM]],
6951
+ ]);
6952
+
6953
+
6255
6954
  const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
6256
6955
  ['bert', ['BertForMaskedLM', BertForMaskedLM]],
6257
6956
  ['roformer', ['RoFormerForMaskedLM', RoFormerForMaskedLM]],
@@ -6291,12 +6990,16 @@ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
6291
6990
 
6292
6991
  const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
6293
6992
  ['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
6993
+ ['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
6294
6994
  ]);
6295
6995
 
6296
6996
  const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
6297
6997
  ['llava', ['LlavaForConditionalGeneration', LlavaForConditionalGeneration]],
6998
+ ['llava_onevision', ['LlavaOnevisionForConditionalGeneration', LlavaOnevisionForConditionalGeneration]],
6298
6999
  ['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]],
6299
7000
  ['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
7001
+ ['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
7002
+ ['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
6300
7003
  ]);
6301
7004
 
6302
7005
  const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
@@ -6392,6 +7095,11 @@ const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
6392
7095
  ['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]],
6393
7096
  ]);
6394
7097
 
7098
+ const MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = new Map([
7099
+ ['patchtst', ['PatchTSTForPrediction', PatchTSTForPrediction]],
7100
+ ['patchtsmixer', ['PatchTSMixerForPrediction', PatchTSMixerForPrediction]],
7101
+ ])
7102
+
6395
7103
  const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([
6396
7104
  ['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]],
6397
7105
  ])
@@ -6408,11 +7116,16 @@ const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
6408
7116
  ['sapiens', ['SapiensForNormalEstimation', SapiensForNormalEstimation]],
6409
7117
  ])
6410
7118
 
7119
+ const MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES = new Map([
7120
+ ['vitpose', ['VitPoseForPoseEstimation', VitPoseForPoseEstimation]],
7121
+ ])
7122
+
6411
7123
  // NOTE: This is custom to Transformers.js, and is necessary because certain models
6412
7124
  // (e.g., CLIP) are split into vision and text components
6413
7125
  const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([
6414
7126
  ['clip', ['CLIPVisionModelWithProjection', CLIPVisionModelWithProjection]],
6415
7127
  ['siglip', ['SiglipVisionModel', SiglipVisionModel]],
7128
+ ['jina_clip', ['JinaCLIPVisionModel', JinaCLIPVisionModel]],
6416
7129
  ])
6417
7130
 
6418
7131
  const MODEL_CLASS_TYPE_MAPPING = [
@@ -6424,6 +7137,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
6424
7137
  [MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
6425
7138
  [MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
6426
7139
  [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.DecoderOnly],
7140
+ [MODEL_FOR_MULTIMODALITY_MAPPING_NAMES, MODEL_TYPES.MultiModality],
6427
7141
  [MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6428
7142
  [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6429
7143
  [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
@@ -6433,9 +7147,11 @@ const MODEL_CLASS_TYPE_MAPPING = [
6433
7147
  [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6434
7148
  [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6435
7149
  [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
7150
+ [MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6436
7151
  [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6437
7152
  [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6438
7153
  [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
7154
+ [MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6439
7155
  [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6440
7156
  [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6441
7157
  [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
@@ -6466,6 +7182,7 @@ const CUSTOM_MAPPING = [
6466
7182
 
6467
7183
  ['CLIPTextModelWithProjection', CLIPTextModelWithProjection, MODEL_TYPES.EncoderOnly],
6468
7184
  ['SiglipTextModel', SiglipTextModel, MODEL_TYPES.EncoderOnly],
7185
+ ['JinaCLIPTextModel', JinaCLIPTextModel, MODEL_TYPES.EncoderOnly],
6469
7186
  ['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
6470
7187
  ['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly],
6471
7188
  ]
@@ -6707,6 +7424,10 @@ export class AutoModelForNormalEstimation extends PretrainedMixin {
6707
7424
  static MODEL_CLASS_MAPPINGS = [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES];
6708
7425
  }
6709
7426
 
7427
+ export class AutoModelForPoseEstimation extends PretrainedMixin {
7428
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_POSE_ESTIMATION_MAPPING_NAMES];
7429
+ }
7430
+
6710
7431
  export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
6711
7432
  static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
6712
7433
  }