@huggingface/transformers 4.0.0-next.6 → 4.0.0-next.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. package/README.md +16 -2
  2. package/dist/ort-wasm-simd-threaded.jsep.mjs +24 -24
  3. package/dist/transformers.js +2255 -931
  4. package/dist/transformers.min.js +19 -19
  5. package/dist/transformers.node.cjs +2300 -934
  6. package/dist/transformers.node.min.cjs +20 -20
  7. package/dist/transformers.node.min.mjs +20 -20
  8. package/dist/transformers.node.mjs +2336 -1012
  9. package/dist/transformers.web.js +2327 -1003
  10. package/dist/transformers.web.min.js +17 -17
  11. package/package.json +4 -4
  12. package/src/cache_utils.js +62 -0
  13. package/src/configs.js +45 -24
  14. package/src/env.js +8 -1
  15. package/src/image_processors_utils.js +27 -17
  16. package/src/models/chatterbox/modeling_chatterbox.js +1 -1
  17. package/src/models/chmv2/image_processing_chmv2.js +3 -0
  18. package/src/models/chmv2/modeling_chmv2.js +4 -0
  19. package/src/models/deepseek_v3/modeling_deepseek_v3.js +5 -0
  20. package/src/models/detr/image_processing_detr.js +1 -1
  21. package/src/models/eurobert/modeling_eurobert.js +41 -0
  22. package/src/models/feature_extractors.js +2 -0
  23. package/src/models/gemma3n/modeling_gemma3n.js +2 -0
  24. package/src/models/glm46v/image_processing_glm46v.js +12 -0
  25. package/src/models/glm46v/processing_glm46v.js +5 -0
  26. package/src/models/glm_moe_dsa/modeling_glm_moe_dsa.js +5 -0
  27. package/src/models/glm_ocr/modeling_glm_ocr.js +78 -0
  28. package/src/models/granite_speech/feature_extraction_granite_speech.js +58 -0
  29. package/src/models/granite_speech/modeling_granite_speech.js +5 -0
  30. package/src/models/granite_speech/processing_granite_speech.js +62 -0
  31. package/src/models/grounding_dino/image_processing_grounding_dino.js +1 -1
  32. package/src/models/idefics3/modeling_idefics3.js +5 -32
  33. package/src/models/image_processors.js +3 -0
  34. package/src/models/lfm2_vl/image_processing_lfm2_vl.js +305 -0
  35. package/src/models/lfm2_vl/modeling_lfm2_vl.js +13 -0
  36. package/src/models/lfm2_vl/processing_lfm2_vl.js +77 -0
  37. package/src/models/lighton_ocr/modeling_lighton_ocr.js +3 -0
  38. package/src/models/llava/modeling_llava.js +1 -1
  39. package/src/models/mistral3/modeling_mistral3.js +2 -2
  40. package/src/models/mistral4/modeling_mistral4.js +5 -0
  41. package/src/models/modeling_utils.js +224 -308
  42. package/src/models/models.js +14 -1
  43. package/src/models/nemotron_h/modeling_nemotron_h.js +5 -0
  44. package/src/models/paligemma/modeling_paligemma.js +2 -25
  45. package/src/models/processors.js +4 -0
  46. package/src/models/qwen2_5_vl/modeling_qwen2_5_vl.js +5 -1
  47. package/src/models/qwen2_vl/image_processing_qwen2_vl.js +1 -41
  48. package/src/models/qwen2_vl/modeling_qwen2_vl.js +194 -143
  49. package/src/models/qwen2_vl/processing_qwen2_vl.js +5 -4
  50. package/src/models/qwen3_5/modeling_qwen3_5.js +1 -0
  51. package/src/models/qwen3_5_moe/modeling_qwen3_5_moe.js +2 -1
  52. package/src/models/qwen3_vl/modeling_qwen3_vl.js +2 -1
  53. package/src/models/qwen3_vl_moe/modeling_qwen3_vl_moe.js +2 -1
  54. package/src/models/registry.js +42 -0
  55. package/src/models/sam/image_processing_sam.js +1 -1
  56. package/src/models/session.js +17 -6
  57. package/src/models/smolvlm/modeling_smolvlm.js +7 -0
  58. package/src/models/solar_open/modeling_solar_open.js +5 -0
  59. package/src/models/ultravox/modeling_ultravox.js +1 -3
  60. package/src/models/voxtral/modeling_voxtral.js +3 -0
  61. package/src/models/voxtral_realtime/feature_extraction_voxtral_realtime.js +71 -0
  62. package/src/models/voxtral_realtime/modeling_voxtral_realtime.js +239 -0
  63. package/src/models/voxtral_realtime/processing_voxtral_realtime.js +113 -0
  64. package/src/models/whisper/feature_extraction_whisper.js +2 -12
  65. package/src/pipelines.js +1 -0
  66. package/src/transformers.js +2 -0
  67. package/src/utils/audio.js +18 -2
  68. package/src/utils/cache/CrossOriginStorageCache.js +251 -0
  69. package/src/utils/cache/cross-origin-storage.d.ts +38 -0
  70. package/src/utils/cache.js +5 -0
  71. package/src/utils/hub.js +4 -1
  72. package/src/utils/lru_cache.js +67 -0
  73. package/src/utils/memoize_promise.js +45 -0
  74. package/src/utils/model_registry/get_file_metadata.js +15 -2
  75. package/src/utils/model_registry/get_model_files.js +52 -78
  76. package/src/utils/tensor.js +18 -2
  77. package/types/cache_utils.d.ts +29 -0
  78. package/types/cache_utils.d.ts.map +1 -0
  79. package/types/configs.d.ts.map +1 -1
  80. package/types/env.d.ts +8 -0
  81. package/types/env.d.ts.map +1 -1
  82. package/types/image_processors_utils.d.ts +18 -1
  83. package/types/image_processors_utils.d.ts.map +1 -1
  84. package/types/models/{ast/modeling_ast.d.ts → audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.d.ts} +1 -1
  85. package/types/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.d.ts.map +1 -0
  86. package/types/models/chmv2/image_processing_chmv2.d.ts +4 -0
  87. package/types/models/chmv2/image_processing_chmv2.d.ts.map +1 -0
  88. package/types/models/chmv2/modeling_chmv2.d.ts +6 -0
  89. package/types/models/chmv2/modeling_chmv2.d.ts.map +1 -0
  90. package/types/models/deepseek_v3/modeling_deepseek_v3.d.ts +8 -0
  91. package/types/models/deepseek_v3/modeling_deepseek_v3.d.ts.map +1 -0
  92. package/types/models/detr/image_processing_detr.d.ts +1 -1
  93. package/types/models/eurobert/modeling_eurobert.d.ts +36 -0
  94. package/types/models/eurobert/modeling_eurobert.d.ts.map +1 -0
  95. package/types/models/feature_extractors.d.ts +2 -0
  96. package/types/models/gemma3n/modeling_gemma3n.d.ts +2 -0
  97. package/types/models/gemma3n/modeling_gemma3n.d.ts.map +1 -1
  98. package/types/models/glm46v/image_processing_glm46v.d.ts +4 -0
  99. package/types/models/glm46v/image_processing_glm46v.d.ts.map +1 -0
  100. package/types/models/glm46v/processing_glm46v.d.ts +4 -0
  101. package/types/models/glm46v/processing_glm46v.d.ts.map +1 -0
  102. package/types/models/glm_moe_dsa/modeling_glm_moe_dsa.d.ts +8 -0
  103. package/types/models/glm_moe_dsa/modeling_glm_moe_dsa.d.ts.map +1 -0
  104. package/types/models/glm_ocr/modeling_glm_ocr.d.ts +26 -0
  105. package/types/models/glm_ocr/modeling_glm_ocr.d.ts.map +1 -0
  106. package/types/models/granite_speech/feature_extraction_granite_speech.d.ts +16 -0
  107. package/types/models/granite_speech/feature_extraction_granite_speech.d.ts.map +1 -0
  108. package/types/models/granite_speech/modeling_granite_speech.d.ts +4 -0
  109. package/types/models/granite_speech/modeling_granite_speech.d.ts.map +1 -0
  110. package/types/models/granite_speech/processing_granite_speech.d.ts +19 -0
  111. package/types/models/granite_speech/processing_granite_speech.d.ts.map +1 -0
  112. package/types/models/grounding_dino/image_processing_grounding_dino.d.ts +1 -1
  113. package/types/models/idefics3/modeling_idefics3.d.ts +2 -18
  114. package/types/models/idefics3/modeling_idefics3.d.ts.map +1 -1
  115. package/types/models/image_processors.d.ts +3 -0
  116. package/types/models/lfm2_vl/image_processing_lfm2_vl.d.ts +41 -0
  117. package/types/models/lfm2_vl/image_processing_lfm2_vl.d.ts.map +1 -0
  118. package/types/models/lfm2_vl/modeling_lfm2_vl.d.ts +4 -0
  119. package/types/models/lfm2_vl/modeling_lfm2_vl.d.ts.map +1 -0
  120. package/types/models/lfm2_vl/processing_lfm2_vl.d.ts +18 -0
  121. package/types/models/lfm2_vl/processing_lfm2_vl.d.ts.map +1 -0
  122. package/types/models/lighton_ocr/modeling_lighton_ocr.d.ts +4 -0
  123. package/types/models/lighton_ocr/modeling_lighton_ocr.d.ts.map +1 -0
  124. package/types/models/mistral3/modeling_mistral3.d.ts +2 -2
  125. package/types/models/mistral3/modeling_mistral3.d.ts.map +1 -1
  126. package/types/models/mistral4/modeling_mistral4.d.ts +8 -0
  127. package/types/models/mistral4/modeling_mistral4.d.ts.map +1 -0
  128. package/types/models/modeling_utils.d.ts +44 -35
  129. package/types/models/modeling_utils.d.ts.map +1 -1
  130. package/types/models/models.d.ts +14 -1
  131. package/types/models/nemotron_h/modeling_nemotron_h.d.ts +8 -0
  132. package/types/models/nemotron_h/modeling_nemotron_h.d.ts.map +1 -0
  133. package/types/models/paligemma/modeling_paligemma.d.ts +2 -8
  134. package/types/models/paligemma/modeling_paligemma.d.ts.map +1 -1
  135. package/types/models/processors.d.ts +4 -0
  136. package/types/models/qwen2_5_vl/modeling_qwen2_5_vl.d.ts +3 -0
  137. package/types/models/qwen2_5_vl/modeling_qwen2_5_vl.d.ts.map +1 -1
  138. package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts.map +1 -1
  139. package/types/models/qwen2_vl/modeling_qwen2_vl.d.ts +43 -6
  140. package/types/models/qwen2_vl/modeling_qwen2_vl.d.ts.map +1 -1
  141. package/types/models/qwen2_vl/processing_qwen2_vl.d.ts +1 -0
  142. package/types/models/qwen2_vl/processing_qwen2_vl.d.ts.map +1 -1
  143. package/types/models/qwen3_5/modeling_qwen3_5.d.ts +2 -0
  144. package/types/models/qwen3_5/modeling_qwen3_5.d.ts.map +1 -1
  145. package/types/models/qwen3_5_moe/modeling_qwen3_5_moe.d.ts +3 -0
  146. package/types/models/qwen3_5_moe/modeling_qwen3_5_moe.d.ts.map +1 -1
  147. package/types/models/qwen3_vl/modeling_qwen3_vl.d.ts +3 -0
  148. package/types/models/qwen3_vl/modeling_qwen3_vl.d.ts.map +1 -1
  149. package/types/models/qwen3_vl_moe/modeling_qwen3_vl_moe.d.ts +3 -0
  150. package/types/models/qwen3_vl_moe/modeling_qwen3_vl_moe.d.ts.map +1 -1
  151. package/types/models/registry.d.ts.map +1 -1
  152. package/types/models/sam/image_processing_sam.d.ts +1 -1
  153. package/types/models/session.d.ts +3 -2
  154. package/types/models/session.d.ts.map +1 -1
  155. package/types/models/smolvlm/modeling_smolvlm.d.ts +8 -0
  156. package/types/models/smolvlm/modeling_smolvlm.d.ts.map +1 -0
  157. package/types/models/solar_open/modeling_solar_open.d.ts +8 -0
  158. package/types/models/solar_open/modeling_solar_open.d.ts.map +1 -0
  159. package/types/models/ultravox/modeling_ultravox.d.ts +0 -2
  160. package/types/models/ultravox/modeling_ultravox.d.ts.map +1 -1
  161. package/types/models/voxtral/modeling_voxtral.d.ts +4 -0
  162. package/types/models/voxtral/modeling_voxtral.d.ts.map +1 -0
  163. package/types/models/voxtral_realtime/feature_extraction_voxtral_realtime.d.ts +28 -0
  164. package/types/models/voxtral_realtime/feature_extraction_voxtral_realtime.d.ts.map +1 -0
  165. package/types/models/voxtral_realtime/modeling_voxtral_realtime.d.ts +17 -0
  166. package/types/models/voxtral_realtime/modeling_voxtral_realtime.d.ts.map +1 -0
  167. package/types/models/voxtral_realtime/processing_voxtral_realtime.d.ts +44 -0
  168. package/types/models/voxtral_realtime/processing_voxtral_realtime.d.ts.map +1 -0
  169. package/types/models/whisper/feature_extraction_whisper.d.ts.map +1 -1
  170. package/types/pipelines.d.ts +1 -0
  171. package/types/pipelines.d.ts.map +1 -1
  172. package/types/transformers.d.ts +1 -0
  173. package/types/transformers.d.ts.map +1 -1
  174. package/types/utils/audio.d.ts +5 -2
  175. package/types/utils/audio.d.ts.map +1 -1
  176. package/types/utils/cache/CrossOriginStorageCache.d.ts +120 -0
  177. package/types/utils/cache/CrossOriginStorageCache.d.ts.map +1 -0
  178. package/types/utils/cache.d.ts.map +1 -1
  179. package/types/utils/dtypes.d.ts +1 -1
  180. package/types/utils/hub.d.ts.map +1 -1
  181. package/types/utils/image.d.ts +1 -1
  182. package/types/utils/lru_cache.d.ts +38 -0
  183. package/types/utils/lru_cache.d.ts.map +1 -0
  184. package/types/utils/memoize_promise.d.ts +14 -0
  185. package/types/utils/memoize_promise.d.ts.map +1 -0
  186. package/types/utils/model_registry/get_file_metadata.d.ts.map +1 -1
  187. package/types/utils/model_registry/get_model_files.d.ts +1 -0
  188. package/types/utils/model_registry/get_model_files.d.ts.map +1 -1
  189. package/types/utils/tensor.d.ts.map +1 -1
  190. package/src/utils/data-structures.js +0 -572
  191. package/types/models/ast/modeling_ast.d.ts.map +0 -1
  192. package/types/utils/data-structures.d.ts +0 -294
  193. package/types/utils/data-structures.d.ts.map +0 -1
  194. /package/src/models/{ast/modeling_ast.js → audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.js} +0 -0
@@ -36,26 +36,7 @@ import { LogitsSampler } from '../generation/logits_sampler.js';
36
36
  import { pick } from '../utils/core.js';
37
37
  import { ModelOutput } from './modeling_outputs.js';
38
38
  import { logger } from '../utils/logger.js';
39
-
40
- /**
41
- * Extract the past sequence length from a past_key_values object.
42
- * For standard models, all entries are attention KV caches with shape [batch, heads, seq_len, head_dim].
43
- * For hybrid models (e.g., Qwen3.5 with conv/recurrent + attention layers), the first entry
44
- * may be a conv or recurrent state whose dims don't encode a sequence length.
45
- * This function finds a `past_key_values.*` entry (standard attention cache) to determine the true past length.
46
- *
47
- * @param {Record<string, Tensor>} past_key_values
48
- * @returns {number} The past sequence length.
49
- */
50
- export function getPastLength(past_key_values) {
51
- for (const name in past_key_values) {
52
- if (name.startsWith('past_key_values.')) {
53
- return past_key_values[name].dims.at(-2);
54
- }
55
- }
56
- // Fallback for non-hybrid models (all entries are attention KV)
57
- return Object.values(past_key_values)[0].dims.at(-2);
58
- }
39
+ import { DynamicCache } from '../cache_utils.js';
59
40
 
60
41
  /**
61
42
  * Converts an array or Tensor of integers to an int64 Tensor.
@@ -118,6 +99,8 @@ export const MODEL_TYPES = {
118
99
  ImageAudioTextToText: 13,
119
100
  Supertonic: 14,
120
101
  Chatterbox: 15,
102
+ MultimodalLanguageModelOnly: 16,
103
+ VoxtralRealtime: 17,
121
104
  };
122
105
 
123
106
  const MODEL_TYPE_CONFIG = {
@@ -125,65 +108,181 @@ const MODEL_TYPE_CONFIG = {
125
108
  can_generate: true,
126
109
  forward: decoder_forward,
127
110
  prepare_inputs: decoder_prepare_inputs_for_generation,
111
+ sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
112
+ cache_sessions: { model: true },
113
+ optional_configs: { generation_config: 'generation_config.json' },
128
114
  },
129
115
  [MODEL_TYPES.DecoderOnlyWithoutHead]: {
130
116
  can_generate: false,
131
117
  forward: decoder_forward,
132
118
  prepare_inputs: decoder_prepare_inputs_for_generation,
119
+ sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
133
120
  },
134
121
  [MODEL_TYPES.Seq2Seq]: {
135
122
  can_generate: true,
136
123
  forward: seq2seq_forward,
137
124
  prepare_inputs: encoder_decoder_prepare_inputs_for_generation,
125
+ sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
126
+ cache_sessions: { decoder_model_merged: true },
127
+ optional_configs: { generation_config: 'generation_config.json' },
138
128
  },
139
129
  [MODEL_TYPES.Vision2Seq]: {
140
130
  can_generate: true,
141
131
  forward: seq2seq_forward,
142
132
  prepare_inputs: encoder_decoder_prepare_inputs_for_generation,
133
+ sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
134
+ cache_sessions: { decoder_model_merged: true },
135
+ optional_configs: { generation_config: 'generation_config.json' },
143
136
  },
144
137
  [MODEL_TYPES.Musicgen]: {
145
138
  can_generate: true,
146
139
  forward: seq2seq_forward,
140
+ sessions: () => ({
141
+ model: 'text_encoder',
142
+ decoder_model_merged: 'decoder_model_merged',
143
+ encodec_decode: 'encodec_decode',
144
+ }),
145
+ cache_sessions: { decoder_model_merged: true },
146
+ optional_configs: { generation_config: 'generation_config.json' },
147
147
  },
148
148
  [MODEL_TYPES.EncoderDecoder]: {
149
149
  can_generate: false,
150
150
  forward: seq2seq_forward,
151
+ sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
152
+ cache_sessions: { decoder_model_merged: true },
153
+ },
154
+ [MODEL_TYPES.MaskGeneration]: {
155
+ sessions: () => ({ model: 'vision_encoder', prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder' }),
151
156
  },
152
157
  [MODEL_TYPES.ImageTextToText]: {
153
158
  can_generate: true,
154
159
  forward: image_text_to_text_forward,
155
160
  prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
161
+ sessions: (config) => {
162
+ const s = {
163
+ embed_tokens: 'embed_tokens',
164
+ vision_encoder: 'vision_encoder',
165
+ decoder_model_merged: 'decoder_model_merged',
166
+ };
167
+ if (config.is_encoder_decoder) s['model'] = 'encoder_model';
168
+ return s;
169
+ },
170
+ cache_sessions: { decoder_model_merged: true },
171
+ optional_configs: { generation_config: 'generation_config.json' },
156
172
  },
157
173
  [MODEL_TYPES.AudioTextToText]: {
158
174
  can_generate: true,
159
175
  forward: audio_text_to_text_forward,
160
176
  prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
177
+ sessions: () => ({
178
+ embed_tokens: 'embed_tokens',
179
+ audio_encoder: 'audio_encoder',
180
+ decoder_model_merged: 'decoder_model_merged',
181
+ }),
182
+ cache_sessions: { decoder_model_merged: true },
183
+ optional_configs: { generation_config: 'generation_config.json' },
161
184
  },
162
- [MODEL_TYPES.Phi3V]: {
185
+ [MODEL_TYPES.ImageAudioTextToText]: {
163
186
  can_generate: true,
164
187
  prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
188
+ sessions: () => ({
189
+ embed_tokens: 'embed_tokens',
190
+ audio_encoder: 'audio_encoder',
191
+ vision_encoder: 'vision_encoder',
192
+ decoder_model_merged: 'decoder_model_merged',
193
+ }),
194
+ optional_configs: { generation_config: 'generation_config.json' },
165
195
  },
166
- [MODEL_TYPES.ImageAudioTextToText]: {
196
+ [MODEL_TYPES.Phi3V]: {
167
197
  can_generate: true,
168
198
  prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
199
+ sessions: () => ({
200
+ prepare_inputs_embeds: 'prepare_inputs_embeds',
201
+ model: 'model',
202
+ vision_encoder: 'vision_encoder',
203
+ }),
204
+ cache_sessions: { model: true },
205
+ optional_configs: { generation_config: 'generation_config.json' },
169
206
  },
170
207
  [MODEL_TYPES.MultiModality]: {
171
208
  can_generate: true,
209
+ sessions: () => ({
210
+ prepare_inputs_embeds: 'prepare_inputs_embeds',
211
+ model: 'language_model',
212
+ lm_head: 'lm_head',
213
+ gen_head: 'gen_head',
214
+ gen_img_embeds: 'gen_img_embeds',
215
+ image_decode: 'image_decode',
216
+ }),
217
+ cache_sessions: { model: true },
218
+ optional_configs: { generation_config: 'generation_config.json' },
172
219
  },
173
220
  [MODEL_TYPES.AutoEncoder]: {
174
221
  can_generate: false,
175
222
  forward: auto_encoder_forward,
223
+ sessions: () => ({ encoder_model: 'encoder_model', decoder_model: 'decoder_model' }),
224
+ },
225
+ [MODEL_TYPES.Supertonic]: {
226
+ sessions: () => ({
227
+ text_encoder: 'text_encoder',
228
+ latent_denoiser: 'latent_denoiser',
229
+ voice_decoder: 'voice_decoder',
230
+ }),
176
231
  },
177
232
  [MODEL_TYPES.Chatterbox]: {
178
233
  can_generate: true,
179
234
  forward: encoder_forward,
235
+ sessions: () => ({
236
+ embed_tokens: 'embed_tokens',
237
+ speech_encoder: 'speech_encoder',
238
+ model: 'language_model',
239
+ conditional_decoder: 'conditional_decoder',
240
+ }),
241
+ cache_sessions: { model: true },
242
+ optional_configs: { generation_config: 'generation_config.json' },
243
+ },
244
+ [MODEL_TYPES.MultimodalLanguageModelOnly]: {
245
+ can_generate: true,
246
+ forward: image_text_to_text_forward,
247
+ prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
248
+ sessions: () => ({ embed_tokens: 'embed_tokens', decoder_model_merged: 'decoder_model_merged' }),
249
+ cache_sessions: { decoder_model_merged: true },
250
+ optional_configs: { generation_config: 'generation_config.json' },
251
+ },
252
+ [MODEL_TYPES.VoxtralRealtime]: {
253
+ can_generate: true,
254
+ prepare_inputs: decoder_prepare_inputs_for_generation,
255
+ sessions: () => ({
256
+ embed_tokens: 'embed_tokens',
257
+ audio_encoder: 'audio_encoder',
258
+ decoder_model_merged: 'decoder_model_merged',
259
+ }),
260
+ cache_sessions: { decoder_model_merged: true, audio_encoder: true },
261
+ optional_configs: { generation_config: 'generation_config.json' },
180
262
  },
181
263
  default: {
182
264
  can_generate: false,
183
265
  forward: encoder_forward,
266
+ sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
184
267
  },
185
268
  };
186
269
 
270
+ /**
271
+ * Get the session configuration for a given model type.
272
+ * @param {number} modelType The model type enum value.
273
+ * @param {Object} config The model config.
274
+ * @param {Object} [options] Loading options.
275
+ * @returns {{ sessions: Record<string, string>, cache_sessions?: Record<string, true>, optional_configs?: Record<string, string> }}
276
+ */
277
+ export function getSessionsConfig(modelType, config, options = {}) {
278
+ const typeConfig = MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default;
279
+ return {
280
+ sessions: typeConfig.sessions(config, options),
281
+ cache_sessions: typeConfig.cache_sessions,
282
+ optional_configs: typeConfig.optional_configs,
283
+ };
284
+ }
285
+
187
286
  export const MODEL_TYPE_MAPPING = new Map();
188
287
  export const MODEL_NAME_TO_CLASS_MAPPING = new Map();
189
288
  export const MODEL_CLASS_TO_NAME_MAPPING = new Map();
@@ -290,247 +389,26 @@ export class PreTrainedModel extends Callable {
290
389
 
291
390
  config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
292
391
 
293
- let info;
294
- if (modelType === MODEL_TYPES.DecoderOnly) {
295
- info = await Promise.all([
296
- constructSessions(
297
- pretrained_model_name_or_path,
298
- {
299
- model: options.model_file_name ?? 'model',
300
- },
301
- options,
302
- 'model',
303
- ),
304
- get_optional_configs(
305
- pretrained_model_name_or_path,
306
- {
307
- generation_config: 'generation_config.json',
308
- },
309
- options,
310
- ),
311
- ]);
312
- } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
313
- info = await Promise.all([
314
- constructSessions(
315
- pretrained_model_name_or_path,
316
- {
317
- model: 'encoder_model',
318
- decoder_model_merged: 'decoder_model_merged',
319
- },
320
- options,
321
- 'decoder_model_merged',
322
- ),
323
- get_optional_configs(
324
- pretrained_model_name_or_path,
325
- {
326
- generation_config: 'generation_config.json',
327
- },
328
- options,
329
- ),
330
- ]);
331
- } else if (modelType === MODEL_TYPES.MaskGeneration) {
332
- info = await Promise.all([
333
- constructSessions(
334
- pretrained_model_name_or_path,
335
- {
336
- model: 'vision_encoder',
337
- prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
338
- },
339
- options,
340
- ),
341
- ]);
342
- } else if (modelType === MODEL_TYPES.EncoderDecoder) {
343
- info = await Promise.all([
344
- constructSessions(
345
- pretrained_model_name_or_path,
346
- {
347
- model: 'encoder_model',
348
- decoder_model_merged: 'decoder_model_merged',
349
- },
350
- options,
351
- 'decoder_model_merged',
352
- ),
353
- ]);
354
- } else if (modelType === MODEL_TYPES.ImageTextToText) {
355
- const sessions = {
356
- embed_tokens: 'embed_tokens',
357
- vision_encoder: 'vision_encoder',
358
- decoder_model_merged: 'decoder_model_merged',
359
- };
360
- if (config.is_encoder_decoder) {
361
- sessions['model'] = 'encoder_model';
362
- }
363
- info = await Promise.all([
364
- constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
365
- get_optional_configs(
366
- pretrained_model_name_or_path,
367
- {
368
- generation_config: 'generation_config.json',
369
- },
370
- options,
371
- ),
372
- ]);
373
- } else if (modelType === MODEL_TYPES.AudioTextToText) {
374
- const sessions = {
375
- embed_tokens: 'embed_tokens',
376
- audio_encoder: 'audio_encoder',
377
- decoder_model_merged: 'decoder_model_merged',
378
- };
379
- info = await Promise.all([
380
- constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
381
- get_optional_configs(
382
- pretrained_model_name_or_path,
383
- {
384
- generation_config: 'generation_config.json',
385
- },
386
- options,
387
- ),
388
- ]);
389
- } else if (modelType === MODEL_TYPES.ImageAudioTextToText) {
390
- const sessions = {
391
- embed_tokens: 'embed_tokens',
392
- audio_encoder: 'audio_encoder',
393
- vision_encoder: 'vision_encoder',
394
- decoder_model_merged: 'decoder_model_merged',
395
- };
396
- info = await Promise.all([
397
- constructSessions(pretrained_model_name_or_path, sessions, options),
398
- get_optional_configs(
399
- pretrained_model_name_or_path,
400
- {
401
- generation_config: 'generation_config.json',
402
- },
403
- options,
404
- ),
405
- ]);
406
- } else if (modelType === MODEL_TYPES.Musicgen) {
407
- info = await Promise.all([
408
- constructSessions(
409
- pretrained_model_name_or_path,
410
- {
411
- model: 'text_encoder',
412
- decoder_model_merged: 'decoder_model_merged',
413
- encodec_decode: 'encodec_decode',
414
- },
415
- options,
416
- 'decoder_model_merged',
417
- ),
418
- get_optional_configs(
419
- pretrained_model_name_or_path,
420
- {
421
- generation_config: 'generation_config.json',
422
- },
423
- options,
424
- ),
425
- ]);
426
- } else if (modelType === MODEL_TYPES.MultiModality) {
427
- info = await Promise.all([
428
- constructSessions(
429
- pretrained_model_name_or_path,
430
- {
431
- prepare_inputs_embeds: 'prepare_inputs_embeds',
432
- model: 'language_model',
433
- lm_head: 'lm_head',
434
- gen_head: 'gen_head',
435
- gen_img_embeds: 'gen_img_embeds',
436
- image_decode: 'image_decode',
437
- },
438
- options,
439
- 'model',
440
- ),
441
- get_optional_configs(
442
- pretrained_model_name_or_path,
443
- {
444
- generation_config: 'generation_config.json',
445
- },
446
- options,
447
- ),
448
- ]);
449
- } else if (modelType === MODEL_TYPES.Phi3V) {
450
- info = await Promise.all([
451
- constructSessions(
452
- pretrained_model_name_or_path,
453
- {
454
- prepare_inputs_embeds: 'prepare_inputs_embeds',
455
- model: 'model',
456
- vision_encoder: 'vision_encoder',
457
- },
458
- options,
459
- 'model',
460
- ),
461
- get_optional_configs(
462
- pretrained_model_name_or_path,
463
- {
464
- generation_config: 'generation_config.json',
465
- },
466
- options,
467
- ),
468
- ]);
469
- } else if (modelType === MODEL_TYPES.Chatterbox) {
470
- info = await Promise.all([
471
- constructSessions(
472
- pretrained_model_name_or_path,
473
- {
474
- embed_tokens: 'embed_tokens',
475
- speech_encoder: 'speech_encoder',
476
- model: 'language_model',
477
- conditional_decoder: 'conditional_decoder',
478
- },
479
- options,
480
- 'model',
481
- ),
482
- get_optional_configs(
483
- pretrained_model_name_or_path,
484
- {
485
- generation_config: 'generation_config.json',
486
- },
487
- options,
488
- ),
489
- ]);
490
- } else if (modelType === MODEL_TYPES.AutoEncoder) {
491
- info = await Promise.all([
492
- constructSessions(
493
- pretrained_model_name_or_path,
494
- {
495
- encoder_model: 'encoder_model',
496
- decoder_model: 'decoder_model',
497
- },
498
- options,
499
- ),
500
- ]);
501
- } else if (modelType === MODEL_TYPES.Supertonic) {
502
- info = await Promise.all([
503
- constructSessions(
504
- pretrained_model_name_or_path,
505
- {
506
- text_encoder: 'text_encoder',
507
- latent_denoiser: 'latent_denoiser',
508
- voice_decoder: 'voice_decoder',
509
- },
510
- options,
511
- ),
512
- ]);
513
- } else {
514
- // should be MODEL_TYPES.EncoderOnly or MODEL_TYPES.DecoderOnlyWithoutHead
515
- if (modelType === undefined) {
516
- const type = modelName ?? config?.model_type;
517
- if (type !== 'custom') {
518
- logger.warn(
519
- `Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`,
520
- );
521
- }
392
+ const typeConfig = MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default;
393
+
394
+ if (modelType === undefined) {
395
+ const type = modelName ?? config?.model_type;
396
+ if (type !== 'custom') {
397
+ logger.warn(
398
+ `Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`,
399
+ );
522
400
  }
523
- info = await Promise.all([
524
- constructSessions(
525
- pretrained_model_name_or_path,
526
- {
527
- model: options.model_file_name ?? 'model',
528
- },
529
- options,
530
- ),
531
- ]);
532
401
  }
533
402
 
403
+ const sessions = typeConfig.sessions(config, options);
404
+ const promises = [
405
+ constructSessions(pretrained_model_name_or_path, sessions, options, typeConfig.cache_sessions),
406
+ ];
407
+ if (typeConfig.optional_configs) {
408
+ promises.push(get_optional_configs(pretrained_model_name_or_path, typeConfig.optional_configs, options));
409
+ }
410
+ const info = await Promise.all(promises);
411
+
534
412
  // @ts-ignore
535
413
  return new this(config, ...info);
536
414
  }
@@ -868,7 +746,7 @@ export class PreTrainedModel extends Callable {
868
746
  * @param {Tensor} [params.inputs=null]
869
747
  * @param {number} [params.bos_token_id=null]
870
748
  * @param {Record<string, Tensor|number[]>} [params.model_kwargs]
871
- * @returns {{inputs_tensor: Tensor, model_inputs: Record<string, Tensor>, model_input_name: string}} The model-specific inputs for generation.
749
+ * @returns {{inputs_tensor: Tensor, model_inputs: Record<string, Tensor> & {past_key_values?: DynamicCache}, model_input_name: string}} The model-specific inputs for generation.
872
750
  */
873
751
  _prepare_model_inputs({ inputs, bos_token_id, model_kwargs }) {
874
752
  const model_inputs = pick(model_kwargs, this.forward_params);
@@ -1232,13 +1110,15 @@ export class PreTrainedModel extends Callable {
1232
1110
  }
1233
1111
 
1234
1112
  /**
1235
- * Returns an object containing past key values from the given decoder results object.
1113
+ * Returns a DynamicCache containing past key values from the given decoder results object.
1236
1114
  *
1237
1115
  * @param {Object} decoderResults The decoder results object.
1238
- * @param {Object} pastKeyValues The previous past key values.
1239
- * @returns {Object} An object containing past key values.
1116
+ * @param {DynamicCache} pastKeyValues The previous past key values.
1117
+ * @param {boolean} [disposeEncoderPKVs=false] Whether to dispose encoder past key values.
1118
+ * @returns {DynamicCache} A new DynamicCache containing the updated past key values.
1240
1119
  */
1241
1120
  getPastKeyValues(decoderResults, pastKeyValues, disposeEncoderPKVs = false) {
1121
+ /** @type {Record<string, Tensor>} */
1242
1122
  const pkvs = Object.create(null);
1243
1123
 
1244
1124
  for (const name in decoderResults) {
@@ -1272,7 +1152,7 @@ export class PreTrainedModel extends Callable {
1272
1152
  }
1273
1153
  }
1274
1154
  }
1275
- return pkvs;
1155
+ return new DynamicCache(pkvs);
1276
1156
  }
1277
1157
 
1278
1158
  /**
@@ -1300,8 +1180,8 @@ export class PreTrainedModel extends Callable {
1300
1180
  /**
1301
1181
  * Adds past key values to the decoder feeds object. If pastKeyValues is null, creates new tensors for past key values.
1302
1182
  *
1303
- * @param {Object} decoderFeeds The decoder feeds object to add past key values to.
1304
- * @param {Object} pastKeyValues An object containing past key values.
1183
+ * @param {Record<string, any>} decoderFeeds The decoder feeds object to add past key values to.
1184
+ * @param {DynamicCache|null} pastKeyValues The cache containing past key values.
1305
1185
  */
1306
1186
  addPastKeyValues(decoderFeeds, pastKeyValues) {
1307
1187
  if (pastKeyValues) {
@@ -1320,19 +1200,32 @@ export class PreTrainedModel extends Callable {
1320
1200
  }
1321
1201
  }
1322
1202
 
1323
- async encode_image({ pixel_values }) {
1324
- // image_inputs === { pixel_values }
1325
- return (await sessionRun(this.sessions['vision_encoder'], { pixel_values })).image_features;
1203
+ /**
1204
+ * Helper function to select valid inputs and run through the appropriate encoder (vision, text, audio) based on the input type.
1205
+ * @param {string} sessionName
1206
+ * @param {Record<string, Tensor>} inputs
1207
+ * @param {string} outputName
1208
+ * @private
1209
+ */
1210
+ async _encode_input(sessionName, inputs, outputName) {
1211
+ if (!Object.hasOwn(this.sessions, sessionName)) {
1212
+ throw new Error(`Model does not have a ${sessionName} session.`);
1213
+ }
1214
+ const session = this.sessions[sessionName];
1215
+ const output = await sessionRun(session, pick(inputs, session.inputNames));
1216
+ return output[outputName];
1217
+ }
1218
+
1219
+ async encode_image(inputs) {
1220
+ return this._encode_input('vision_encoder', inputs, 'image_features');
1326
1221
  }
1327
1222
 
1328
- async encode_text({ input_ids }) {
1329
- // text_inputs === { input_ids, attention_mask }
1330
- return (await sessionRun(this.sessions['embed_tokens'], { input_ids })).inputs_embeds;
1223
+ async encode_text(inputs) {
1224
+ return this._encode_input('embed_tokens', inputs, 'inputs_embeds');
1331
1225
  }
1332
1226
 
1333
- async encode_audio({ audio_values }) {
1334
- // audio_inputs === { audio_values }
1335
- return (await sessionRun(this.sessions['audio_encoder'], { audio_values })).audio_features;
1227
+ async encode_audio(inputs) {
1228
+ return this._encode_input('audio_encoder', inputs, 'audio_features');
1336
1229
  }
1337
1230
  }
1338
1231
 
@@ -1431,6 +1324,15 @@ export async function decoder_forward(self, model_inputs, is_encoder_decoder = f
1431
1324
  new_model_inputs.position_ids = create_position_ids(new_model_inputs, past_key_values, start_index);
1432
1325
  }
1433
1326
 
1327
+ if (session.inputNames.includes('num_logits_to_keep') && !new_model_inputs.num_logits_to_keep) {
1328
+ // `num_logits_to_keep` specifies the number of prompt logits to calculate during generation.
1329
+ // If unset (or 0), all logits will be calculated. If an integer value, only last `num_logits_to_keep`
1330
+ // logits will be calculated. During generation, the default is 1 because only the logits of the last
1331
+ // prompt token are needed for generation. For long sequences, the logits for the entire sequence may
1332
+ // use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint significantly.
1333
+ new_model_inputs.num_logits_to_keep = new Tensor('int64', [0n], []);
1334
+ }
1335
+
1434
1336
  // Unpack the `past_key_values` object into model inputs
1435
1337
  self.addPastKeyValues(new_model_inputs, past_key_values);
1436
1338
 
@@ -1445,13 +1347,13 @@ export async function decoder_forward(self, model_inputs, is_encoder_decoder = f
1445
1347
  * @param {Object} params Additional parameters.
1446
1348
  * @param {Function} [params.encode_function] The function to encode the modality values.
1447
1349
  * @param {Function} [params.merge_function] The function to merge the modality features with the input embeddings.
1448
- * @param {string} [params.modality_input_name] The modality input name.
1350
+ * @param {string[]} [params.modality_input_names] The modality input name.
1449
1351
  * @param {string} [params.modality_output_name] The modality output name.
1450
1352
  * @param {Tensor} [params.input_ids=null]
1451
1353
  * @param {Tensor} [params.attention_mask=null]
1452
1354
  * @param {Tensor} [params.position_ids=null]
1453
1355
  * @param {Tensor} [params.inputs_embeds=null]
1454
- * @param {Record<string, Tensor>} [params.past_key_values=null]
1356
+ * @param {DynamicCache} [params.past_key_values=null]
1455
1357
  * @param {Object} [params.generation_config=null]
1456
1358
  * @param {Object} [params.logits_processor=null]
1457
1359
  * @returns {Promise<Tensor>} The model's output tensor
@@ -1463,7 +1365,7 @@ export async function generic_text_to_text_forward(
1463
1365
  // Generic parameters:
1464
1366
  encode_function,
1465
1367
  merge_function,
1466
- modality_input_name,
1368
+ modality_input_names,
1467
1369
  modality_output_name,
1468
1370
 
1469
1371
  // Produced by the tokenizer/processor:
@@ -1483,37 +1385,39 @@ export async function generic_text_to_text_forward(
1483
1385
  ...kwargs
1484
1386
  },
1485
1387
  ) {
1486
- const modality_values = kwargs[modality_input_name];
1487
1388
  if (!inputs_embeds) {
1488
1389
  // 1. Extract the text embeddings.
1489
1390
  inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
1490
1391
 
1491
1392
  // 2. Possibly, merge text and modality values
1492
- if (modality_values && input_ids.dims[1] !== 1) {
1493
- const modality_features = await encode_function({
1494
- // Pass the modality values under its expected key.
1495
- // The caller knows whether this is audio or image.
1496
- [modality_input_name]: modality_values,
1497
- ...kwargs,
1498
- });
1499
- ({ inputs_embeds, attention_mask } = merge_function({
1500
- [modality_output_name]: modality_features,
1501
- inputs_embeds,
1502
- input_ids,
1503
- attention_mask,
1504
- }));
1505
- } else if (past_key_values && modality_values && input_ids.dims[1] === 1) {
1506
- // This branch handles the cache case.
1507
- const target_length = input_ids.dims[1]; // always 1
1508
- const past_length = getPastLength(past_key_values);
1509
-
1510
- attention_mask = cat(
1511
- [
1512
- ones([input_ids.dims[0], past_length]),
1513
- attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
1514
- ],
1515
- 1,
1516
- );
1393
+ const modality_values = pick(kwargs, modality_input_names);
1394
+ if (Object.keys(modality_values).length > 0) {
1395
+ if (input_ids.dims[1] !== 1) {
1396
+ const modality_features = await encode_function({
1397
+ // Pass the modality values under its expected key.
1398
+ // The caller knows whether this is audio or image.
1399
+ ...modality_values,
1400
+ ...kwargs,
1401
+ });
1402
+ ({ inputs_embeds, attention_mask } = merge_function({
1403
+ [modality_output_name]: modality_features,
1404
+ inputs_embeds,
1405
+ input_ids,
1406
+ attention_mask,
1407
+ }));
1408
+ } else if (past_key_values && input_ids.dims[1] === 1) {
1409
+ // This branch handles the cache case.
1410
+ const target_length = input_ids.dims[1]; // always 1
1411
+ const past_length = past_key_values.get_seq_length();
1412
+
1413
+ attention_mask = cat(
1414
+ [
1415
+ ones([input_ids.dims[0], past_length]),
1416
+ attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
1417
+ ],
1418
+ 1,
1419
+ );
1420
+ }
1517
1421
  }
1518
1422
  }
1519
1423
 
@@ -1522,14 +1426,19 @@ export async function generic_text_to_text_forward(
1522
1426
  // Handle special case for qwen vl models
1523
1427
  [
1524
1428
  'qwen2_vl',
1429
+ 'qwen2_vl_text',
1525
1430
  'qwen2_5_vl',
1526
1431
  'qwen2_5_vl_text',
1527
1432
  'qwen3_vl',
1528
1433
  'qwen3_vl_text',
1434
+ 'qwen3_vl_moe',
1435
+ 'qwen3_vl_moe_text',
1529
1436
  'qwen3_5',
1530
1437
  'qwen3_5_text',
1531
1438
  'qwen3_5_moe',
1532
1439
  'qwen3_5_moe_text',
1440
+ 'glm_ocr',
1441
+ 'glm_ocr_text',
1533
1442
  ].includes(self.config.model_type)
1534
1443
  ) {
1535
1444
  // @ts-ignore
@@ -1564,7 +1473,7 @@ export async function generic_text_to_text_forward(
1564
1473
  export async function audio_text_to_text_forward(self, params) {
1565
1474
  return await generic_text_to_text_forward(self, {
1566
1475
  ...params,
1567
- modality_input_name: 'audio_values',
1476
+ modality_input_names: ['audio_values', 'input_features'],
1568
1477
  modality_output_name: 'audio_features',
1569
1478
  encode_function: self.encode_audio.bind(self),
1570
1479
  merge_function: self._merge_input_ids_with_audio_features.bind(self),
@@ -1581,7 +1490,7 @@ export async function audio_text_to_text_forward(self, params) {
1581
1490
  export async function image_text_to_text_forward(self, params) {
1582
1491
  return await generic_text_to_text_forward(self, {
1583
1492
  ...params,
1584
- modality_input_name: 'pixel_values',
1493
+ modality_input_names: ['pixel_values'],
1585
1494
  modality_output_name: 'image_features',
1586
1495
  encode_function: self.encode_image.bind(self),
1587
1496
  merge_function: self._merge_input_ids_with_image_features.bind(self),
@@ -1644,7 +1553,14 @@ export function create_position_ids(model_inputs, past_key_values = null, start_
1644
1553
  }
1645
1554
 
1646
1555
  export function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
1647
- const past_length = model_inputs.past_key_values ? getPastLength(model_inputs.past_key_values) : 0;
1556
+ const past_length = model_inputs.past_key_values ? model_inputs.past_key_values.get_seq_length() : 0;
1557
+
1558
+ // During generation, only the last token's logits are needed. Setting num_logits_to_keep=1
1559
+ // avoids computing logits for the entire sequence, significantly reducing memory usage.
1560
+ const session = self.sessions['decoder_model_merged'] ?? self.sessions['model'];
1561
+ if (session?.inputNames.includes('num_logits_to_keep') && !model_inputs.num_logits_to_keep) {
1562
+ model_inputs.num_logits_to_keep = new Tensor('int64', [1n], []);
1563
+ }
1648
1564
 
1649
1565
  if (!model_inputs.attention_mask) {
1650
1566
  // If the attention mask is not provided, we attempt to infer based on provided inputs