@huggingface/transformers 3.3.3 → 3.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +13 -3
- package/dist/ort-wasm-simd-threaded.jsep.mjs +124 -115
- package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
- package/dist/transformers.js +2778 -1592
- package/dist/transformers.js.map +1 -1
- package/dist/transformers.min.js +1 -1
- package/dist/transformers.min.js.map +1 -1
- package/dist/{transformers.cjs → transformers.node.cjs} +1699 -2530
- package/dist/transformers.node.cjs.map +1 -0
- package/dist/transformers.node.min.cjs +2 -0
- package/dist/transformers.node.min.cjs.map +1 -0
- package/dist/transformers.node.min.mjs +2 -0
- package/dist/transformers.node.min.mjs.map +1 -0
- package/dist/{transformers.mjs → transformers.node.mjs} +1738 -2510
- package/dist/transformers.node.mjs.map +1 -0
- package/dist/transformers.web.js +35876 -0
- package/dist/transformers.web.js.map +1 -0
- package/dist/transformers.web.min.js +2 -0
- package/dist/transformers.web.min.js.map +1 -0
- package/package.json +6 -6
- package/src/backends/onnx.js +14 -15
- package/src/configs.js +6 -1
- package/src/env.js +1 -1
- package/src/generation/streamers.js +4 -3
- package/src/models/dac/feature_extraction_dac.js +3 -0
- package/src/models/encodec/feature_extraction_encodec.js +32 -0
- package/src/models/feature_extractors.js +3 -0
- package/src/models/idefics3/image_processing_idefics3.js +1 -1
- package/src/models/image_processors.js +1 -0
- package/src/models/processors.js +2 -0
- package/src/models/smolvlm/image_processing_smolvlm.js +2 -0
- package/src/models/smolvlm/processing_smolvlm.js +2 -0
- package/src/models/snac/feature_extraction_snac.js +3 -0
- package/src/models/ultravox/processing_ultravox.js +54 -0
- package/src/models/whisper/common_whisper.js +7 -1
- package/src/models/whisper/feature_extraction_whisper.js +18 -10
- package/src/models.js +546 -78
- package/src/pipelines.js +246 -137
- package/src/tokenizers.js +42 -28
- package/src/transformers.js +1 -0
- package/src/utils/audio.js +2 -0
- package/src/utils/hub.js +140 -80
- package/src/utils/image.js +9 -1
- package/src/utils/maths.js +1 -1
- package/src/utils/tensor.js +12 -5
- package/src/utils/video.js +128 -0
- package/types/backends/onnx.d.ts +2 -2
- package/types/backends/onnx.d.ts.map +1 -1
- package/types/configs.d.ts +1 -1
- package/types/configs.d.ts.map +1 -1
- package/types/generation/streamers.d.ts.map +1 -1
- package/types/models/dac/feature_extraction_dac.d.ts +4 -0
- package/types/models/dac/feature_extraction_dac.d.ts.map +1 -0
- package/types/models/encodec/feature_extraction_encodec.d.ts +13 -0
- package/types/models/encodec/feature_extraction_encodec.d.ts.map +1 -0
- package/types/models/feature_extractors.d.ts +3 -0
- package/types/models/florence2/processing_florence2.d.ts +1 -1
- package/types/models/florence2/processing_florence2.d.ts.map +1 -1
- package/types/models/image_processors.d.ts +1 -0
- package/types/models/processors.d.ts +2 -0
- package/types/models/smolvlm/image_processing_smolvlm.d.ts +2 -0
- package/types/models/smolvlm/image_processing_smolvlm.d.ts.map +1 -0
- package/types/models/smolvlm/processing_smolvlm.d.ts +2 -0
- package/types/models/smolvlm/processing_smolvlm.d.ts.map +1 -0
- package/types/models/snac/feature_extraction_snac.d.ts +4 -0
- package/types/models/snac/feature_extraction_snac.d.ts.map +1 -0
- package/types/models/ultravox/processing_ultravox.d.ts +16 -0
- package/types/models/ultravox/processing_ultravox.d.ts.map +1 -0
- package/types/models/whisper/common_whisper.d.ts.map +1 -1
- package/types/models/whisper/feature_extraction_whisper.d.ts +3 -1
- package/types/models/whisper/feature_extraction_whisper.d.ts.map +1 -1
- package/types/models.d.ts +180 -4
- package/types/models.d.ts.map +1 -1
- package/types/pipelines.d.ts +51 -5
- package/types/pipelines.d.ts.map +1 -1
- package/types/tokenizers.d.ts.map +1 -1
- package/types/transformers.d.ts +1 -0
- package/types/tsconfig.tsbuildinfo +1 -1
- package/types/utils/audio.d.ts.map +1 -1
- package/types/utils/hub.d.ts +19 -7
- package/types/utils/hub.d.ts.map +1 -1
- package/types/utils/image.d.ts +2 -2
- package/types/utils/image.d.ts.map +1 -1
- package/types/utils/maths.d.ts +2 -2
- package/types/utils/maths.d.ts.map +1 -1
- package/types/utils/tensor.d.ts +17 -18
- package/types/utils/tensor.d.ts.map +1 -1
- package/types/utils/video.d.ts +37 -0
- package/types/utils/video.d.ts.map +1 -0
- package/dist/transformers.cjs.map +0 -1
- package/dist/transformers.min.cjs +0 -2
- package/dist/transformers.min.cjs.map +0 -1
- package/dist/transformers.min.mjs +0 -2
- package/dist/transformers.min.mjs.map +0 -1
- package/dist/transformers.mjs.map +0 -1
package/src/models.js
CHANGED
|
@@ -68,6 +68,7 @@ import {
|
|
|
68
68
|
import {
|
|
69
69
|
getModelFile,
|
|
70
70
|
getModelJSON,
|
|
71
|
+
MAX_EXTERNAL_DATA_CHUNKS,
|
|
71
72
|
} from './utils/hub.js';
|
|
72
73
|
|
|
73
74
|
import {
|
|
@@ -108,6 +109,7 @@ import {
|
|
|
108
109
|
stack,
|
|
109
110
|
std_mean,
|
|
110
111
|
Tensor,
|
|
112
|
+
DataTypeMap,
|
|
111
113
|
} from './utils/tensor.js';
|
|
112
114
|
import { RawImage } from './utils/image.js';
|
|
113
115
|
|
|
@@ -132,6 +134,8 @@ const MODEL_TYPES = {
|
|
|
132
134
|
Musicgen: 7,
|
|
133
135
|
MultiModality: 8,
|
|
134
136
|
Phi3V: 9,
|
|
137
|
+
AudioTextToText: 10,
|
|
138
|
+
AutoEncoder: 11,
|
|
135
139
|
}
|
|
136
140
|
//////////////////////////////////////////////////
|
|
137
141
|
|
|
@@ -150,7 +154,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
|
|
|
150
154
|
* @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
|
|
151
155
|
* @param {string} fileName The name of the model file.
|
|
152
156
|
* @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
|
|
153
|
-
* @returns {Promise<{
|
|
157
|
+
* @returns {Promise<{buffer_or_path: Uint8Array|string, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
|
|
154
158
|
* @private
|
|
155
159
|
*/
|
|
156
160
|
async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
@@ -225,7 +229,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
225
229
|
|
|
226
230
|
// Construct the model file name
|
|
227
231
|
const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
|
|
228
|
-
const
|
|
232
|
+
const baseName = `${fileName}${suffix}.onnx`;
|
|
233
|
+
const modelFileName = `${options.subfolder ?? ''}/${baseName}`;
|
|
229
234
|
|
|
230
235
|
const session_options = { ...options.session_options };
|
|
231
236
|
|
|
@@ -243,29 +248,38 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
243
248
|
);
|
|
244
249
|
}
|
|
245
250
|
|
|
246
|
-
const
|
|
251
|
+
const bufferOrPathPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options, apis.IS_NODE_ENV);
|
|
247
252
|
|
|
248
253
|
// handle onnx external data files
|
|
249
254
|
const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
|
|
250
|
-
/** @type {Promise<{path: string, data: Uint8Array}>[]} */
|
|
255
|
+
/** @type {Promise<string|{path: string, data: Uint8Array}>[]} */
|
|
251
256
|
let externalDataPromises = [];
|
|
252
|
-
if (use_external_data_format
|
|
253
|
-
|
|
254
|
-
(
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
use_external_data_format
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
257
|
+
if (use_external_data_format) {
|
|
258
|
+
let external_data_format;
|
|
259
|
+
if (typeof use_external_data_format === 'object') {
|
|
260
|
+
if (use_external_data_format.hasOwnProperty(baseName)) {
|
|
261
|
+
external_data_format = use_external_data_format[baseName];
|
|
262
|
+
} else if (use_external_data_format.hasOwnProperty(fileName)) {
|
|
263
|
+
external_data_format = use_external_data_format[fileName];
|
|
264
|
+
} else {
|
|
265
|
+
external_data_format = false;
|
|
266
|
+
}
|
|
267
|
+
} else {
|
|
268
|
+
external_data_format = use_external_data_format;
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
const num_chunks = +external_data_format; // (false=0, true=1, number remains the same)
|
|
272
|
+
if (num_chunks > MAX_EXTERNAL_DATA_CHUNKS) {
|
|
273
|
+
throw new Error(`The number of external data chunks (${num_chunks}) exceeds the maximum allowed value (${MAX_EXTERNAL_DATA_CHUNKS}).`);
|
|
274
|
+
}
|
|
275
|
+
for (let i = 0; i < num_chunks; ++i) {
|
|
276
|
+
const path = `${baseName}_data${i === 0 ? '' : '_' + i}`;
|
|
277
|
+
const fullPath = `${options.subfolder ?? ''}/${path}`;
|
|
278
|
+
externalDataPromises.push(new Promise(async (resolve, reject) => {
|
|
279
|
+
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options, apis.IS_NODE_ENV);
|
|
280
|
+
resolve(data instanceof Uint8Array ? { path, data } : path);
|
|
281
|
+
}));
|
|
262
282
|
}
|
|
263
|
-
const path = `${fileName}${suffix}.onnx_data`;
|
|
264
|
-
const fullPath = `${options.subfolder ?? ''}/${path}`;
|
|
265
|
-
externalDataPromises.push(new Promise(async (resolve, reject) => {
|
|
266
|
-
const data = await getModelFile(pretrained_model_name_or_path, fullPath, true, options);
|
|
267
|
-
resolve({ path, data })
|
|
268
|
-
}));
|
|
269
283
|
|
|
270
284
|
} else if (session_options.externalData !== undefined) {
|
|
271
285
|
externalDataPromises = session_options.externalData.map(async (ext) => {
|
|
@@ -282,7 +296,10 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
282
296
|
}
|
|
283
297
|
|
|
284
298
|
if (externalDataPromises.length > 0) {
|
|
285
|
-
|
|
299
|
+
const externalData = await Promise.all(externalDataPromises);
|
|
300
|
+
if (!apis.IS_NODE_ENV) {
|
|
301
|
+
session_options.externalData = externalData;
|
|
302
|
+
}
|
|
286
303
|
}
|
|
287
304
|
|
|
288
305
|
if (selectedDevice === 'webgpu') {
|
|
@@ -300,9 +317,9 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
300
317
|
}
|
|
301
318
|
}
|
|
302
319
|
|
|
303
|
-
const
|
|
320
|
+
const buffer_or_path = await bufferOrPathPromise;
|
|
304
321
|
|
|
305
|
-
return {
|
|
322
|
+
return { buffer_or_path, session_options, session_config };
|
|
306
323
|
}
|
|
307
324
|
|
|
308
325
|
/**
|
|
@@ -317,8 +334,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
|
|
|
317
334
|
async function constructSessions(pretrained_model_name_or_path, names, options) {
|
|
318
335
|
return Object.fromEntries(await Promise.all(
|
|
319
336
|
Object.keys(names).map(async (name) => {
|
|
320
|
-
const {
|
|
321
|
-
const session = await createInferenceSession(
|
|
337
|
+
const { buffer_or_path, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
|
|
338
|
+
const session = await createInferenceSession(buffer_or_path, session_options, session_config);
|
|
322
339
|
return [name, session];
|
|
323
340
|
})
|
|
324
341
|
));
|
|
@@ -548,10 +565,16 @@ async function encoderForward(self, model_inputs) {
|
|
|
548
565
|
const dims = encoderFeeds.pixel_values.dims;
|
|
549
566
|
encoderFeeds.pixel_mask = ones([dims[0], dims[2], dims[3]]);
|
|
550
567
|
}
|
|
551
|
-
|
|
568
|
+
|
|
552
569
|
return await sessionRun(session, encoderFeeds);
|
|
553
570
|
}
|
|
554
571
|
|
|
572
|
+
async function autoEncoderForward(self, model_inputs) {
|
|
573
|
+
const encoded = await self.encode(model_inputs);
|
|
574
|
+
const decoded = await self.decode(encoded);
|
|
575
|
+
return decoded;
|
|
576
|
+
}
|
|
577
|
+
|
|
555
578
|
/**
|
|
556
579
|
* Forward pass of a decoder model.
|
|
557
580
|
* @param {Object} self The decoder model.
|
|
@@ -571,8 +594,8 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
|
|
|
571
594
|
new_model_inputs.use_cache_branch = boolTensor(!!past_key_values);
|
|
572
595
|
}
|
|
573
596
|
if (session.inputNames.includes('position_ids') && new_model_inputs.attention_mask && !new_model_inputs.position_ids) {
|
|
574
|
-
// NOTE: Handle a special case for paligemma models, where positions are 1-indexed
|
|
575
|
-
const start_index = self.config.model_type
|
|
597
|
+
// NOTE: Handle a special case for paligemma/gemma3 models, where positions are 1-indexed
|
|
598
|
+
const start_index = ['paligemma', 'gemma3_text', 'gemma3'].includes(self.config.model_type) ? 1 : 0;
|
|
576
599
|
new_model_inputs.position_ids = createPositionIds(new_model_inputs, past_key_values, start_index);
|
|
577
600
|
}
|
|
578
601
|
|
|
@@ -586,58 +609,98 @@ async function decoderForward(self, model_inputs, is_encoder_decoder = false) {
|
|
|
586
609
|
|
|
587
610
|
|
|
588
611
|
|
|
589
|
-
function
|
|
590
|
-
|
|
612
|
+
function default_merge_input_ids_with_features({
|
|
613
|
+
modality_token_id,
|
|
591
614
|
inputs_embeds,
|
|
592
|
-
|
|
615
|
+
modality_features,
|
|
593
616
|
input_ids,
|
|
594
617
|
attention_mask,
|
|
595
618
|
}) {
|
|
596
|
-
const
|
|
619
|
+
const token_positions = input_ids.tolist().map(ids =>
|
|
597
620
|
ids.reduce((acc, x, idx) => {
|
|
598
|
-
if (x ==
|
|
621
|
+
if (x == modality_token_id) acc.push(idx);
|
|
599
622
|
return acc;
|
|
600
623
|
}, [])
|
|
601
624
|
);
|
|
602
|
-
const
|
|
603
|
-
const
|
|
604
|
-
if (
|
|
605
|
-
throw new Error(`
|
|
625
|
+
const n_tokens = token_positions.reduce((acc, x) => acc + x.length, 0);
|
|
626
|
+
const n_features = modality_features.dims[0];
|
|
627
|
+
if (n_tokens !== n_features) {
|
|
628
|
+
throw new Error(`Number of tokens and features do not match: tokens: ${n_tokens}, features ${n_features}`);
|
|
606
629
|
}
|
|
607
630
|
|
|
608
631
|
// Equivalent to performing a masked_scatter
|
|
609
632
|
let img = 0;
|
|
610
|
-
for (let i = 0; i <
|
|
611
|
-
const tokens =
|
|
633
|
+
for (let i = 0; i < token_positions.length; ++i) {
|
|
634
|
+
const tokens = token_positions[i];
|
|
612
635
|
const embeds = inputs_embeds[i];
|
|
613
636
|
for (let j = 0; j < tokens.length; ++j) {
|
|
614
|
-
embeds[tokens[j]].data.set(
|
|
637
|
+
embeds[tokens[j]].data.set(modality_features[img++].data)
|
|
615
638
|
}
|
|
616
639
|
}
|
|
617
640
|
return { inputs_embeds, attention_mask }
|
|
618
641
|
}
|
|
619
642
|
|
|
620
643
|
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
644
|
+
function default_merge_input_ids_with_image_features({
|
|
645
|
+
image_token_id,
|
|
646
|
+
inputs_embeds,
|
|
647
|
+
image_features,
|
|
648
|
+
input_ids,
|
|
649
|
+
attention_mask,
|
|
650
|
+
}) {
|
|
651
|
+
return default_merge_input_ids_with_features({
|
|
652
|
+
modality_token_id: image_token_id,
|
|
653
|
+
inputs_embeds,
|
|
654
|
+
modality_features: image_features,
|
|
655
|
+
input_ids,
|
|
656
|
+
attention_mask,
|
|
657
|
+
})
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
function default_merge_input_ids_with_audio_features({
|
|
661
|
+
audio_token_id,
|
|
662
|
+
inputs_embeds,
|
|
663
|
+
audio_features,
|
|
664
|
+
input_ids,
|
|
665
|
+
attention_mask,
|
|
666
|
+
}) {
|
|
667
|
+
return default_merge_input_ids_with_features({
|
|
668
|
+
modality_token_id: audio_token_id,
|
|
669
|
+
inputs_embeds,
|
|
670
|
+
modality_features: audio_features,
|
|
671
|
+
input_ids,
|
|
672
|
+
attention_mask,
|
|
673
|
+
})
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
/**
|
|
677
|
+
* Abstract forward pass function for image-text-to-text or audio-text-to-text models.
|
|
678
|
+
* @param {Object} self The model object.
|
|
679
|
+
* @param {Object} params Additional parameters.
|
|
680
|
+
* @param {Function} [params.encode_function] The function to encode the modality values.
|
|
681
|
+
* @param {Function} [params.merge_function] The function to merge the modality features with the input embeddings.
|
|
682
|
+
* @param {string} [params.modality_input_name] The modality input name.
|
|
683
|
+
* @param {string} [params.modality_output_name] The modality output name.
|
|
684
|
+
* @param {Tensor} [params.input_ids=null]
|
|
685
|
+
* @param {Tensor} [params.attention_mask=null]
|
|
686
|
+
* @param {Tensor} [params.position_ids=null]
|
|
687
|
+
* @param {Tensor} [params.inputs_embeds=null]
|
|
688
|
+
* @param {Tensor} [params.past_key_values=null]
|
|
689
|
+
* @param {Object} [params.generation_config=null]
|
|
690
|
+
* @param {Object} [params.logits_processor=null]
|
|
633
691
|
* @returns {Promise<Tensor>} The model's output tensor
|
|
634
692
|
* @private
|
|
635
693
|
*/
|
|
636
|
-
async function
|
|
694
|
+
async function genericTextToTextForward(self, {
|
|
695
|
+
// Generic parameters:
|
|
696
|
+
encode_function,
|
|
697
|
+
merge_function,
|
|
698
|
+
modality_input_name,
|
|
699
|
+
modality_output_name,
|
|
700
|
+
|
|
637
701
|
// Produced by the tokenizer/processor:
|
|
638
702
|
input_ids = null,
|
|
639
703
|
attention_mask = null,
|
|
640
|
-
pixel_values = null,
|
|
641
704
|
|
|
642
705
|
// Used during generation:
|
|
643
706
|
position_ids = null,
|
|
@@ -648,27 +711,31 @@ async function imageTextToTextForward(self, {
|
|
|
648
711
|
generation_config = null,
|
|
649
712
|
logits_processor = null,
|
|
650
713
|
|
|
651
|
-
//
|
|
714
|
+
// Additional parameters
|
|
652
715
|
...kwargs
|
|
653
716
|
}) {
|
|
654
|
-
|
|
717
|
+
const modality_values = kwargs[modality_input_name];
|
|
655
718
|
if (!inputs_embeds) {
|
|
656
|
-
// 1. Extract the
|
|
719
|
+
// 1. Extract the text embeddings.
|
|
657
720
|
inputs_embeds = await self.encode_text({ input_ids, ...kwargs });
|
|
658
721
|
|
|
659
|
-
// 2. Possibly, merge text and
|
|
660
|
-
if (
|
|
661
|
-
const
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
722
|
+
// 2. Possibly, merge text and modality values
|
|
723
|
+
if (modality_values && input_ids.dims[1] !== 1) {
|
|
724
|
+
const modality_features = await encode_function({
|
|
725
|
+
// Pass the modality values under its expected key.
|
|
726
|
+
// The caller knows whether this is audio or image.
|
|
727
|
+
[modality_input_name]: modality_values,
|
|
728
|
+
...kwargs
|
|
729
|
+
});
|
|
730
|
+
({ inputs_embeds, attention_mask } = merge_function({
|
|
731
|
+
[modality_output_name]: modality_features,
|
|
665
732
|
inputs_embeds,
|
|
666
733
|
input_ids,
|
|
667
734
|
attention_mask,
|
|
668
735
|
}));
|
|
669
736
|
|
|
670
|
-
} else if (past_key_values &&
|
|
671
|
-
// This
|
|
737
|
+
} else if (past_key_values && modality_values && input_ids.dims[1] === 1) {
|
|
738
|
+
// This branch handles the cache case.
|
|
672
739
|
const target_length = input_ids.dims[1]; // always 1
|
|
673
740
|
const past_length = Object.values(past_key_values)[0].dims.at(-2);
|
|
674
741
|
|
|
@@ -689,6 +756,7 @@ async function imageTextToTextForward(self, {
|
|
|
689
756
|
}
|
|
690
757
|
}
|
|
691
758
|
|
|
759
|
+
// 3. Call the decoder forward using the updated inputs.
|
|
692
760
|
const outputs = await decoderForward(self, {
|
|
693
761
|
inputs_embeds,
|
|
694
762
|
past_key_values,
|
|
@@ -700,6 +768,40 @@ async function imageTextToTextForward(self, {
|
|
|
700
768
|
return outputs;
|
|
701
769
|
}
|
|
702
770
|
|
|
771
|
+
/**
|
|
772
|
+
* Forward pass of an audio-text-to-text model.
|
|
773
|
+
* @param {Object} self The audio-text-to-text model.
|
|
774
|
+
* @param {Object} params The inputs for the audio-text-to-text forward pass.
|
|
775
|
+
* @returns {Promise<Tensor>} The model's output tensor.
|
|
776
|
+
* @private
|
|
777
|
+
*/
|
|
778
|
+
async function audioTextToTextForward(self, params) {
|
|
779
|
+
return await genericTextToTextForward(self, {
|
|
780
|
+
...params,
|
|
781
|
+
modality_input_name: 'audio_values',
|
|
782
|
+
modality_output_name: 'audio_features',
|
|
783
|
+
encode_function: self.encode_audio.bind(self),
|
|
784
|
+
merge_function: self._merge_input_ids_with_audio_features.bind(self),
|
|
785
|
+
});
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
/**
|
|
789
|
+
* Forward pass of an image-text-to-text model.
|
|
790
|
+
* @param {Object} self The image-text-to-text model.
|
|
791
|
+
* @param {Object} params The inputs for the image-text-to-text forward pass.
|
|
792
|
+
* @returns {Promise<Tensor>} The model's output tensor.
|
|
793
|
+
* @private
|
|
794
|
+
*/
|
|
795
|
+
async function imageTextToTextForward(self, params) {
|
|
796
|
+
return await genericTextToTextForward(self, {
|
|
797
|
+
...params,
|
|
798
|
+
modality_input_name: 'pixel_values',
|
|
799
|
+
modality_output_name: 'image_features',
|
|
800
|
+
encode_function: self.encode_image.bind(self),
|
|
801
|
+
merge_function: self._merge_input_ids_with_image_features.bind(self),
|
|
802
|
+
});
|
|
803
|
+
}
|
|
804
|
+
|
|
703
805
|
/**
|
|
704
806
|
* Helper function to perform the following:
|
|
705
807
|
* ```python
|
|
@@ -813,7 +915,7 @@ function encoder_decoder_prepare_inputs_for_generation(self, input_ids, model_in
|
|
|
813
915
|
};
|
|
814
916
|
}
|
|
815
917
|
|
|
816
|
-
function
|
|
918
|
+
function multimodal_text_to_text_prepare_inputs_for_generation(self, ...args) {
|
|
817
919
|
if (self.config.is_encoder_decoder) {
|
|
818
920
|
return encoder_decoder_prepare_inputs_for_generation(self, ...args);
|
|
819
921
|
} else {
|
|
@@ -917,18 +1019,24 @@ export class PreTrainedModel extends Callable {
|
|
|
917
1019
|
case MODEL_TYPES.ImageTextToText:
|
|
918
1020
|
this.can_generate = true;
|
|
919
1021
|
this._forward = imageTextToTextForward;
|
|
920
|
-
this._prepare_inputs_for_generation =
|
|
1022
|
+
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
|
|
1023
|
+
break;
|
|
1024
|
+
case MODEL_TYPES.AudioTextToText:
|
|
1025
|
+
this.can_generate = true;
|
|
1026
|
+
this._forward = audioTextToTextForward;
|
|
1027
|
+
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
|
|
921
1028
|
break;
|
|
922
1029
|
case MODEL_TYPES.Phi3V:
|
|
923
1030
|
this.can_generate = true;
|
|
924
|
-
this._prepare_inputs_for_generation =
|
|
1031
|
+
this._prepare_inputs_for_generation = multimodal_text_to_text_prepare_inputs_for_generation;
|
|
925
1032
|
break;
|
|
926
|
-
|
|
927
1033
|
case MODEL_TYPES.MultiModality:
|
|
928
1034
|
this.can_generate = true;
|
|
929
1035
|
this._prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation;
|
|
930
1036
|
break;
|
|
931
|
-
|
|
1037
|
+
case MODEL_TYPES.AutoEncoder:
|
|
1038
|
+
this._forward = autoEncoderForward;
|
|
1039
|
+
break;
|
|
932
1040
|
default:
|
|
933
1041
|
// should be MODEL_TYPES.EncoderOnly
|
|
934
1042
|
this._forward = encoderForward;
|
|
@@ -1060,6 +1168,19 @@ export class PreTrainedModel extends Callable {
|
|
|
1060
1168
|
}, options),
|
|
1061
1169
|
]);
|
|
1062
1170
|
|
|
1171
|
+
} else if (modelType === MODEL_TYPES.AudioTextToText) {
|
|
1172
|
+
const sessions = {
|
|
1173
|
+
embed_tokens: 'embed_tokens',
|
|
1174
|
+
audio_encoder: 'audio_encoder',
|
|
1175
|
+
decoder_model_merged: 'decoder_model_merged',
|
|
1176
|
+
}
|
|
1177
|
+
info = await Promise.all([
|
|
1178
|
+
constructSessions(pretrained_model_name_or_path, sessions, options),
|
|
1179
|
+
getOptionalConfigs(pretrained_model_name_or_path, {
|
|
1180
|
+
generation_config: 'generation_config.json',
|
|
1181
|
+
}, options),
|
|
1182
|
+
]);
|
|
1183
|
+
|
|
1063
1184
|
} else if (modelType === MODEL_TYPES.Musicgen) {
|
|
1064
1185
|
info = await Promise.all([
|
|
1065
1186
|
constructSessions(pretrained_model_name_or_path, {
|
|
@@ -1098,7 +1219,13 @@ export class PreTrainedModel extends Callable {
|
|
|
1098
1219
|
generation_config: 'generation_config.json',
|
|
1099
1220
|
}, options),
|
|
1100
1221
|
]);
|
|
1101
|
-
|
|
1222
|
+
} else if (modelType === MODEL_TYPES.AutoEncoder) {
|
|
1223
|
+
info = await Promise.all([
|
|
1224
|
+
constructSessions(pretrained_model_name_or_path, {
|
|
1225
|
+
encoder_model: 'encoder_model',
|
|
1226
|
+
decoder_model: 'decoder_model',
|
|
1227
|
+
}, options),
|
|
1228
|
+
]);
|
|
1102
1229
|
} else { // should be MODEL_TYPES.EncoderOnly
|
|
1103
1230
|
if (modelType !== MODEL_TYPES.EncoderOnly) {
|
|
1104
1231
|
const type = modelName ?? config?.model_type;
|
|
@@ -1847,7 +1974,7 @@ export class PreTrainedModel extends Callable {
|
|
|
1847
1974
|
} else {
|
|
1848
1975
|
const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
|
|
1849
1976
|
const dtype = session?.config?.kv_cache_dtype ?? 'float32';
|
|
1850
|
-
const empty = (dtype === 'float16') ? new
|
|
1977
|
+
const empty = (dtype === 'float16') ? new DataTypeMap.float16() : [];
|
|
1851
1978
|
|
|
1852
1979
|
const batch_size = (decoderFeeds[this.main_input_name] ?? decoderFeeds.attention_mask)?.dims?.[0] ?? 1;
|
|
1853
1980
|
const shapes = getKeyValueShapes(this.config, { batch_size });
|
|
@@ -1877,6 +2004,11 @@ export class PreTrainedModel extends Callable {
|
|
|
1877
2004
|
// text_inputs === { input_ids, attention_mask }
|
|
1878
2005
|
return (await sessionRun(this.sessions['embed_tokens'], { input_ids })).inputs_embeds;
|
|
1879
2006
|
}
|
|
2007
|
+
|
|
2008
|
+
async encode_audio({ audio_values }) {
|
|
2009
|
+
// audio_inputs === { audio_values }
|
|
2010
|
+
return (await sessionRun(this.sessions['audio_encoder'], { audio_values })).audio_features;
|
|
2011
|
+
}
|
|
1880
2012
|
}
|
|
1881
2013
|
|
|
1882
2014
|
//////////////////////////////////////////////////
|
|
@@ -3420,6 +3552,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
|
|
|
3420
3552
|
}
|
|
3421
3553
|
//////////////////////////////////////////////////
|
|
3422
3554
|
|
|
3555
|
+
export class LiteWhisperForConditionalGeneration extends WhisperForConditionalGeneration { }
|
|
3423
3556
|
|
|
3424
3557
|
//////////////////////////////////////////////////
|
|
3425
3558
|
// Moonshine models
|
|
@@ -3691,7 +3824,7 @@ export class Idefics3PreTrainedModel extends PreTrainedModel {
|
|
|
3691
3824
|
}
|
|
3692
3825
|
|
|
3693
3826
|
/**
|
|
3694
|
-
* The
|
|
3827
|
+
* The Idefics3 model which consists of a vision backbone and a language model.
|
|
3695
3828
|
*/
|
|
3696
3829
|
export class Idefics3ForConditionalGeneration extends Idefics3PreTrainedModel {
|
|
3697
3830
|
|
|
@@ -3714,6 +3847,13 @@ export class Idefics3ForConditionalGeneration extends Idefics3PreTrainedModel {
|
|
|
3714
3847
|
}
|
|
3715
3848
|
//////////////////////////////////////////////////
|
|
3716
3849
|
|
|
3850
|
+
/**
|
|
3851
|
+
* The SmolVLM Model with a language modeling head.
|
|
3852
|
+
* It is made up a SigLIP vision encoder, with a language modeling head on top.
|
|
3853
|
+
*/
|
|
3854
|
+
export class SmolVLMForConditionalGeneration extends Idefics3ForConditionalGeneration { }
|
|
3855
|
+
|
|
3856
|
+
//////////////////////////////////////////////////
|
|
3717
3857
|
export class Phi3VPreTrainedModel extends PreTrainedModel {
|
|
3718
3858
|
forward_params = [
|
|
3719
3859
|
'input_ids',
|
|
@@ -4380,6 +4520,23 @@ export class Gemma2Model extends Gemma2PreTrainedModel { }
|
|
|
4380
4520
|
export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
|
|
4381
4521
|
//////////////////////////////////////////////////
|
|
4382
4522
|
|
|
4523
|
+
|
|
4524
|
+
//////////////////////////////////////////////////
|
|
4525
|
+
// Gemma3 models
|
|
4526
|
+
|
|
4527
|
+
/**
|
|
4528
|
+
* The bare Gemma3 Model outputting raw hidden-states without any specific head on top.
|
|
4529
|
+
*/
|
|
4530
|
+
export class Gemma3PreTrainedModel extends PreTrainedModel { }
|
|
4531
|
+
/**
|
|
4532
|
+
* The bare Gemma3 Model outputting raw hidden-states without any specific head on top.
|
|
4533
|
+
*/
|
|
4534
|
+
export class Gemma3Model extends Gemma3PreTrainedModel { }
|
|
4535
|
+
|
|
4536
|
+
export class Gemma3ForCausalLM extends Gemma3PreTrainedModel { }
|
|
4537
|
+
//////////////////////////////////////////////////
|
|
4538
|
+
|
|
4539
|
+
|
|
4383
4540
|
//////////////////////////////////////////////////
|
|
4384
4541
|
export class OpenELMPreTrainedModel extends PreTrainedModel { }
|
|
4385
4542
|
export class OpenELMModel extends OpenELMPreTrainedModel { }
|
|
@@ -5112,6 +5269,7 @@ export class SwinForImageClassification extends SwinPreTrainedModel {
|
|
|
5112
5269
|
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
5113
5270
|
}
|
|
5114
5271
|
}
|
|
5272
|
+
export class SwinForSemanticSegmentation extends SwinPreTrainedModel { }
|
|
5115
5273
|
//////////////////////////////////////////////////
|
|
5116
5274
|
|
|
5117
5275
|
//////////////////////////////////////////////////
|
|
@@ -5231,6 +5389,16 @@ export class DepthProPreTrainedModel extends PreTrainedModel { }
|
|
|
5231
5389
|
export class DepthProForDepthEstimation extends DepthProPreTrainedModel { }
|
|
5232
5390
|
//////////////////////////////////////////////////
|
|
5233
5391
|
|
|
5392
|
+
//////////////////////////////////////////////////
|
|
5393
|
+
export class Metric3DPreTrainedModel extends PreTrainedModel { }
|
|
5394
|
+
export class Metric3DForDepthEstimation extends Metric3DPreTrainedModel { }
|
|
5395
|
+
//////////////////////////////////////////////////
|
|
5396
|
+
|
|
5397
|
+
//////////////////////////////////////////////////
|
|
5398
|
+
export class Metric3Dv2PreTrainedModel extends PreTrainedModel { }
|
|
5399
|
+
export class Metric3Dv2ForDepthEstimation extends Metric3Dv2PreTrainedModel { }
|
|
5400
|
+
//////////////////////////////////////////////////
|
|
5401
|
+
|
|
5234
5402
|
//////////////////////////////////////////////////
|
|
5235
5403
|
export class MaskFormerPreTrainedModel extends PreTrainedModel { }
|
|
5236
5404
|
export class MaskFormerModel extends MaskFormerPreTrainedModel { }
|
|
@@ -6714,6 +6882,8 @@ export class MobileNetV1ForImageClassification extends MobileNetV1PreTrainedMode
|
|
|
6714
6882
|
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
6715
6883
|
}
|
|
6716
6884
|
}
|
|
6885
|
+
|
|
6886
|
+
export class MobileNetV1ForSemanticSegmentation extends MobileNetV1PreTrainedModel { }
|
|
6717
6887
|
//////////////////////////////////////////////////
|
|
6718
6888
|
|
|
6719
6889
|
//////////////////////////////////////////////////
|
|
@@ -6737,6 +6907,7 @@ export class MobileNetV2ForImageClassification extends MobileNetV2PreTrainedMode
|
|
|
6737
6907
|
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
6738
6908
|
}
|
|
6739
6909
|
}
|
|
6910
|
+
export class MobileNetV2ForSemanticSegmentation extends MobileNetV2PreTrainedModel { }
|
|
6740
6911
|
//////////////////////////////////////////////////
|
|
6741
6912
|
|
|
6742
6913
|
//////////////////////////////////////////////////
|
|
@@ -6760,6 +6931,7 @@ export class MobileNetV3ForImageClassification extends MobileNetV3PreTrainedMode
|
|
|
6760
6931
|
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
6761
6932
|
}
|
|
6762
6933
|
}
|
|
6934
|
+
export class MobileNetV3ForSemanticSegmentation extends MobileNetV3PreTrainedModel { }
|
|
6763
6935
|
//////////////////////////////////////////////////
|
|
6764
6936
|
|
|
6765
6937
|
//////////////////////////////////////////////////
|
|
@@ -6783,6 +6955,7 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
|
|
|
6783
6955
|
return new SequenceClassifierOutput(await super._call(model_inputs));
|
|
6784
6956
|
}
|
|
6785
6957
|
}
|
|
6958
|
+
export class MobileNetV4ForSemanticSegmentation extends MobileNetV4PreTrainedModel { }
|
|
6786
6959
|
//////////////////////////////////////////////////
|
|
6787
6960
|
|
|
6788
6961
|
//////////////////////////////////////////////////
|
|
@@ -6963,6 +7136,237 @@ export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel { }
|
|
|
6963
7136
|
export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel { }
|
|
6964
7137
|
//////////////////////////////////////////////////
|
|
6965
7138
|
|
|
7139
|
+
//////////////////////////////////////////////////
|
|
7140
|
+
export class UltravoxPreTrainedModel extends PreTrainedModel {
|
|
7141
|
+
forward_params = [
|
|
7142
|
+
'input_ids',
|
|
7143
|
+
'attention_mask',
|
|
7144
|
+
'position_ids',
|
|
7145
|
+
'audio_values',
|
|
7146
|
+
'past_key_values',
|
|
7147
|
+
];
|
|
7148
|
+
}
|
|
7149
|
+
|
|
7150
|
+
export class UltravoxModel extends UltravoxPreTrainedModel {
|
|
7151
|
+
|
|
7152
|
+
_merge_input_ids_with_audio_features(kwargs) {
|
|
7153
|
+
const audio_hidden_size = kwargs.audio_features.dims.at(-1);
|
|
7154
|
+
const reshaped_audio_features = kwargs.audio_features.view(-1, audio_hidden_size);
|
|
7155
|
+
|
|
7156
|
+
return default_merge_input_ids_with_audio_features({
|
|
7157
|
+
// @ts-ignore
|
|
7158
|
+
audio_token_id: this.config.ignore_index,
|
|
7159
|
+
...kwargs,
|
|
7160
|
+
audio_features: reshaped_audio_features,
|
|
7161
|
+
})
|
|
7162
|
+
}
|
|
7163
|
+
}
|
|
7164
|
+
//////////////////////////////////////////////////
|
|
7165
|
+
|
|
7166
|
+
//////////////////////////////////////////////////
|
|
7167
|
+
// Mimi models
|
|
7168
|
+
export class MimiPreTrainedModel extends PreTrainedModel {
|
|
7169
|
+
main_input_name = 'input_values';
|
|
7170
|
+
forward_params = ['input_values'];
|
|
7171
|
+
}
|
|
7172
|
+
|
|
7173
|
+
export class MimiEncoderOutput extends ModelOutput {
|
|
7174
|
+
/**
|
|
7175
|
+
* @param {Object} output The output of the model.
|
|
7176
|
+
* @param {Tensor} output.audio_codes Discrete code embeddings, of shape `(batch_size, num_quantizers, codes_length)`.
|
|
7177
|
+
*/
|
|
7178
|
+
constructor({ audio_codes }) {
|
|
7179
|
+
super();
|
|
7180
|
+
this.audio_codes = audio_codes;
|
|
7181
|
+
}
|
|
7182
|
+
}
|
|
7183
|
+
|
|
7184
|
+
export class MimiDecoderOutput extends ModelOutput {
|
|
7185
|
+
/**
|
|
7186
|
+
* @param {Object} output The output of the model.
|
|
7187
|
+
* @param {Tensor} output.audio_values Decoded audio values, of shape `(batch_size, num_channels, sequence_length)`.
|
|
7188
|
+
*/
|
|
7189
|
+
constructor({ audio_values }) {
|
|
7190
|
+
super();
|
|
7191
|
+
this.audio_values = audio_values;
|
|
7192
|
+
}
|
|
7193
|
+
}
|
|
7194
|
+
|
|
7195
|
+
/**
|
|
7196
|
+
* The Mimi neural audio codec model.
|
|
7197
|
+
*/
|
|
7198
|
+
export class MimiModel extends MimiPreTrainedModel {
|
|
7199
|
+
/**
|
|
7200
|
+
* Encodes the input audio waveform into discrete codes.
|
|
7201
|
+
* @param {Object} inputs Model inputs
|
|
7202
|
+
* @param {Tensor} [inputs.input_values] Float values of the input audio waveform, of shape `(batch_size, channels, sequence_length)`).
|
|
7203
|
+
* @returns {Promise<MimiEncoderOutput>} The output tensor of shape `(batch_size, num_codebooks, sequence_length)`.
|
|
7204
|
+
*/
|
|
7205
|
+
async encode(inputs) {
|
|
7206
|
+
return new MimiEncoderOutput(await sessionRun(this.sessions['encoder_model'], inputs));
|
|
7207
|
+
}
|
|
7208
|
+
|
|
7209
|
+
/**
|
|
7210
|
+
* Decodes the given frames into an output audio waveform.
|
|
7211
|
+
* @param {MimiEncoderOutput} inputs The encoded audio codes.
|
|
7212
|
+
* @returns {Promise<MimiDecoderOutput>} The output tensor of shape `(batch_size, num_channels, sequence_length)`.
|
|
7213
|
+
*/
|
|
7214
|
+
async decode(inputs) {
|
|
7215
|
+
return new MimiDecoderOutput(await sessionRun(this.sessions['decoder_model'], inputs));
|
|
7216
|
+
}
|
|
7217
|
+
}
|
|
7218
|
+
|
|
7219
|
+
export class MimiEncoderModel extends MimiPreTrainedModel {
|
|
7220
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
7221
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
7222
|
+
return super.from_pretrained(pretrained_model_name_or_path, {
|
|
7223
|
+
...options,
|
|
7224
|
+
// Update default model file name if not provided
|
|
7225
|
+
model_file_name: options.model_file_name ?? 'encoder_model',
|
|
7226
|
+
});
|
|
7227
|
+
}
|
|
7228
|
+
}
|
|
7229
|
+
export class MimiDecoderModel extends MimiPreTrainedModel {
|
|
7230
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
7231
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
7232
|
+
return super.from_pretrained(pretrained_model_name_or_path, {
|
|
7233
|
+
...options,
|
|
7234
|
+
// Update default model file name if not provided
|
|
7235
|
+
model_file_name: options.model_file_name ?? 'decoder_model',
|
|
7236
|
+
});
|
|
7237
|
+
}
|
|
7238
|
+
}
|
|
7239
|
+
//////////////////////////////////////////////////
|
|
7240
|
+
|
|
7241
|
+
|
|
7242
|
+
//////////////////////////////////////////////////
|
|
7243
|
+
// Dac models
|
|
7244
|
+
export class DacPreTrainedModel extends PreTrainedModel {
|
|
7245
|
+
main_input_name = 'input_values';
|
|
7246
|
+
forward_params = ['input_values'];
|
|
7247
|
+
}
|
|
7248
|
+
|
|
7249
|
+
export class DacEncoderOutput extends ModelOutput {
|
|
7250
|
+
/**
|
|
7251
|
+
* @param {Object} output The output of the model.
|
|
7252
|
+
* @param {Tensor} output.audio_codes Discrete code embeddings, of shape `(batch_size, num_quantizers, codes_length)`.
|
|
7253
|
+
*/
|
|
7254
|
+
constructor({ audio_codes }) {
|
|
7255
|
+
super();
|
|
7256
|
+
this.audio_codes = audio_codes;
|
|
7257
|
+
}
|
|
7258
|
+
}
|
|
7259
|
+
|
|
7260
|
+
export class DacDecoderOutput extends ModelOutput {
|
|
7261
|
+
/**
|
|
7262
|
+
* @param {Object} output The output of the model.
|
|
7263
|
+
* @param {Tensor} output.audio_values Decoded audio values, of shape `(batch_size, num_channels, sequence_length)`.
|
|
7264
|
+
*/
|
|
7265
|
+
constructor({ audio_values }) {
|
|
7266
|
+
super();
|
|
7267
|
+
this.audio_values = audio_values;
|
|
7268
|
+
}
|
|
7269
|
+
}
|
|
7270
|
+
|
|
7271
|
+
/**
|
|
7272
|
+
* The DAC (Descript Audio Codec) model.
|
|
7273
|
+
*/
|
|
7274
|
+
export class DacModel extends DacPreTrainedModel {
|
|
7275
|
+
/**
|
|
7276
|
+
* Encodes the input audio waveform into discrete codes.
|
|
7277
|
+
* @param {Object} inputs Model inputs
|
|
7278
|
+
* @param {Tensor} [inputs.input_values] Float values of the input audio waveform, of shape `(batch_size, channels, sequence_length)`).
|
|
7279
|
+
* @returns {Promise<DacEncoderOutput>} The output tensor of shape `(batch_size, num_codebooks, sequence_length)`.
|
|
7280
|
+
*/
|
|
7281
|
+
async encode(inputs) {
|
|
7282
|
+
return new DacEncoderOutput(await sessionRun(this.sessions['encoder_model'], inputs));
|
|
7283
|
+
}
|
|
7284
|
+
|
|
7285
|
+
/**
|
|
7286
|
+
* Decodes the given frames into an output audio waveform.
|
|
7287
|
+
* @param {DacEncoderOutput} inputs The encoded audio codes.
|
|
7288
|
+
* @returns {Promise<DacDecoderOutput>} The output tensor of shape `(batch_size, num_channels, sequence_length)`.
|
|
7289
|
+
*/
|
|
7290
|
+
async decode(inputs) {
|
|
7291
|
+
return new DacDecoderOutput(await sessionRun(this.sessions['decoder_model'], inputs));
|
|
7292
|
+
}
|
|
7293
|
+
}
|
|
7294
|
+
|
|
7295
|
+
export class DacEncoderModel extends DacPreTrainedModel {
|
|
7296
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
7297
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
7298
|
+
return super.from_pretrained(pretrained_model_name_or_path, {
|
|
7299
|
+
...options,
|
|
7300
|
+
// Update default model file name if not provided
|
|
7301
|
+
model_file_name: options.model_file_name ?? 'encoder_model',
|
|
7302
|
+
});
|
|
7303
|
+
}
|
|
7304
|
+
}
|
|
7305
|
+
export class DacDecoderModel extends DacPreTrainedModel {
|
|
7306
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
7307
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
7308
|
+
return super.from_pretrained(pretrained_model_name_or_path, {
|
|
7309
|
+
...options,
|
|
7310
|
+
// Update default model file name if not provided
|
|
7311
|
+
model_file_name: options.model_file_name ?? 'decoder_model',
|
|
7312
|
+
});
|
|
7313
|
+
}
|
|
7314
|
+
}
|
|
7315
|
+
//////////////////////////////////////////////////
|
|
7316
|
+
|
|
7317
|
+
|
|
7318
|
+
//////////////////////////////////////////////////
|
|
7319
|
+
// Snac models
|
|
7320
|
+
export class SnacPreTrainedModel extends PreTrainedModel {
|
|
7321
|
+
main_input_name = 'input_values';
|
|
7322
|
+
forward_params = ['input_values'];
|
|
7323
|
+
}
|
|
7324
|
+
|
|
7325
|
+
/**
|
|
7326
|
+
* The SNAC (Multi-Scale Neural Audio Codec) model.
|
|
7327
|
+
*/
|
|
7328
|
+
export class SnacModel extends SnacPreTrainedModel {
|
|
7329
|
+
/**
|
|
7330
|
+
* Encodes the input audio waveform into discrete codes.
|
|
7331
|
+
* @param {Object} inputs Model inputs
|
|
7332
|
+
* @param {Tensor} [inputs.input_values] Float values of the input audio waveform, of shape `(batch_size, channels, sequence_length)`).
|
|
7333
|
+
* @returns {Promise<Record<string, Tensor>>} The output tensors of shape `(batch_size, num_codebooks, sequence_length)`.
|
|
7334
|
+
*/
|
|
7335
|
+
async encode(inputs) {
|
|
7336
|
+
return await sessionRun(this.sessions['encoder_model'], inputs);
|
|
7337
|
+
}
|
|
7338
|
+
|
|
7339
|
+
/**
|
|
7340
|
+
* Decodes the given frames into an output audio waveform.
|
|
7341
|
+
* @param {Record<string, Tensor>} inputs The encoded audio codes.
|
|
7342
|
+
* @returns {Promise<{audio_values: Tensor}>} The output tensor of shape `(batch_size, num_channels, sequence_length)`.
|
|
7343
|
+
*/
|
|
7344
|
+
async decode(inputs) {
|
|
7345
|
+
return await sessionRun(this.sessions['decoder_model'], inputs);
|
|
7346
|
+
}
|
|
7347
|
+
}
|
|
7348
|
+
|
|
7349
|
+
export class SnacEncoderModel extends SnacPreTrainedModel {
|
|
7350
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
7351
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
7352
|
+
return super.from_pretrained(pretrained_model_name_or_path, {
|
|
7353
|
+
...options,
|
|
7354
|
+
// Update default model file name if not provided
|
|
7355
|
+
model_file_name: options.model_file_name ?? 'encoder_model',
|
|
7356
|
+
});
|
|
7357
|
+
}
|
|
7358
|
+
}
|
|
7359
|
+
export class SnacDecoderModel extends SnacPreTrainedModel {
|
|
7360
|
+
/** @type {typeof PreTrainedModel.from_pretrained} */
|
|
7361
|
+
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
|
|
7362
|
+
return super.from_pretrained(pretrained_model_name_or_path, {
|
|
7363
|
+
...options,
|
|
7364
|
+
// Update default model file name if not provided
|
|
7365
|
+
model_file_name: options.model_file_name ?? 'decoder_model',
|
|
7366
|
+
});
|
|
7367
|
+
}
|
|
7368
|
+
}
|
|
7369
|
+
//////////////////////////////////////////////////
|
|
6966
7370
|
|
|
6967
7371
|
//////////////////////////////////////////////////
|
|
6968
7372
|
// AutoModels, used to simplify construction of PreTrainedModels
|
|
@@ -7019,20 +7423,29 @@ export class PretrainedMixin {
|
|
|
7019
7423
|
if (!this.MODEL_CLASS_MAPPINGS) {
|
|
7020
7424
|
throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
|
|
7021
7425
|
}
|
|
7022
|
-
|
|
7426
|
+
const model_type = options.config.model_type;
|
|
7023
7427
|
for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
|
|
7024
|
-
|
|
7428
|
+
let modelInfo = MODEL_CLASS_MAPPING.get(model_type);
|
|
7025
7429
|
if (!modelInfo) {
|
|
7026
|
-
|
|
7430
|
+
// As a fallback, we check if model_type is specified as the exact class
|
|
7431
|
+
for (const cls of MODEL_CLASS_MAPPING.values()) {
|
|
7432
|
+
if (cls[0] === model_type) {
|
|
7433
|
+
modelInfo = cls;
|
|
7434
|
+
break;
|
|
7435
|
+
}
|
|
7436
|
+
}
|
|
7437
|
+
if (!modelInfo) continue; // Item not found in this mapping
|
|
7027
7438
|
}
|
|
7028
7439
|
return await modelInfo[1].from_pretrained(pretrained_model_name_or_path, options);
|
|
7029
7440
|
}
|
|
7030
7441
|
|
|
7031
7442
|
if (this.BASE_IF_FAIL) {
|
|
7032
|
-
|
|
7443
|
+
if (!(CUSTOM_ARCHITECTURES.has(model_type))) {
|
|
7444
|
+
console.warn(`Unknown model class "${model_type}", attempting to construct from base class.`);
|
|
7445
|
+
}
|
|
7033
7446
|
return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options);
|
|
7034
7447
|
} else {
|
|
7035
|
-
throw Error(`Unsupported model type: ${
|
|
7448
|
+
throw Error(`Unsupported model type: ${model_type}`)
|
|
7036
7449
|
}
|
|
7037
7450
|
}
|
|
7038
7451
|
}
|
|
@@ -7133,6 +7546,11 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
|
|
|
7133
7546
|
['blenderbot-small', ['BlenderbotSmallModel', BlenderbotSmallModel]],
|
|
7134
7547
|
]);
|
|
7135
7548
|
|
|
7549
|
+
const MODEL_MAPPING_NAMES_AUTO_ENCODER = new Map([
|
|
7550
|
+
['mimi', ['MimiModel', MimiModel]],
|
|
7551
|
+
['dac', ['DacModel', DacModel]],
|
|
7552
|
+
['snac', ['SnacModel', SnacModel]],
|
|
7553
|
+
]);
|
|
7136
7554
|
|
|
7137
7555
|
const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
|
|
7138
7556
|
['bloom', ['BloomModel', BloomModel]],
|
|
@@ -7152,6 +7570,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
|
|
|
7152
7570
|
['cohere', ['CohereModel', CohereModel]],
|
|
7153
7571
|
['gemma', ['GemmaModel', GemmaModel]],
|
|
7154
7572
|
['gemma2', ['Gemma2Model', Gemma2Model]],
|
|
7573
|
+
['gemma3_text', ['Gemma3Model', Gemma3Model]],
|
|
7155
7574
|
['helium', ['HeliumModel', HeliumModel]],
|
|
7156
7575
|
['glm', ['GlmModel', GlmModel]],
|
|
7157
7576
|
['openelm', ['OpenELMModel', OpenELMModel]],
|
|
@@ -7169,6 +7588,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
|
|
|
7169
7588
|
const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([
|
|
7170
7589
|
['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]],
|
|
7171
7590
|
['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
|
|
7591
|
+
['lite-whisper', ['LiteWhisperForConditionalGeneration', LiteWhisperForConditionalGeneration]],
|
|
7172
7592
|
['moonshine', ['MoonshineForConditionalGeneration', MoonshineForConditionalGeneration]],
|
|
7173
7593
|
]);
|
|
7174
7594
|
|
|
@@ -7250,6 +7670,7 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
|
|
|
7250
7670
|
['cohere', ['CohereForCausalLM', CohereForCausalLM]],
|
|
7251
7671
|
['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
|
|
7252
7672
|
['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
|
|
7673
|
+
['gemma3_text', ['Gemma3ForCausalLM', Gemma3ForCausalLM]],
|
|
7253
7674
|
['helium', ['HeliumForCausalLM', HeliumForCausalLM]],
|
|
7254
7675
|
['glm', ['GlmForCausalLM', GlmForCausalLM]],
|
|
7255
7676
|
['openelm', ['OpenELMForCausalLM', OpenELMForCausalLM]],
|
|
@@ -7315,6 +7736,7 @@ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
|
|
|
7315
7736
|
const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([
|
|
7316
7737
|
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
|
|
7317
7738
|
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
|
|
7739
|
+
['smolvlm', ['SmolVLMForConditionalGeneration', SmolVLMForConditionalGeneration]],
|
|
7318
7740
|
]);
|
|
7319
7741
|
|
|
7320
7742
|
const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
|
|
@@ -7324,9 +7746,15 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
|
|
|
7324
7746
|
['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]],
|
|
7325
7747
|
['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]],
|
|
7326
7748
|
['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]],
|
|
7749
|
+
['smolvlm', ['SmolVLMForConditionalGeneration', SmolVLMForConditionalGeneration]],
|
|
7327
7750
|
['paligemma', ['PaliGemmaForConditionalGeneration', PaliGemmaForConditionalGeneration]],
|
|
7328
7751
|
]);
|
|
7329
7752
|
|
|
7753
|
+
const MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
|
|
7754
|
+
['ultravox', ['UltravoxModel', UltravoxModel]],
|
|
7755
|
+
]);
|
|
7756
|
+
|
|
7757
|
+
|
|
7330
7758
|
const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
|
|
7331
7759
|
['vision-encoder-decoder', ['VisionEncoderDecoderModel', VisionEncoderDecoderModel]],
|
|
7332
7760
|
]);
|
|
@@ -7378,6 +7806,12 @@ const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
|
7378
7806
|
const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
7379
7807
|
['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
|
|
7380
7808
|
['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
|
|
7809
|
+
|
|
7810
|
+
['swin', ['SwinForSemanticSegmentation', SwinForSemanticSegmentation]],
|
|
7811
|
+
['mobilenet_v1', ['MobileNetV1ForSemanticSegmentation', MobileNetV1ForSemanticSegmentation]],
|
|
7812
|
+
['mobilenet_v2', ['MobileNetV2ForSemanticSegmentation', MobileNetV2ForSemanticSegmentation]],
|
|
7813
|
+
['mobilenet_v3', ['MobileNetV3ForSemanticSegmentation', MobileNetV3ForSemanticSegmentation]],
|
|
7814
|
+
['mobilenet_v4', ['MobileNetV4ForSemanticSegmentation', MobileNetV4ForSemanticSegmentation]],
|
|
7381
7815
|
]);
|
|
7382
7816
|
|
|
7383
7817
|
const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
|
|
@@ -7438,6 +7872,8 @@ const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
|
|
|
7438
7872
|
['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
|
|
7439
7873
|
['sapiens', ['SapiensForDepthEstimation', SapiensForDepthEstimation]],
|
|
7440
7874
|
['depth_pro', ['DepthProForDepthEstimation', DepthProForDepthEstimation]],
|
|
7875
|
+
['metric3d', ['Metric3DForDepthEstimation', Metric3DForDepthEstimation]],
|
|
7876
|
+
['metric3dv2', ['Metric3Dv2ForDepthEstimation', Metric3Dv2ForDepthEstimation]],
|
|
7441
7877
|
])
|
|
7442
7878
|
|
|
7443
7879
|
const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
|
|
@@ -7457,9 +7893,12 @@ const MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES = new Map([
|
|
|
7457
7893
|
])
|
|
7458
7894
|
|
|
7459
7895
|
const MODEL_CLASS_TYPE_MAPPING = [
|
|
7896
|
+
// MODEL_MAPPING_NAMES:
|
|
7460
7897
|
[MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_TYPES.EncoderOnly],
|
|
7461
7898
|
[MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_TYPES.EncoderDecoder],
|
|
7462
7899
|
[MODEL_MAPPING_NAMES_DECODER_ONLY, MODEL_TYPES.DecoderOnly],
|
|
7900
|
+
[MODEL_MAPPING_NAMES_AUTO_ENCODER, MODEL_TYPES.AutoEncoder],
|
|
7901
|
+
|
|
7463
7902
|
[MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
7464
7903
|
[MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
7465
7904
|
[MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
|
|
@@ -7470,6 +7909,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
|
|
|
7470
7909
|
[MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
7471
7910
|
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
|
|
7472
7911
|
[MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText],
|
|
7912
|
+
[MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.AudioTextToText],
|
|
7473
7913
|
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
7474
7914
|
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
7475
7915
|
[MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
|
|
@@ -7514,6 +7954,13 @@ const CUSTOM_MAPPING = [
|
|
|
7514
7954
|
['JinaCLIPTextModel', JinaCLIPTextModel, MODEL_TYPES.EncoderOnly],
|
|
7515
7955
|
['ClapTextModelWithProjection', ClapTextModelWithProjection, MODEL_TYPES.EncoderOnly],
|
|
7516
7956
|
['ClapAudioModelWithProjection', ClapAudioModelWithProjection, MODEL_TYPES.EncoderOnly],
|
|
7957
|
+
|
|
7958
|
+
['DacEncoderModel', DacEncoderModel, MODEL_TYPES.EncoderOnly],
|
|
7959
|
+
['DacDecoderModel', DacDecoderModel, MODEL_TYPES.EncoderOnly],
|
|
7960
|
+
['MimiEncoderModel', MimiEncoderModel, MODEL_TYPES.EncoderOnly],
|
|
7961
|
+
['MimiDecoderModel', MimiDecoderModel, MODEL_TYPES.EncoderOnly],
|
|
7962
|
+
['SnacEncoderModel', SnacEncoderModel, MODEL_TYPES.EncoderOnly],
|
|
7963
|
+
['SnacDecoderModel', SnacDecoderModel, MODEL_TYPES.EncoderOnly],
|
|
7517
7964
|
]
|
|
7518
7965
|
for (const [name, model, type] of CUSTOM_MAPPING) {
|
|
7519
7966
|
MODEL_TYPE_MAPPING.set(name, type);
|
|
@@ -7521,6 +7968,19 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
|
|
|
7521
7968
|
MODEL_NAME_TO_CLASS_MAPPING.set(name, model);
|
|
7522
7969
|
}
|
|
7523
7970
|
|
|
7971
|
+
const CUSTOM_ARCHITECTURES = new Map([
|
|
7972
|
+
['modnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
|
|
7973
|
+
['birefnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
|
|
7974
|
+
['isnet', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
|
|
7975
|
+
['ben', MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES],
|
|
7976
|
+
]);
|
|
7977
|
+
for (const [name, mapping] of CUSTOM_ARCHITECTURES.entries()) {
|
|
7978
|
+
mapping.set(name, ['PreTrainedModel', PreTrainedModel])
|
|
7979
|
+
MODEL_TYPE_MAPPING.set(name, MODEL_TYPES.EncoderOnly);
|
|
7980
|
+
MODEL_CLASS_TO_NAME_MAPPING.set(PreTrainedModel, name);
|
|
7981
|
+
MODEL_NAME_TO_CLASS_MAPPING.set(name, PreTrainedModel);
|
|
7982
|
+
}
|
|
7983
|
+
|
|
7524
7984
|
|
|
7525
7985
|
/**
|
|
7526
7986
|
* Helper class which is used to instantiate pretrained models with the `from_pretrained` function.
|
|
@@ -7761,6 +8221,14 @@ export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
|
|
|
7761
8221
|
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
|
|
7762
8222
|
}
|
|
7763
8223
|
|
|
8224
|
+
export class AutoModelForImageTextToText extends PretrainedMixin {
|
|
8225
|
+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES];
|
|
8226
|
+
}
|
|
8227
|
+
|
|
8228
|
+
export class AutoModelForAudioTextToText extends PretrainedMixin {
|
|
8229
|
+
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES];
|
|
8230
|
+
}
|
|
8231
|
+
|
|
7764
8232
|
//////////////////////////////////////////////////
|
|
7765
8233
|
|
|
7766
8234
|
//////////////////////////////////////////////////
|