@huggingface/transformers 4.0.0-next.6 → 4.0.0-next.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -2
- package/dist/ort-wasm-simd-threaded.jsep.mjs +24 -24
- package/dist/transformers.js +2255 -931
- package/dist/transformers.min.js +19 -19
- package/dist/transformers.node.cjs +2300 -934
- package/dist/transformers.node.min.cjs +20 -20
- package/dist/transformers.node.min.mjs +20 -20
- package/dist/transformers.node.mjs +2336 -1012
- package/dist/transformers.web.js +2327 -1003
- package/dist/transformers.web.min.js +17 -17
- package/package.json +4 -4
- package/src/cache_utils.js +62 -0
- package/src/configs.js +45 -24
- package/src/env.js +8 -1
- package/src/image_processors_utils.js +27 -17
- package/src/models/chatterbox/modeling_chatterbox.js +1 -1
- package/src/models/chmv2/image_processing_chmv2.js +3 -0
- package/src/models/chmv2/modeling_chmv2.js +4 -0
- package/src/models/deepseek_v3/modeling_deepseek_v3.js +5 -0
- package/src/models/detr/image_processing_detr.js +1 -1
- package/src/models/eurobert/modeling_eurobert.js +41 -0
- package/src/models/feature_extractors.js +2 -0
- package/src/models/gemma3n/modeling_gemma3n.js +2 -0
- package/src/models/glm46v/image_processing_glm46v.js +12 -0
- package/src/models/glm46v/processing_glm46v.js +5 -0
- package/src/models/glm_moe_dsa/modeling_glm_moe_dsa.js +5 -0
- package/src/models/glm_ocr/modeling_glm_ocr.js +78 -0
- package/src/models/granite_speech/feature_extraction_granite_speech.js +58 -0
- package/src/models/granite_speech/modeling_granite_speech.js +5 -0
- package/src/models/granite_speech/processing_granite_speech.js +62 -0
- package/src/models/grounding_dino/image_processing_grounding_dino.js +1 -1
- package/src/models/idefics3/modeling_idefics3.js +5 -32
- package/src/models/image_processors.js +3 -0
- package/src/models/lfm2_vl/image_processing_lfm2_vl.js +305 -0
- package/src/models/lfm2_vl/modeling_lfm2_vl.js +13 -0
- package/src/models/lfm2_vl/processing_lfm2_vl.js +77 -0
- package/src/models/lighton_ocr/modeling_lighton_ocr.js +3 -0
- package/src/models/llava/modeling_llava.js +1 -1
- package/src/models/mistral3/modeling_mistral3.js +2 -2
- package/src/models/mistral4/modeling_mistral4.js +5 -0
- package/src/models/modeling_utils.js +224 -308
- package/src/models/models.js +14 -1
- package/src/models/nemotron_h/modeling_nemotron_h.js +5 -0
- package/src/models/paligemma/modeling_paligemma.js +2 -25
- package/src/models/processors.js +4 -0
- package/src/models/qwen2_5_vl/modeling_qwen2_5_vl.js +5 -1
- package/src/models/qwen2_vl/image_processing_qwen2_vl.js +1 -41
- package/src/models/qwen2_vl/modeling_qwen2_vl.js +194 -143
- package/src/models/qwen2_vl/processing_qwen2_vl.js +5 -4
- package/src/models/qwen3_5/modeling_qwen3_5.js +1 -0
- package/src/models/qwen3_5_moe/modeling_qwen3_5_moe.js +2 -1
- package/src/models/qwen3_vl/modeling_qwen3_vl.js +2 -1
- package/src/models/qwen3_vl_moe/modeling_qwen3_vl_moe.js +2 -1
- package/src/models/registry.js +42 -0
- package/src/models/sam/image_processing_sam.js +1 -1
- package/src/models/session.js +17 -6
- package/src/models/smolvlm/modeling_smolvlm.js +7 -0
- package/src/models/solar_open/modeling_solar_open.js +5 -0
- package/src/models/ultravox/modeling_ultravox.js +1 -3
- package/src/models/voxtral/modeling_voxtral.js +3 -0
- package/src/models/voxtral_realtime/feature_extraction_voxtral_realtime.js +71 -0
- package/src/models/voxtral_realtime/modeling_voxtral_realtime.js +239 -0
- package/src/models/voxtral_realtime/processing_voxtral_realtime.js +113 -0
- package/src/models/whisper/feature_extraction_whisper.js +2 -12
- package/src/pipelines.js +1 -0
- package/src/transformers.js +2 -0
- package/src/utils/audio.js +18 -2
- package/src/utils/cache/CrossOriginStorageCache.js +251 -0
- package/src/utils/cache/cross-origin-storage.d.ts +38 -0
- package/src/utils/cache.js +5 -0
- package/src/utils/hub.js +4 -1
- package/src/utils/lru_cache.js +67 -0
- package/src/utils/memoize_promise.js +45 -0
- package/src/utils/model_registry/get_file_metadata.js +15 -2
- package/src/utils/model_registry/get_model_files.js +52 -78
- package/src/utils/tensor.js +18 -2
- package/types/cache_utils.d.ts +29 -0
- package/types/cache_utils.d.ts.map +1 -0
- package/types/configs.d.ts.map +1 -1
- package/types/env.d.ts +8 -0
- package/types/env.d.ts.map +1 -1
- package/types/image_processors_utils.d.ts +18 -1
- package/types/image_processors_utils.d.ts.map +1 -1
- package/types/models/{ast/modeling_ast.d.ts → audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.d.ts} +1 -1
- package/types/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.d.ts.map +1 -0
- package/types/models/chmv2/image_processing_chmv2.d.ts +4 -0
- package/types/models/chmv2/image_processing_chmv2.d.ts.map +1 -0
- package/types/models/chmv2/modeling_chmv2.d.ts +6 -0
- package/types/models/chmv2/modeling_chmv2.d.ts.map +1 -0
- package/types/models/deepseek_v3/modeling_deepseek_v3.d.ts +8 -0
- package/types/models/deepseek_v3/modeling_deepseek_v3.d.ts.map +1 -0
- package/types/models/detr/image_processing_detr.d.ts +1 -1
- package/types/models/eurobert/modeling_eurobert.d.ts +36 -0
- package/types/models/eurobert/modeling_eurobert.d.ts.map +1 -0
- package/types/models/feature_extractors.d.ts +2 -0
- package/types/models/gemma3n/modeling_gemma3n.d.ts +2 -0
- package/types/models/gemma3n/modeling_gemma3n.d.ts.map +1 -1
- package/types/models/glm46v/image_processing_glm46v.d.ts +4 -0
- package/types/models/glm46v/image_processing_glm46v.d.ts.map +1 -0
- package/types/models/glm46v/processing_glm46v.d.ts +4 -0
- package/types/models/glm46v/processing_glm46v.d.ts.map +1 -0
- package/types/models/glm_moe_dsa/modeling_glm_moe_dsa.d.ts +8 -0
- package/types/models/glm_moe_dsa/modeling_glm_moe_dsa.d.ts.map +1 -0
- package/types/models/glm_ocr/modeling_glm_ocr.d.ts +26 -0
- package/types/models/glm_ocr/modeling_glm_ocr.d.ts.map +1 -0
- package/types/models/granite_speech/feature_extraction_granite_speech.d.ts +16 -0
- package/types/models/granite_speech/feature_extraction_granite_speech.d.ts.map +1 -0
- package/types/models/granite_speech/modeling_granite_speech.d.ts +4 -0
- package/types/models/granite_speech/modeling_granite_speech.d.ts.map +1 -0
- package/types/models/granite_speech/processing_granite_speech.d.ts +19 -0
- package/types/models/granite_speech/processing_granite_speech.d.ts.map +1 -0
- package/types/models/grounding_dino/image_processing_grounding_dino.d.ts +1 -1
- package/types/models/idefics3/modeling_idefics3.d.ts +2 -18
- package/types/models/idefics3/modeling_idefics3.d.ts.map +1 -1
- package/types/models/image_processors.d.ts +3 -0
- package/types/models/lfm2_vl/image_processing_lfm2_vl.d.ts +41 -0
- package/types/models/lfm2_vl/image_processing_lfm2_vl.d.ts.map +1 -0
- package/types/models/lfm2_vl/modeling_lfm2_vl.d.ts +4 -0
- package/types/models/lfm2_vl/modeling_lfm2_vl.d.ts.map +1 -0
- package/types/models/lfm2_vl/processing_lfm2_vl.d.ts +18 -0
- package/types/models/lfm2_vl/processing_lfm2_vl.d.ts.map +1 -0
- package/types/models/lighton_ocr/modeling_lighton_ocr.d.ts +4 -0
- package/types/models/lighton_ocr/modeling_lighton_ocr.d.ts.map +1 -0
- package/types/models/mistral3/modeling_mistral3.d.ts +2 -2
- package/types/models/mistral3/modeling_mistral3.d.ts.map +1 -1
- package/types/models/mistral4/modeling_mistral4.d.ts +8 -0
- package/types/models/mistral4/modeling_mistral4.d.ts.map +1 -0
- package/types/models/modeling_utils.d.ts +44 -35
- package/types/models/modeling_utils.d.ts.map +1 -1
- package/types/models/models.d.ts +14 -1
- package/types/models/nemotron_h/modeling_nemotron_h.d.ts +8 -0
- package/types/models/nemotron_h/modeling_nemotron_h.d.ts.map +1 -0
- package/types/models/paligemma/modeling_paligemma.d.ts +2 -8
- package/types/models/paligemma/modeling_paligemma.d.ts.map +1 -1
- package/types/models/processors.d.ts +4 -0
- package/types/models/qwen2_5_vl/modeling_qwen2_5_vl.d.ts +3 -0
- package/types/models/qwen2_5_vl/modeling_qwen2_5_vl.d.ts.map +1 -1
- package/types/models/qwen2_vl/image_processing_qwen2_vl.d.ts.map +1 -1
- package/types/models/qwen2_vl/modeling_qwen2_vl.d.ts +43 -6
- package/types/models/qwen2_vl/modeling_qwen2_vl.d.ts.map +1 -1
- package/types/models/qwen2_vl/processing_qwen2_vl.d.ts +1 -0
- package/types/models/qwen2_vl/processing_qwen2_vl.d.ts.map +1 -1
- package/types/models/qwen3_5/modeling_qwen3_5.d.ts +2 -0
- package/types/models/qwen3_5/modeling_qwen3_5.d.ts.map +1 -1
- package/types/models/qwen3_5_moe/modeling_qwen3_5_moe.d.ts +3 -0
- package/types/models/qwen3_5_moe/modeling_qwen3_5_moe.d.ts.map +1 -1
- package/types/models/qwen3_vl/modeling_qwen3_vl.d.ts +3 -0
- package/types/models/qwen3_vl/modeling_qwen3_vl.d.ts.map +1 -1
- package/types/models/qwen3_vl_moe/modeling_qwen3_vl_moe.d.ts +3 -0
- package/types/models/qwen3_vl_moe/modeling_qwen3_vl_moe.d.ts.map +1 -1
- package/types/models/registry.d.ts.map +1 -1
- package/types/models/sam/image_processing_sam.d.ts +1 -1
- package/types/models/session.d.ts +3 -2
- package/types/models/session.d.ts.map +1 -1
- package/types/models/smolvlm/modeling_smolvlm.d.ts +8 -0
- package/types/models/smolvlm/modeling_smolvlm.d.ts.map +1 -0
- package/types/models/solar_open/modeling_solar_open.d.ts +8 -0
- package/types/models/solar_open/modeling_solar_open.d.ts.map +1 -0
- package/types/models/ultravox/modeling_ultravox.d.ts +0 -2
- package/types/models/ultravox/modeling_ultravox.d.ts.map +1 -1
- package/types/models/voxtral/modeling_voxtral.d.ts +4 -0
- package/types/models/voxtral/modeling_voxtral.d.ts.map +1 -0
- package/types/models/voxtral_realtime/feature_extraction_voxtral_realtime.d.ts +28 -0
- package/types/models/voxtral_realtime/feature_extraction_voxtral_realtime.d.ts.map +1 -0
- package/types/models/voxtral_realtime/modeling_voxtral_realtime.d.ts +17 -0
- package/types/models/voxtral_realtime/modeling_voxtral_realtime.d.ts.map +1 -0
- package/types/models/voxtral_realtime/processing_voxtral_realtime.d.ts +44 -0
- package/types/models/voxtral_realtime/processing_voxtral_realtime.d.ts.map +1 -0
- package/types/models/whisper/feature_extraction_whisper.d.ts.map +1 -1
- package/types/pipelines.d.ts +1 -0
- package/types/pipelines.d.ts.map +1 -1
- package/types/transformers.d.ts +1 -0
- package/types/transformers.d.ts.map +1 -1
- package/types/utils/audio.d.ts +5 -2
- package/types/utils/audio.d.ts.map +1 -1
- package/types/utils/cache/CrossOriginStorageCache.d.ts +120 -0
- package/types/utils/cache/CrossOriginStorageCache.d.ts.map +1 -0
- package/types/utils/cache.d.ts.map +1 -1
- package/types/utils/dtypes.d.ts +1 -1
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/image.d.ts +1 -1
- package/types/utils/lru_cache.d.ts +38 -0
- package/types/utils/lru_cache.d.ts.map +1 -0
- package/types/utils/memoize_promise.d.ts +14 -0
- package/types/utils/memoize_promise.d.ts.map +1 -0
- package/types/utils/model_registry/get_file_metadata.d.ts.map +1 -1
- package/types/utils/model_registry/get_model_files.d.ts +1 -0
- package/types/utils/model_registry/get_model_files.d.ts.map +1 -1
- package/types/utils/tensor.d.ts.map +1 -1
- package/src/utils/data-structures.js +0 -572
- package/types/models/ast/modeling_ast.d.ts.map +0 -1
- package/types/utils/data-structures.d.ts +0 -294
- package/types/utils/data-structures.d.ts.map +0 -1
- /package/src/models/{ast/modeling_ast.js → audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.js} +0 -0
|
@@ -36,26 +36,7 @@ import { LogitsSampler } from '../generation/logits_sampler.js';
|
|
|
36
36
|
import { pick } from '../utils/core.js';
|
|
37
37
|
import { ModelOutput } from './modeling_outputs.js';
|
|
38
38
|
import { logger } from '../utils/logger.js';
|
|
39
|
-
|
|
40
|
-
/**
|
|
41
|
-
* Extract the past sequence length from a past_key_values object.
|
|
42
|
-
* For standard models, all entries are attention KV caches with shape [batch, heads, seq_len, head_dim].
|
|
43
|
-
* For hybrid models (e.g., Qwen3.5 with conv/recurrent + attention layers), the first entry
|
|
44
|
-
* may be a conv or recurrent state whose dims don't encode a sequence length.
|
|
45
|
-
* This function finds a `past_key_values.*` entry (standard attention cache) to determine the true past length.
|
|
46
|
-
*
|
|
47
|
-
* @param {Record<string, Tensor>} past_key_values
|
|
48
|
-
* @returns {number} The past sequence length.
|
|
49
|
-
*/
|
|
50
|
-
export function getPastLength(past_key_values) {
|
|
51
|
-
for (const name in past_key_values) {
|
|
52
|
-
if (name.startsWith('past_key_values.')) {
|
|
53
|
-
return past_key_values[name].dims.at(-2);
|
|
54
|
-
}
|
|
55
|
-
}
|
|
56
|
-
// Fallback for non-hybrid models (all entries are attention KV)
|
|
57
|
-
return Object.values(past_key_values)[0].dims.at(-2);
|
|
58
|
-
}
|
|
39
|
+
import { DynamicCache } from '../cache_utils.js';
|
|
59
40
|
|
|
60
41
|
/**
|
|
61
42
|
* Converts an array or Tensor of integers to an int64 Tensor.
|
|
@@ -118,6 +99,8 @@ export const MODEL_TYPES = {
|
|
|
118
99
|
ImageAudioTextToText: 13,
|
|
119
100
|
Supertonic: 14,
|
|
120
101
|
Chatterbox: 15,
|
|
102
|
+
MultimodalLanguageModelOnly: 16,
|
|
103
|
+
VoxtralRealtime: 17,
|
|
121
104
|
};
|
|
122
105
|
|
|
123
106
|
const MODEL_TYPE_CONFIG = {
|
|
@@ -125,65 +108,181 @@ const MODEL_TYPE_CONFIG = {
|
|
|
125
108
|
can_generate: true,
|
|
126
109
|
forward: decoder_forward,
|
|
127
110
|
prepare_inputs: decoder_prepare_inputs_for_generation,
|
|
111
|
+
sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
|
|
112
|
+
cache_sessions: { model: true },
|
|
113
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
128
114
|
},
|
|
129
115
|
[MODEL_TYPES.DecoderOnlyWithoutHead]: {
|
|
130
116
|
can_generate: false,
|
|
131
117
|
forward: decoder_forward,
|
|
132
118
|
prepare_inputs: decoder_prepare_inputs_for_generation,
|
|
119
|
+
sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
|
|
133
120
|
},
|
|
134
121
|
[MODEL_TYPES.Seq2Seq]: {
|
|
135
122
|
can_generate: true,
|
|
136
123
|
forward: seq2seq_forward,
|
|
137
124
|
prepare_inputs: encoder_decoder_prepare_inputs_for_generation,
|
|
125
|
+
sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
|
|
126
|
+
cache_sessions: { decoder_model_merged: true },
|
|
127
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
138
128
|
},
|
|
139
129
|
[MODEL_TYPES.Vision2Seq]: {
|
|
140
130
|
can_generate: true,
|
|
141
131
|
forward: seq2seq_forward,
|
|
142
132
|
prepare_inputs: encoder_decoder_prepare_inputs_for_generation,
|
|
133
|
+
sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
|
|
134
|
+
cache_sessions: { decoder_model_merged: true },
|
|
135
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
143
136
|
},
|
|
144
137
|
[MODEL_TYPES.Musicgen]: {
|
|
145
138
|
can_generate: true,
|
|
146
139
|
forward: seq2seq_forward,
|
|
140
|
+
sessions: () => ({
|
|
141
|
+
model: 'text_encoder',
|
|
142
|
+
decoder_model_merged: 'decoder_model_merged',
|
|
143
|
+
encodec_decode: 'encodec_decode',
|
|
144
|
+
}),
|
|
145
|
+
cache_sessions: { decoder_model_merged: true },
|
|
146
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
147
147
|
},
|
|
148
148
|
[MODEL_TYPES.EncoderDecoder]: {
|
|
149
149
|
can_generate: false,
|
|
150
150
|
forward: seq2seq_forward,
|
|
151
|
+
sessions: () => ({ model: 'encoder_model', decoder_model_merged: 'decoder_model_merged' }),
|
|
152
|
+
cache_sessions: { decoder_model_merged: true },
|
|
153
|
+
},
|
|
154
|
+
[MODEL_TYPES.MaskGeneration]: {
|
|
155
|
+
sessions: () => ({ model: 'vision_encoder', prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder' }),
|
|
151
156
|
},
|
|
152
157
|
[MODEL_TYPES.ImageTextToText]: {
|
|
153
158
|
can_generate: true,
|
|
154
159
|
forward: image_text_to_text_forward,
|
|
155
160
|
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
|
|
161
|
+
sessions: (config) => {
|
|
162
|
+
const s = {
|
|
163
|
+
embed_tokens: 'embed_tokens',
|
|
164
|
+
vision_encoder: 'vision_encoder',
|
|
165
|
+
decoder_model_merged: 'decoder_model_merged',
|
|
166
|
+
};
|
|
167
|
+
if (config.is_encoder_decoder) s['model'] = 'encoder_model';
|
|
168
|
+
return s;
|
|
169
|
+
},
|
|
170
|
+
cache_sessions: { decoder_model_merged: true },
|
|
171
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
156
172
|
},
|
|
157
173
|
[MODEL_TYPES.AudioTextToText]: {
|
|
158
174
|
can_generate: true,
|
|
159
175
|
forward: audio_text_to_text_forward,
|
|
160
176
|
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
|
|
177
|
+
sessions: () => ({
|
|
178
|
+
embed_tokens: 'embed_tokens',
|
|
179
|
+
audio_encoder: 'audio_encoder',
|
|
180
|
+
decoder_model_merged: 'decoder_model_merged',
|
|
181
|
+
}),
|
|
182
|
+
cache_sessions: { decoder_model_merged: true },
|
|
183
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
161
184
|
},
|
|
162
|
-
[MODEL_TYPES.
|
|
185
|
+
[MODEL_TYPES.ImageAudioTextToText]: {
|
|
163
186
|
can_generate: true,
|
|
164
187
|
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
|
|
188
|
+
sessions: () => ({
|
|
189
|
+
embed_tokens: 'embed_tokens',
|
|
190
|
+
audio_encoder: 'audio_encoder',
|
|
191
|
+
vision_encoder: 'vision_encoder',
|
|
192
|
+
decoder_model_merged: 'decoder_model_merged',
|
|
193
|
+
}),
|
|
194
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
165
195
|
},
|
|
166
|
-
[MODEL_TYPES.
|
|
196
|
+
[MODEL_TYPES.Phi3V]: {
|
|
167
197
|
can_generate: true,
|
|
168
198
|
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
|
|
199
|
+
sessions: () => ({
|
|
200
|
+
prepare_inputs_embeds: 'prepare_inputs_embeds',
|
|
201
|
+
model: 'model',
|
|
202
|
+
vision_encoder: 'vision_encoder',
|
|
203
|
+
}),
|
|
204
|
+
cache_sessions: { model: true },
|
|
205
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
169
206
|
},
|
|
170
207
|
[MODEL_TYPES.MultiModality]: {
|
|
171
208
|
can_generate: true,
|
|
209
|
+
sessions: () => ({
|
|
210
|
+
prepare_inputs_embeds: 'prepare_inputs_embeds',
|
|
211
|
+
model: 'language_model',
|
|
212
|
+
lm_head: 'lm_head',
|
|
213
|
+
gen_head: 'gen_head',
|
|
214
|
+
gen_img_embeds: 'gen_img_embeds',
|
|
215
|
+
image_decode: 'image_decode',
|
|
216
|
+
}),
|
|
217
|
+
cache_sessions: { model: true },
|
|
218
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
172
219
|
},
|
|
173
220
|
[MODEL_TYPES.AutoEncoder]: {
|
|
174
221
|
can_generate: false,
|
|
175
222
|
forward: auto_encoder_forward,
|
|
223
|
+
sessions: () => ({ encoder_model: 'encoder_model', decoder_model: 'decoder_model' }),
|
|
224
|
+
},
|
|
225
|
+
[MODEL_TYPES.Supertonic]: {
|
|
226
|
+
sessions: () => ({
|
|
227
|
+
text_encoder: 'text_encoder',
|
|
228
|
+
latent_denoiser: 'latent_denoiser',
|
|
229
|
+
voice_decoder: 'voice_decoder',
|
|
230
|
+
}),
|
|
176
231
|
},
|
|
177
232
|
[MODEL_TYPES.Chatterbox]: {
|
|
178
233
|
can_generate: true,
|
|
179
234
|
forward: encoder_forward,
|
|
235
|
+
sessions: () => ({
|
|
236
|
+
embed_tokens: 'embed_tokens',
|
|
237
|
+
speech_encoder: 'speech_encoder',
|
|
238
|
+
model: 'language_model',
|
|
239
|
+
conditional_decoder: 'conditional_decoder',
|
|
240
|
+
}),
|
|
241
|
+
cache_sessions: { model: true },
|
|
242
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
243
|
+
},
|
|
244
|
+
[MODEL_TYPES.MultimodalLanguageModelOnly]: {
|
|
245
|
+
can_generate: true,
|
|
246
|
+
forward: image_text_to_text_forward,
|
|
247
|
+
prepare_inputs: multimodal_text_to_text_prepare_inputs_for_generation,
|
|
248
|
+
sessions: () => ({ embed_tokens: 'embed_tokens', decoder_model_merged: 'decoder_model_merged' }),
|
|
249
|
+
cache_sessions: { decoder_model_merged: true },
|
|
250
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
251
|
+
},
|
|
252
|
+
[MODEL_TYPES.VoxtralRealtime]: {
|
|
253
|
+
can_generate: true,
|
|
254
|
+
prepare_inputs: decoder_prepare_inputs_for_generation,
|
|
255
|
+
sessions: () => ({
|
|
256
|
+
embed_tokens: 'embed_tokens',
|
|
257
|
+
audio_encoder: 'audio_encoder',
|
|
258
|
+
decoder_model_merged: 'decoder_model_merged',
|
|
259
|
+
}),
|
|
260
|
+
cache_sessions: { decoder_model_merged: true, audio_encoder: true },
|
|
261
|
+
optional_configs: { generation_config: 'generation_config.json' },
|
|
180
262
|
},
|
|
181
263
|
default: {
|
|
182
264
|
can_generate: false,
|
|
183
265
|
forward: encoder_forward,
|
|
266
|
+
sessions: (config, options) => ({ model: options.model_file_name ?? 'model' }),
|
|
184
267
|
},
|
|
185
268
|
};
|
|
186
269
|
|
|
270
|
+
/**
|
|
271
|
+
* Get the session configuration for a given model type.
|
|
272
|
+
* @param {number} modelType The model type enum value.
|
|
273
|
+
* @param {Object} config The model config.
|
|
274
|
+
* @param {Object} [options] Loading options.
|
|
275
|
+
* @returns {{ sessions: Record<string, string>, cache_sessions?: Record<string, true>, optional_configs?: Record<string, string> }}
|
|
276
|
+
*/
|
|
277
|
+
export function getSessionsConfig(modelType, config, options = {}) {
|
|
278
|
+
const typeConfig = MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default;
|
|
279
|
+
return {
|
|
280
|
+
sessions: typeConfig.sessions(config, options),
|
|
281
|
+
cache_sessions: typeConfig.cache_sessions,
|
|
282
|
+
optional_configs: typeConfig.optional_configs,
|
|
283
|
+
};
|
|
284
|
+
}
|
|
285
|
+
|
|
187
286
|
export const MODEL_TYPE_MAPPING = new Map();
|
|
188
287
|
export const MODEL_NAME_TO_CLASS_MAPPING = new Map();
|
|
189
288
|
export const MODEL_CLASS_TO_NAME_MAPPING = new Map();
|
|
@@ -290,247 +389,26 @@ export class PreTrainedModel extends Callable {
|
|
|
290
389
|
|
|
291
390
|
config = options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);
|
|
292
391
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
options,
|
|
302
|
-
'model',
|
|
303
|
-
),
|
|
304
|
-
get_optional_configs(
|
|
305
|
-
pretrained_model_name_or_path,
|
|
306
|
-
{
|
|
307
|
-
generation_config: 'generation_config.json',
|
|
308
|
-
},
|
|
309
|
-
options,
|
|
310
|
-
),
|
|
311
|
-
]);
|
|
312
|
-
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
|
|
313
|
-
info = await Promise.all([
|
|
314
|
-
constructSessions(
|
|
315
|
-
pretrained_model_name_or_path,
|
|
316
|
-
{
|
|
317
|
-
model: 'encoder_model',
|
|
318
|
-
decoder_model_merged: 'decoder_model_merged',
|
|
319
|
-
},
|
|
320
|
-
options,
|
|
321
|
-
'decoder_model_merged',
|
|
322
|
-
),
|
|
323
|
-
get_optional_configs(
|
|
324
|
-
pretrained_model_name_or_path,
|
|
325
|
-
{
|
|
326
|
-
generation_config: 'generation_config.json',
|
|
327
|
-
},
|
|
328
|
-
options,
|
|
329
|
-
),
|
|
330
|
-
]);
|
|
331
|
-
} else if (modelType === MODEL_TYPES.MaskGeneration) {
|
|
332
|
-
info = await Promise.all([
|
|
333
|
-
constructSessions(
|
|
334
|
-
pretrained_model_name_or_path,
|
|
335
|
-
{
|
|
336
|
-
model: 'vision_encoder',
|
|
337
|
-
prompt_encoder_mask_decoder: 'prompt_encoder_mask_decoder',
|
|
338
|
-
},
|
|
339
|
-
options,
|
|
340
|
-
),
|
|
341
|
-
]);
|
|
342
|
-
} else if (modelType === MODEL_TYPES.EncoderDecoder) {
|
|
343
|
-
info = await Promise.all([
|
|
344
|
-
constructSessions(
|
|
345
|
-
pretrained_model_name_or_path,
|
|
346
|
-
{
|
|
347
|
-
model: 'encoder_model',
|
|
348
|
-
decoder_model_merged: 'decoder_model_merged',
|
|
349
|
-
},
|
|
350
|
-
options,
|
|
351
|
-
'decoder_model_merged',
|
|
352
|
-
),
|
|
353
|
-
]);
|
|
354
|
-
} else if (modelType === MODEL_TYPES.ImageTextToText) {
|
|
355
|
-
const sessions = {
|
|
356
|
-
embed_tokens: 'embed_tokens',
|
|
357
|
-
vision_encoder: 'vision_encoder',
|
|
358
|
-
decoder_model_merged: 'decoder_model_merged',
|
|
359
|
-
};
|
|
360
|
-
if (config.is_encoder_decoder) {
|
|
361
|
-
sessions['model'] = 'encoder_model';
|
|
362
|
-
}
|
|
363
|
-
info = await Promise.all([
|
|
364
|
-
constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
|
|
365
|
-
get_optional_configs(
|
|
366
|
-
pretrained_model_name_or_path,
|
|
367
|
-
{
|
|
368
|
-
generation_config: 'generation_config.json',
|
|
369
|
-
},
|
|
370
|
-
options,
|
|
371
|
-
),
|
|
372
|
-
]);
|
|
373
|
-
} else if (modelType === MODEL_TYPES.AudioTextToText) {
|
|
374
|
-
const sessions = {
|
|
375
|
-
embed_tokens: 'embed_tokens',
|
|
376
|
-
audio_encoder: 'audio_encoder',
|
|
377
|
-
decoder_model_merged: 'decoder_model_merged',
|
|
378
|
-
};
|
|
379
|
-
info = await Promise.all([
|
|
380
|
-
constructSessions(pretrained_model_name_or_path, sessions, options, 'decoder_model_merged'),
|
|
381
|
-
get_optional_configs(
|
|
382
|
-
pretrained_model_name_or_path,
|
|
383
|
-
{
|
|
384
|
-
generation_config: 'generation_config.json',
|
|
385
|
-
},
|
|
386
|
-
options,
|
|
387
|
-
),
|
|
388
|
-
]);
|
|
389
|
-
} else if (modelType === MODEL_TYPES.ImageAudioTextToText) {
|
|
390
|
-
const sessions = {
|
|
391
|
-
embed_tokens: 'embed_tokens',
|
|
392
|
-
audio_encoder: 'audio_encoder',
|
|
393
|
-
vision_encoder: 'vision_encoder',
|
|
394
|
-
decoder_model_merged: 'decoder_model_merged',
|
|
395
|
-
};
|
|
396
|
-
info = await Promise.all([
|
|
397
|
-
constructSessions(pretrained_model_name_or_path, sessions, options),
|
|
398
|
-
get_optional_configs(
|
|
399
|
-
pretrained_model_name_or_path,
|
|
400
|
-
{
|
|
401
|
-
generation_config: 'generation_config.json',
|
|
402
|
-
},
|
|
403
|
-
options,
|
|
404
|
-
),
|
|
405
|
-
]);
|
|
406
|
-
} else if (modelType === MODEL_TYPES.Musicgen) {
|
|
407
|
-
info = await Promise.all([
|
|
408
|
-
constructSessions(
|
|
409
|
-
pretrained_model_name_or_path,
|
|
410
|
-
{
|
|
411
|
-
model: 'text_encoder',
|
|
412
|
-
decoder_model_merged: 'decoder_model_merged',
|
|
413
|
-
encodec_decode: 'encodec_decode',
|
|
414
|
-
},
|
|
415
|
-
options,
|
|
416
|
-
'decoder_model_merged',
|
|
417
|
-
),
|
|
418
|
-
get_optional_configs(
|
|
419
|
-
pretrained_model_name_or_path,
|
|
420
|
-
{
|
|
421
|
-
generation_config: 'generation_config.json',
|
|
422
|
-
},
|
|
423
|
-
options,
|
|
424
|
-
),
|
|
425
|
-
]);
|
|
426
|
-
} else if (modelType === MODEL_TYPES.MultiModality) {
|
|
427
|
-
info = await Promise.all([
|
|
428
|
-
constructSessions(
|
|
429
|
-
pretrained_model_name_or_path,
|
|
430
|
-
{
|
|
431
|
-
prepare_inputs_embeds: 'prepare_inputs_embeds',
|
|
432
|
-
model: 'language_model',
|
|
433
|
-
lm_head: 'lm_head',
|
|
434
|
-
gen_head: 'gen_head',
|
|
435
|
-
gen_img_embeds: 'gen_img_embeds',
|
|
436
|
-
image_decode: 'image_decode',
|
|
437
|
-
},
|
|
438
|
-
options,
|
|
439
|
-
'model',
|
|
440
|
-
),
|
|
441
|
-
get_optional_configs(
|
|
442
|
-
pretrained_model_name_or_path,
|
|
443
|
-
{
|
|
444
|
-
generation_config: 'generation_config.json',
|
|
445
|
-
},
|
|
446
|
-
options,
|
|
447
|
-
),
|
|
448
|
-
]);
|
|
449
|
-
} else if (modelType === MODEL_TYPES.Phi3V) {
|
|
450
|
-
info = await Promise.all([
|
|
451
|
-
constructSessions(
|
|
452
|
-
pretrained_model_name_or_path,
|
|
453
|
-
{
|
|
454
|
-
prepare_inputs_embeds: 'prepare_inputs_embeds',
|
|
455
|
-
model: 'model',
|
|
456
|
-
vision_encoder: 'vision_encoder',
|
|
457
|
-
},
|
|
458
|
-
options,
|
|
459
|
-
'model',
|
|
460
|
-
),
|
|
461
|
-
get_optional_configs(
|
|
462
|
-
pretrained_model_name_or_path,
|
|
463
|
-
{
|
|
464
|
-
generation_config: 'generation_config.json',
|
|
465
|
-
},
|
|
466
|
-
options,
|
|
467
|
-
),
|
|
468
|
-
]);
|
|
469
|
-
} else if (modelType === MODEL_TYPES.Chatterbox) {
|
|
470
|
-
info = await Promise.all([
|
|
471
|
-
constructSessions(
|
|
472
|
-
pretrained_model_name_or_path,
|
|
473
|
-
{
|
|
474
|
-
embed_tokens: 'embed_tokens',
|
|
475
|
-
speech_encoder: 'speech_encoder',
|
|
476
|
-
model: 'language_model',
|
|
477
|
-
conditional_decoder: 'conditional_decoder',
|
|
478
|
-
},
|
|
479
|
-
options,
|
|
480
|
-
'model',
|
|
481
|
-
),
|
|
482
|
-
get_optional_configs(
|
|
483
|
-
pretrained_model_name_or_path,
|
|
484
|
-
{
|
|
485
|
-
generation_config: 'generation_config.json',
|
|
486
|
-
},
|
|
487
|
-
options,
|
|
488
|
-
),
|
|
489
|
-
]);
|
|
490
|
-
} else if (modelType === MODEL_TYPES.AutoEncoder) {
|
|
491
|
-
info = await Promise.all([
|
|
492
|
-
constructSessions(
|
|
493
|
-
pretrained_model_name_or_path,
|
|
494
|
-
{
|
|
495
|
-
encoder_model: 'encoder_model',
|
|
496
|
-
decoder_model: 'decoder_model',
|
|
497
|
-
},
|
|
498
|
-
options,
|
|
499
|
-
),
|
|
500
|
-
]);
|
|
501
|
-
} else if (modelType === MODEL_TYPES.Supertonic) {
|
|
502
|
-
info = await Promise.all([
|
|
503
|
-
constructSessions(
|
|
504
|
-
pretrained_model_name_or_path,
|
|
505
|
-
{
|
|
506
|
-
text_encoder: 'text_encoder',
|
|
507
|
-
latent_denoiser: 'latent_denoiser',
|
|
508
|
-
voice_decoder: 'voice_decoder',
|
|
509
|
-
},
|
|
510
|
-
options,
|
|
511
|
-
),
|
|
512
|
-
]);
|
|
513
|
-
} else {
|
|
514
|
-
// should be MODEL_TYPES.EncoderOnly or MODEL_TYPES.DecoderOnlyWithoutHead
|
|
515
|
-
if (modelType === undefined) {
|
|
516
|
-
const type = modelName ?? config?.model_type;
|
|
517
|
-
if (type !== 'custom') {
|
|
518
|
-
logger.warn(
|
|
519
|
-
`Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`,
|
|
520
|
-
);
|
|
521
|
-
}
|
|
392
|
+
const typeConfig = MODEL_TYPE_CONFIG[modelType] ?? MODEL_TYPE_CONFIG.default;
|
|
393
|
+
|
|
394
|
+
if (modelType === undefined) {
|
|
395
|
+
const type = modelName ?? config?.model_type;
|
|
396
|
+
if (type !== 'custom') {
|
|
397
|
+
logger.warn(
|
|
398
|
+
`Model type for '${type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`,
|
|
399
|
+
);
|
|
522
400
|
}
|
|
523
|
-
info = await Promise.all([
|
|
524
|
-
constructSessions(
|
|
525
|
-
pretrained_model_name_or_path,
|
|
526
|
-
{
|
|
527
|
-
model: options.model_file_name ?? 'model',
|
|
528
|
-
},
|
|
529
|
-
options,
|
|
530
|
-
),
|
|
531
|
-
]);
|
|
532
401
|
}
|
|
533
402
|
|
|
403
|
+
const sessions = typeConfig.sessions(config, options);
|
|
404
|
+
const promises = [
|
|
405
|
+
constructSessions(pretrained_model_name_or_path, sessions, options, typeConfig.cache_sessions),
|
|
406
|
+
];
|
|
407
|
+
if (typeConfig.optional_configs) {
|
|
408
|
+
promises.push(get_optional_configs(pretrained_model_name_or_path, typeConfig.optional_configs, options));
|
|
409
|
+
}
|
|
410
|
+
const info = await Promise.all(promises);
|
|
411
|
+
|
|
534
412
|
// @ts-ignore
|
|
535
413
|
return new this(config, ...info);
|
|
536
414
|
}
|
|
@@ -868,7 +746,7 @@ export class PreTrainedModel extends Callable {
|
|
|
868
746
|
* @param {Tensor} [params.inputs=null]
|
|
869
747
|
* @param {number} [params.bos_token_id=null]
|
|
870
748
|
* @param {Record<string, Tensor|number[]>} [params.model_kwargs]
|
|
871
|
-
* @returns {{inputs_tensor: Tensor, model_inputs: Record<string, Tensor
|
|
749
|
+
* @returns {{inputs_tensor: Tensor, model_inputs: Record<string, Tensor> & {past_key_values?: DynamicCache}, model_input_name: string}} The model-specific inputs for generation.
|
|
872
750
|
*/
|
|
873
751
|
_prepare_model_inputs({ inputs, bos_token_id, model_kwargs }) {
|
|
874
752
|
const model_inputs = pick(model_kwargs, this.forward_params);
|
|
@@ -1232,13 +1110,15 @@ export class PreTrainedModel extends Callable {
|
|
|
1232
1110
|
}
|
|
1233
1111
|
|
|
1234
1112
|
/**
|
|
1235
|
-
* Returns
|
|
1113
|
+
* Returns a DynamicCache containing past key values from the given decoder results object.
|
|
1236
1114
|
*
|
|
1237
1115
|
* @param {Object} decoderResults The decoder results object.
|
|
1238
|
-
* @param {
|
|
1239
|
-
* @
|
|
1116
|
+
* @param {DynamicCache} pastKeyValues The previous past key values.
|
|
1117
|
+
* @param {boolean} [disposeEncoderPKVs=false] Whether to dispose encoder past key values.
|
|
1118
|
+
* @returns {DynamicCache} A new DynamicCache containing the updated past key values.
|
|
1240
1119
|
*/
|
|
1241
1120
|
getPastKeyValues(decoderResults, pastKeyValues, disposeEncoderPKVs = false) {
|
|
1121
|
+
/** @type {Record<string, Tensor>} */
|
|
1242
1122
|
const pkvs = Object.create(null);
|
|
1243
1123
|
|
|
1244
1124
|
for (const name in decoderResults) {
|
|
@@ -1272,7 +1152,7 @@ export class PreTrainedModel extends Callable {
|
|
|
1272
1152
|
}
|
|
1273
1153
|
}
|
|
1274
1154
|
}
|
|
1275
|
-
return pkvs;
|
|
1155
|
+
return new DynamicCache(pkvs);
|
|
1276
1156
|
}
|
|
1277
1157
|
|
|
1278
1158
|
/**
|
|
@@ -1300,8 +1180,8 @@ export class PreTrainedModel extends Callable {
|
|
|
1300
1180
|
/**
|
|
1301
1181
|
* Adds past key values to the decoder feeds object. If pastKeyValues is null, creates new tensors for past key values.
|
|
1302
1182
|
*
|
|
1303
|
-
* @param {
|
|
1304
|
-
* @param {
|
|
1183
|
+
* @param {Record<string, any>} decoderFeeds The decoder feeds object to add past key values to.
|
|
1184
|
+
* @param {DynamicCache|null} pastKeyValues The cache containing past key values.
|
|
1305
1185
|
*/
|
|
1306
1186
|
addPastKeyValues(decoderFeeds, pastKeyValues) {
|
|
1307
1187
|
if (pastKeyValues) {
|
|
@@ -1320,19 +1200,32 @@ export class PreTrainedModel extends Callable {
|
|
|
1320
1200
|
}
|
|
1321
1201
|
}
|
|
1322
1202
|
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1203
|
+
/**
|
|
1204
|
+
* Helper function to select valid inputs and run through the appropriate encoder (vision, text, audio) based on the input type.
|
|
1205
|
+
* @param {string} sessionName
|
|
1206
|
+
* @param {Record<string, Tensor>} inputs
|
|
1207
|
+
* @param {string} outputName
|
|
1208
|
+
* @private
|
|
1209
|
+
*/
|
|
1210
|
+
async _encode_input(sessionName, inputs, outputName) {
|
|
1211
|
+
if (!Object.hasOwn(this.sessions, sessionName)) {
|
|
1212
|
+
throw new Error(`Model does not have a ${sessionName} session.`);
|
|
1213
|
+
}
|
|
1214
|
+
const session = this.sessions[sessionName];
|
|
1215
|
+
const output = await sessionRun(session, pick(inputs, session.inputNames));
|
|
1216
|
+
return output[outputName];
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
async encode_image(inputs) {
|
|
1220
|
+
return this._encode_input('vision_encoder', inputs, 'image_features');
|
|
1326
1221
|
}
|
|
1327
1222
|
|
|
1328
|
-
async encode_text(
|
|
1329
|
-
|
|
1330
|
-
return (await sessionRun(this.sessions['embed_tokens'], { input_ids })).inputs_embeds;
|
|
1223
|
+
async encode_text(inputs) {
|
|
1224
|
+
return this._encode_input('embed_tokens', inputs, 'inputs_embeds');
|
|
1331
1225
|
}
|
|
1332
1226
|
|
|
1333
|
-
async encode_audio(
|
|
1334
|
-
|
|
1335
|
-
return (await sessionRun(this.sessions['audio_encoder'], { audio_values })).audio_features;
|
|
1227
|
+
async encode_audio(inputs) {
|
|
1228
|
+
return this._encode_input('audio_encoder', inputs, 'audio_features');
|
|
1336
1229
|
}
|
|
1337
1230
|
}
|
|
1338
1231
|
|
|
@@ -1431,6 +1324,15 @@ export async function decoder_forward(self, model_inputs, is_encoder_decoder = f
|
|
|
1431
1324
|
new_model_inputs.position_ids = create_position_ids(new_model_inputs, past_key_values, start_index);
|
|
1432
1325
|
}
|
|
1433
1326
|
|
|
1327
|
+
if (session.inputNames.includes('num_logits_to_keep') && !new_model_inputs.num_logits_to_keep) {
|
|
1328
|
+
// `num_logits_to_keep` specifies the number of prompt logits to calculate during generation.
|
|
1329
|
+
// If unset (or 0), all logits will be calculated. If an integer value, only last `num_logits_to_keep`
|
|
1330
|
+
// logits will be calculated. During generation, the default is 1 because only the logits of the last
|
|
1331
|
+
// prompt token are needed for generation. For long sequences, the logits for the entire sequence may
|
|
1332
|
+
// use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint significantly.
|
|
1333
|
+
new_model_inputs.num_logits_to_keep = new Tensor('int64', [0n], []);
|
|
1334
|
+
}
|
|
1335
|
+
|
|
1434
1336
|
// Unpack the `past_key_values` object into model inputs
|
|
1435
1337
|
self.addPastKeyValues(new_model_inputs, past_key_values);
|
|
1436
1338
|
|
|
@@ -1445,13 +1347,13 @@ export async function decoder_forward(self, model_inputs, is_encoder_decoder = f
|
|
|
1445
1347
|
* @param {Object} params Additional parameters.
|
|
1446
1348
|
* @param {Function} [params.encode_function] The function to encode the modality values.
|
|
1447
1349
|
* @param {Function} [params.merge_function] The function to merge the modality features with the input embeddings.
|
|
1448
|
-
* @param {string} [params.
|
|
1350
|
+
* @param {string[]} [params.modality_input_names] The modality input name.
|
|
1449
1351
|
* @param {string} [params.modality_output_name] The modality output name.
|
|
1450
1352
|
* @param {Tensor} [params.input_ids=null]
|
|
1451
1353
|
* @param {Tensor} [params.attention_mask=null]
|
|
1452
1354
|
* @param {Tensor} [params.position_ids=null]
|
|
1453
1355
|
* @param {Tensor} [params.inputs_embeds=null]
|
|
1454
|
-
* @param {
|
|
1356
|
+
* @param {DynamicCache} [params.past_key_values=null]
|
|
1455
1357
|
* @param {Object} [params.generation_config=null]
|
|
1456
1358
|
* @param {Object} [params.logits_processor=null]
|
|
1457
1359
|
* @returns {Promise<Tensor>} The model's output tensor
|
|
@@ -1463,7 +1365,7 @@ export async function generic_text_to_text_forward(
|
|
|
1463
1365
|
// Generic parameters:
|
|
1464
1366
|
encode_function,
|
|
1465
1367
|
merge_function,
|
|
1466
|
-
|
|
1368
|
+
modality_input_names,
|
|
1467
1369
|
modality_output_name,
|
|
1468
1370
|
|
|
1469
1371
|
// Produced by the tokenizer/processor:
|
|
@@ -1483,37 +1385,39 @@ export async function generic_text_to_text_forward(
|
|
|
1483
1385
|
...kwargs
|
|
1484
1386
|
},
|
|
1485
1387
|
) {
|
|
1486
|
-
const modality_values = kwargs[modality_input_name];
|
|
1487
1388
|
if (!inputs_embeds) {
|
|
1488
1389
|
// 1. Extract the text embeddings.
|
|
1489
1390
|
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
|
|
1490
1391
|
|
|
1491
1392
|
// 2. Possibly, merge text and modality values
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
inputs_embeds,
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1393
|
+
const modality_values = pick(kwargs, modality_input_names);
|
|
1394
|
+
if (Object.keys(modality_values).length > 0) {
|
|
1395
|
+
if (input_ids.dims[1] !== 1) {
|
|
1396
|
+
const modality_features = await encode_function({
|
|
1397
|
+
// Pass the modality values under its expected key.
|
|
1398
|
+
// The caller knows whether this is audio or image.
|
|
1399
|
+
...modality_values,
|
|
1400
|
+
...kwargs,
|
|
1401
|
+
});
|
|
1402
|
+
({ inputs_embeds, attention_mask } = merge_function({
|
|
1403
|
+
[modality_output_name]: modality_features,
|
|
1404
|
+
inputs_embeds,
|
|
1405
|
+
input_ids,
|
|
1406
|
+
attention_mask,
|
|
1407
|
+
}));
|
|
1408
|
+
} else if (past_key_values && input_ids.dims[1] === 1) {
|
|
1409
|
+
// This branch handles the cache case.
|
|
1410
|
+
const target_length = input_ids.dims[1]; // always 1
|
|
1411
|
+
const past_length = past_key_values.get_seq_length();
|
|
1412
|
+
|
|
1413
|
+
attention_mask = cat(
|
|
1414
|
+
[
|
|
1415
|
+
ones([input_ids.dims[0], past_length]),
|
|
1416
|
+
attention_mask.slice(null, [attention_mask.dims[1] - target_length, attention_mask.dims[1]]),
|
|
1417
|
+
],
|
|
1418
|
+
1,
|
|
1419
|
+
);
|
|
1420
|
+
}
|
|
1517
1421
|
}
|
|
1518
1422
|
}
|
|
1519
1423
|
|
|
@@ -1522,14 +1426,19 @@ export async function generic_text_to_text_forward(
|
|
|
1522
1426
|
// Handle special case for qwen vl models
|
|
1523
1427
|
[
|
|
1524
1428
|
'qwen2_vl',
|
|
1429
|
+
'qwen2_vl_text',
|
|
1525
1430
|
'qwen2_5_vl',
|
|
1526
1431
|
'qwen2_5_vl_text',
|
|
1527
1432
|
'qwen3_vl',
|
|
1528
1433
|
'qwen3_vl_text',
|
|
1434
|
+
'qwen3_vl_moe',
|
|
1435
|
+
'qwen3_vl_moe_text',
|
|
1529
1436
|
'qwen3_5',
|
|
1530
1437
|
'qwen3_5_text',
|
|
1531
1438
|
'qwen3_5_moe',
|
|
1532
1439
|
'qwen3_5_moe_text',
|
|
1440
|
+
'glm_ocr',
|
|
1441
|
+
'glm_ocr_text',
|
|
1533
1442
|
].includes(self.config.model_type)
|
|
1534
1443
|
) {
|
|
1535
1444
|
// @ts-ignore
|
|
@@ -1564,7 +1473,7 @@ export async function generic_text_to_text_forward(
|
|
|
1564
1473
|
export async function audio_text_to_text_forward(self, params) {
|
|
1565
1474
|
return await generic_text_to_text_forward(self, {
|
|
1566
1475
|
...params,
|
|
1567
|
-
|
|
1476
|
+
modality_input_names: ['audio_values', 'input_features'],
|
|
1568
1477
|
modality_output_name: 'audio_features',
|
|
1569
1478
|
encode_function: self.encode_audio.bind(self),
|
|
1570
1479
|
merge_function: self._merge_input_ids_with_audio_features.bind(self),
|
|
@@ -1581,7 +1490,7 @@ export async function audio_text_to_text_forward(self, params) {
|
|
|
1581
1490
|
export async function image_text_to_text_forward(self, params) {
|
|
1582
1491
|
return await generic_text_to_text_forward(self, {
|
|
1583
1492
|
...params,
|
|
1584
|
-
|
|
1493
|
+
modality_input_names: ['pixel_values'],
|
|
1585
1494
|
modality_output_name: 'image_features',
|
|
1586
1495
|
encode_function: self.encode_image.bind(self),
|
|
1587
1496
|
merge_function: self._merge_input_ids_with_image_features.bind(self),
|
|
@@ -1644,7 +1553,14 @@ export function create_position_ids(model_inputs, past_key_values = null, start_
|
|
|
1644
1553
|
}
|
|
1645
1554
|
|
|
1646
1555
|
export function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, generation_config) {
|
|
1647
|
-
const past_length = model_inputs.past_key_values ?
|
|
1556
|
+
const past_length = model_inputs.past_key_values ? model_inputs.past_key_values.get_seq_length() : 0;
|
|
1557
|
+
|
|
1558
|
+
// During generation, only the last token's logits are needed. Setting num_logits_to_keep=1
|
|
1559
|
+
// avoids computing logits for the entire sequence, significantly reducing memory usage.
|
|
1560
|
+
const session = self.sessions['decoder_model_merged'] ?? self.sessions['model'];
|
|
1561
|
+
if (session?.inputNames.includes('num_logits_to_keep') && !model_inputs.num_logits_to_keep) {
|
|
1562
|
+
model_inputs.num_logits_to_keep = new Tensor('int64', [1n], []);
|
|
1563
|
+
}
|
|
1648
1564
|
|
|
1649
1565
|
if (!model_inputs.attention_mask) {
|
|
1650
1566
|
// If the attention mask is not provided, we attempt to infer based on provided inputs
|