@huggingface/transformers 3.0.0-alpha.9 → 3.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. package/README.md +82 -50
  2. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  3. package/dist/transformers.cjs +2550 -2552
  4. package/dist/transformers.cjs.map +1 -1
  5. package/dist/transformers.js +3639 -3567
  6. package/dist/transformers.js.map +1 -1
  7. package/dist/transformers.min.cjs +25 -25
  8. package/dist/transformers.min.cjs.map +1 -1
  9. package/dist/transformers.min.js +41 -42
  10. package/dist/transformers.min.js.map +1 -1
  11. package/dist/transformers.min.mjs +56 -57
  12. package/dist/transformers.min.mjs.map +1 -1
  13. package/dist/transformers.mjs +2586 -2564
  14. package/dist/transformers.mjs.map +1 -1
  15. package/package.json +14 -13
  16. package/src/backends/onnx.js +24 -19
  17. package/src/configs.js +19 -4
  18. package/src/env.js +5 -9
  19. package/src/generation/logits_process.js +40 -37
  20. package/src/models.js +356 -539
  21. package/src/ops/registry.js +14 -3
  22. package/src/pipelines.js +5 -5
  23. package/src/processors.js +392 -351
  24. package/src/tokenizers.js +140 -175
  25. package/src/utils/constants.js +1 -1
  26. package/src/utils/core.js +12 -0
  27. package/src/utils/data-structures.js +13 -11
  28. package/src/utils/hub.js +1 -1
  29. package/src/utils/maths.js +14 -5
  30. package/src/utils/tensor.js +60 -13
  31. package/types/backends/onnx.d.ts +5 -2
  32. package/types/backends/onnx.d.ts.map +1 -1
  33. package/types/configs.d.ts +29 -3
  34. package/types/configs.d.ts.map +1 -1
  35. package/types/env.d.ts +4 -2
  36. package/types/env.d.ts.map +1 -1
  37. package/types/generation/logits_process.d.ts.map +1 -1
  38. package/types/models.d.ts +116 -289
  39. package/types/models.d.ts.map +1 -1
  40. package/types/ops/registry.d.ts +6 -6
  41. package/types/ops/registry.d.ts.map +1 -1
  42. package/types/pipelines.d.ts +1 -2
  43. package/types/pipelines.d.ts.map +1 -1
  44. package/types/processors.d.ts +58 -51
  45. package/types/processors.d.ts.map +1 -1
  46. package/types/tokenizers.d.ts +23 -32
  47. package/types/tokenizers.d.ts.map +1 -1
  48. package/types/utils/constants.d.ts +1 -1
  49. package/types/utils/constants.d.ts.map +1 -1
  50. package/types/utils/core.d.ts +7 -0
  51. package/types/utils/core.d.ts.map +1 -1
  52. package/types/utils/data-structures.d.ts +6 -6
  53. package/types/utils/data-structures.d.ts.map +1 -1
  54. package/types/utils/hub.d.ts +1 -1
  55. package/types/utils/hub.d.ts.map +1 -1
  56. package/types/utils/maths.d.ts +2 -2
  57. package/types/utils/maths.d.ts.map +1 -1
  58. package/types/utils/tensor.d.ts +27 -1
  59. package/types/utils/tensor.d.ts.map +1 -1
package/src/models.js CHANGED
@@ -71,6 +71,10 @@ import {
71
71
  getModelJSON,
72
72
  } from './utils/hub.js';
73
73
 
74
+ import {
75
+ GITHUB_ISSUE_URL,
76
+ } from './utils/constants.js';
77
+
74
78
  import {
75
79
  LogitsProcessorList,
76
80
  ForcedBOSTokenLogitsProcessor,
@@ -142,11 +146,12 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
142
146
  * @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
143
147
  * @param {string} fileName The name of the model file.
144
148
  * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
145
- * @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
149
+ * @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
146
150
  * @private
147
151
  */
148
152
  async function getSession(pretrained_model_name_or_path, fileName, options) {
149
- let device = options.device;
153
+ const custom_config = options.config?.['transformers.js_config'] ?? {};
154
+ let device = options.device ?? custom_config.device;
150
155
  if (device && typeof device !== 'string') {
151
156
  if (device.hasOwnProperty(fileName)) {
152
157
  device = device[fileName];
@@ -164,7 +169,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
164
169
 
165
170
  // If options.dtype is specified, we use it to choose the suffix for the model file.
166
171
  // Otherwise, we use the default dtype for the device.
167
- let dtype = options.dtype;
172
+ let dtype = options.dtype ?? custom_config.dtype;
168
173
  if (typeof dtype !== 'string') {
169
174
  if (dtype && dtype.hasOwnProperty(fileName)) {
170
175
  dtype = dtype[fileName];
@@ -182,27 +187,54 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
182
187
  throw new Error(`The device (${selectedDevice}) does not support fp16.`);
183
188
  }
184
189
 
190
+ // Only valid for models with a decoder
191
+ const kv_cache_dtype = custom_config.kv_cache_dtype
192
+ ? (typeof custom_config.kv_cache_dtype === 'string'
193
+ ? custom_config.kv_cache_dtype
194
+ : custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
195
+ : undefined;
196
+
197
+ if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
198
+ throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
199
+ }
200
+
201
+ const session_config = {
202
+ dtype: selectedDtype,
203
+ kv_cache_dtype,
204
+ }
205
+
185
206
  // Construct the model file name
186
207
  const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
187
208
  const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
188
209
 
189
- const session_options = { ...options.session_options } ?? {};
210
+ const session_options = { ...options.session_options };
190
211
 
191
212
  // Overwrite `executionProviders` if not specified
192
213
  session_options.executionProviders ??= executionProviders;
193
214
 
215
+ // Overwrite `freeDimensionOverrides` if specified in config and not set in session options
216
+ const free_dimension_overrides = custom_config.free_dimension_overrides;
217
+ if (free_dimension_overrides) {
218
+ session_options.freeDimensionOverrides ??= free_dimension_overrides;
219
+ } else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) {
220
+ console.warn(
221
+ 'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' +
222
+ 'When `free_dimension_overrides` is not set, you may experience significant performance degradation.'
223
+ );
224
+ }
194
225
 
195
226
  const bufferPromise = getModelFile(pretrained_model_name_or_path, modelFileName, true, options);
196
227
 
197
228
  // handle onnx external data files
229
+ const use_external_data_format = options.use_external_data_format ?? custom_config.use_external_data_format;
198
230
  /** @type {Promise<{path: string, data: Uint8Array}>[]} */
199
231
  let externalDataPromises = [];
200
- if (options.use_external_data_format && (
201
- options.use_external_data_format === true ||
232
+ if (use_external_data_format && (
233
+ use_external_data_format === true ||
202
234
  (
203
- typeof options.use_external_data_format === 'object' &&
204
- options.use_external_data_format.hasOwnProperty(fileName) &&
205
- options.use_external_data_format[fileName] === true
235
+ typeof use_external_data_format === 'object' &&
236
+ use_external_data_format.hasOwnProperty(fileName) &&
237
+ use_external_data_format[fileName] === true
206
238
  )
207
239
  )) {
208
240
  if (apis.IS_NODE_ENV) {
@@ -236,6 +268,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
236
268
  });
237
269
  if (Object.keys(shapes).length > 0 && !isONNXProxy()) {
238
270
  // Only set preferredOutputLocation if shapes are present and we aren't proxying ONNX
271
+ /** @type {Record<string, import('onnxruntime-common').Tensor.DataLocation>} */
239
272
  const preferredOutputLocation = {};
240
273
  for (const key in shapes) {
241
274
  preferredOutputLocation[key] = 'gpu-buffer';
@@ -245,7 +278,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
245
278
  }
246
279
 
247
280
  const buffer = await bufferPromise;
248
- return { buffer, session_options };
281
+
282
+ return { buffer, session_options, session_config };
249
283
  }
250
284
 
251
285
  /**
@@ -260,13 +294,30 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
260
294
  async function constructSessions(pretrained_model_name_or_path, names, options) {
261
295
  return Object.fromEntries(await Promise.all(
262
296
  Object.keys(names).map(async (name) => {
263
- const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options);
264
- const session = await createInferenceSession(buffer, session_options);
297
+ const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
298
+ const session = await createInferenceSession(buffer, session_options, session_config);
265
299
  return [name, session];
266
300
  })
267
301
  ));
268
302
  }
269
303
 
304
+ /**
305
+ * Helper function to load multiple optional configuration files
306
+ * @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
307
+ * @param {Record<string, string>} names The names of the config files to load.
308
+ * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
309
+ * @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects.
310
+ * @private
311
+ */
312
+ async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
313
+ return Object.fromEntries(await Promise.all(
314
+ Object.keys(names).map(async (name) => {
315
+ const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options);
316
+ return [name, config];
317
+ })
318
+ ));
319
+ }
320
+
270
321
  /**
271
322
  * Validate model inputs
272
323
  * @param {Object} session The InferenceSession object that will be run.
@@ -360,7 +411,7 @@ function replaceTensors(obj) {
360
411
 
361
412
  /**
362
413
  * Converts an array or Tensor of integers to an int64 Tensor.
363
- * @param {Array|Tensor} items The input integers to be converted.
414
+ * @param {any[]|Tensor} items The input integers to be converted.
364
415
  * @returns {Tensor} The int64 Tensor with the converted values.
365
416
  * @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
366
417
  * @private
@@ -393,37 +444,6 @@ function toI64Tensor(items) {
393
444
  }
394
445
  }
395
446
 
396
- /**
397
- * Prepares an attention mask for a sequence of tokens based on configuration options.
398
- * @param {Object} self The calling object instance.
399
- * @param {Tensor} tokens The input tokens.
400
- * @returns {Tensor} The attention mask tensor.
401
- * @private
402
- */
403
- function prepareAttentionMask(self, tokens) {
404
-
405
- // Prepare attention mask
406
- let pad_token_id = self.config.pad_token_id ?? null;
407
- let eos_token_id = self.config.eos_token_id ?? null;
408
- if (isIntegralNumber(eos_token_id)) {
409
- eos_token_id = [eos_token_id];
410
- }
411
-
412
- let is_pad_token_in_inputs = tokens.indexOf(pad_token_id) !== -1;
413
- let is_pad_token_not_equal_to_eos_token_id = (eos_token_id === null) || !eos_token_id.includes(pad_token_id)
414
-
415
- if (is_pad_token_in_inputs && is_pad_token_not_equal_to_eos_token_id) {
416
- let data = BigInt64Array.from(
417
- // Note: != so that int matches bigint
418
- // @ts-ignore
419
- tokens.data.map(x => x != pad_token_id)
420
- )
421
- return new Tensor('int64', data, tokens.dims)
422
- } else {
423
- return ones_like(tokens);
424
- }
425
- }
426
-
427
447
  /**
428
448
  * Creates a boolean tensor with a single value.
429
449
  * @param {boolean} value The value of the tensor.
@@ -694,8 +714,8 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
694
714
  } else {
695
715
  return decoder_prepare_inputs_for_generation(self, ...args);
696
716
  }
697
-
698
717
  }
718
+
699
719
  //////////////////////////////////////////////////
700
720
 
701
721
  //////////////////////////////////////////////////
@@ -709,12 +729,14 @@ export class PreTrainedModel extends Callable {
709
729
  * Creates a new instance of the `PreTrainedModel` class.
710
730
  * @param {import('./configs.js').PretrainedConfig} config The model configuration.
711
731
  * @param {Record<string, any>} sessions The inference sessions for the model.
732
+ * @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json).
712
733
  */
713
- constructor(config, sessions) {
734
+ constructor(config, sessions, configs) {
714
735
  super();
715
736
 
716
737
  this.config = config;
717
738
  this.sessions = sessions;
739
+ this.configs = configs;
718
740
 
719
741
  const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
720
742
  const modelType = MODEL_TYPE_MAPPING.get(modelName);
@@ -830,7 +852,9 @@ export class PreTrainedModel extends Callable {
830
852
  constructSessions(pretrained_model_name_or_path, {
831
853
  model: options.model_file_name ?? 'model',
832
854
  }, options),
833
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
855
+ getOptionalConfigs(pretrained_model_name_or_path, {
856
+ generation_config: 'generation_config.json',
857
+ }, options),
834
858
  ]);
835
859
 
836
860
  } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
@@ -839,7 +863,9 @@ export class PreTrainedModel extends Callable {
839
863
  model: 'encoder_model',
840
864
  decoder_model_merged: 'decoder_model_merged',
841
865
  }, options),
842
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
866
+ getOptionalConfigs(pretrained_model_name_or_path, {
867
+ generation_config: 'generation_config.json',
868
+ }, options),
843
869
  ]);
844
870
 
845
871
  } else if (modelType === MODEL_TYPES.MaskGeneration) {
@@ -869,7 +895,9 @@ export class PreTrainedModel extends Callable {
869
895
  }
870
896
  info = await Promise.all([
871
897
  constructSessions(pretrained_model_name_or_path, sessions, options),
872
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
898
+ getOptionalConfigs(pretrained_model_name_or_path, {
899
+ generation_config: 'generation_config.json',
900
+ }, options),
873
901
  ]);
874
902
 
875
903
  } else if (modelType === MODEL_TYPES.Musicgen) {
@@ -879,12 +907,14 @@ export class PreTrainedModel extends Callable {
879
907
  decoder_model_merged: 'decoder_model_merged',
880
908
  encodec_decode: 'encodec_decode',
881
909
  }, options),
882
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
910
+ getOptionalConfigs(pretrained_model_name_or_path, {
911
+ generation_config: 'generation_config.json',
912
+ }, options),
883
913
  ]);
884
914
 
885
915
  } else { // should be MODEL_TYPES.EncoderOnly
886
916
  if (modelType !== MODEL_TYPES.EncoderOnly) {
887
- console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at https://github.com/xenova/transformers.js/issues/new/choose.`)
917
+ console.warn(`Model type for '${modelName ?? config?.model_type}' not found, assuming encoder-only architecture. Please report this at ${GITHUB_ISSUE_URL}.`)
888
918
  }
889
919
  info = await Promise.all([
890
920
  constructSessions(pretrained_model_name_or_path, {
@@ -917,6 +947,14 @@ export class PreTrainedModel extends Callable {
917
947
  return await this._forward(this, model_inputs);
918
948
  }
919
949
 
950
+ /**
951
+ * Get the model's generation config, if it exists.
952
+ * @returns {GenerationConfig|null} The model's generation config if it exists, otherwise `null`.
953
+ */
954
+ get generation_config() {
955
+ return this.configs?.generation_config ?? null;
956
+ }
957
+
920
958
  /**
921
959
  * This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
922
960
  * instances used for multinomial sampling.
@@ -1096,9 +1134,7 @@ export class PreTrainedModel extends Callable {
1096
1134
  const gen_config = new cls(config);
1097
1135
 
1098
1136
  // Apply model's generation config, if it exists
1099
- if ('generation_config' in this) {
1100
- Object.assign(gen_config, this.generation_config);
1101
- }
1137
+ Object.assign(gen_config, this.generation_config ?? {});
1102
1138
 
1103
1139
  // Next, use any generation config specified by the user
1104
1140
  // when calling `generate`
@@ -1298,35 +1334,37 @@ export class PreTrainedModel extends Callable {
1298
1334
  let { decoder_input_ids, ...model_inputs } = model_kwargs;
1299
1335
 
1300
1336
  // Prepare input ids if the user has not defined `decoder_input_ids` manually.
1301
- if (!decoder_input_ids) {
1302
- decoder_start_token_id ??= bos_token_id;
1303
-
1304
- if (this.config.model_type === 'musicgen') {
1305
- // Custom logic (TODO: move to Musicgen class)
1306
- decoder_input_ids = Array.from({
1307
- length: batch_size * this.config.decoder.num_codebooks
1308
- }, () => [decoder_start_token_id]);
1309
-
1310
- } else if (Array.isArray(decoder_start_token_id)) {
1311
- if (decoder_start_token_id.length !== batch_size) {
1312
- throw new Error(
1313
- `\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
1314
- )
1337
+ if (!(decoder_input_ids instanceof Tensor)) {
1338
+ if (!decoder_input_ids) {
1339
+ decoder_start_token_id ??= bos_token_id;
1340
+
1341
+ if (this.config.model_type === 'musicgen') {
1342
+ // Custom logic (TODO: move to Musicgen class)
1343
+ decoder_input_ids = Array.from({
1344
+ length: batch_size * this.config.decoder.num_codebooks
1345
+ }, () => [decoder_start_token_id]);
1346
+
1347
+ } else if (Array.isArray(decoder_start_token_id)) {
1348
+ if (decoder_start_token_id.length !== batch_size) {
1349
+ throw new Error(
1350
+ `\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
1351
+ )
1352
+ }
1353
+ decoder_input_ids = decoder_start_token_id;
1354
+ } else {
1355
+ decoder_input_ids = Array.from({
1356
+ length: batch_size,
1357
+ }, () => [decoder_start_token_id]);
1315
1358
  }
1316
- decoder_input_ids = decoder_start_token_id;
1317
- } else {
1359
+ } else if (!Array.isArray(decoder_input_ids[0])) {
1360
+ // Correct batch size
1318
1361
  decoder_input_ids = Array.from({
1319
1362
  length: batch_size,
1320
- }, () => [decoder_start_token_id]);
1363
+ }, () => decoder_input_ids);
1321
1364
  }
1322
- } else if (!Array.isArray(decoder_input_ids[0])) {
1323
- // Correct batch size
1324
- decoder_input_ids = Array.from({
1325
- length: batch_size,
1326
- }, () => decoder_input_ids);
1365
+ decoder_input_ids = toI64Tensor(decoder_input_ids);
1327
1366
  }
1328
1367
 
1329
- decoder_input_ids = toI64Tensor(decoder_input_ids);
1330
1368
  model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids);
1331
1369
 
1332
1370
  return { input_ids: decoder_input_ids, model_inputs };
@@ -1458,13 +1496,12 @@ export class PreTrainedModel extends Callable {
1458
1496
  // - GenerationMode.BEAM_SEARCH
1459
1497
  // - GenerationMode.BEAM_SAMPLE
1460
1498
  ////////////////////////////////////////////////////
1461
- let past_key_values = null;
1499
+ let outputs;
1462
1500
  let attentions = {};
1463
1501
  while (true) {
1464
1502
  // prepare model inputs
1465
1503
  model_inputs = this.prepare_inputs_for_generation(all_input_ids, model_inputs, generation_config);
1466
-
1467
- const outputs = await this.forward(model_inputs);
1504
+ outputs = await this.forward(model_inputs);
1468
1505
 
1469
1506
  if (generation_config.output_attentions && generation_config.return_dict_in_generate) {
1470
1507
  // Get attentions if they are present
@@ -1511,10 +1548,6 @@ export class PreTrainedModel extends Callable {
1511
1548
 
1512
1549
  const stop = prepared_stopping_criteria(all_input_ids);
1513
1550
  if (stop.every(x => x)) {
1514
- if (generation_config.return_dict_in_generate) {
1515
- // Get past key values without disposing buffers
1516
- past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, false);
1517
- }
1518
1551
  break;
1519
1552
  }
1520
1553
 
@@ -1527,6 +1560,9 @@ export class PreTrainedModel extends Callable {
1527
1560
  streamer.end();
1528
1561
  }
1529
1562
 
1563
+ // Retrieve and dispose all final past key values (including encoder attentions)
1564
+ const past_key_values = this.getPastKeyValues(outputs, model_inputs.past_key_values, true);
1565
+
1530
1566
  // TODO: ensure all_input_ids is padded correctly...
1531
1567
  const sequences = new Tensor('int64', all_input_ids.flat(), [all_input_ids.length, all_input_ids[0].length]);
1532
1568
 
@@ -1540,6 +1576,12 @@ export class PreTrainedModel extends Callable {
1540
1576
  // logits,
1541
1577
  }
1542
1578
  } else {
1579
+ // Dispose all remaining tensors
1580
+ for (const tensor of Object.values(outputs)) {
1581
+ if (tensor.location === 'gpu-buffer') {
1582
+ tensor.dispose();
1583
+ }
1584
+ }
1543
1585
  return sequences;
1544
1586
  }
1545
1587
  }
@@ -1549,31 +1591,32 @@ export class PreTrainedModel extends Callable {
1549
1591
  *
1550
1592
  * @param {Object} decoderResults The decoder results object.
1551
1593
  * @param {Object} pastKeyValues The previous past key values.
1552
- * @param {boolean} [dispose=true] Whether to dispose of the old gpu buffer.
1553
1594
  * @returns {Object} An object containing past key values.
1554
1595
  */
1555
- getPastKeyValues(decoderResults, pastKeyValues, dispose = true) {
1596
+ getPastKeyValues(decoderResults, pastKeyValues, disposeEncoderPKVs = false) {
1556
1597
  const pkvs = Object.create(null);
1557
1598
 
1558
1599
  for (const name in decoderResults) {
1559
1600
  if (name.startsWith('present')) {
1560
1601
  const newName = name.replace('present', 'past_key_values');
1561
-
1562
- if (pastKeyValues && name.includes('encoder')) {
1563
- // Optimization introduced by optimum to reuse past key values. So, we just replace the constant
1564
- // outputs with the previous past key values.
1602
+ const is_encoder_pkv = name.includes('encoder');
1603
+ if (is_encoder_pkv && pastKeyValues) {
1604
+ // Optimization introduced by optimum to reuse past key values.
1605
+ // So, we just replace the constant outputs (`decoderResults[name]`) with the previous past key values.
1565
1606
  // https://github.com/huggingface/optimum/blob/0bf2c05fb7e1182b52d21b703cfc95fd9e4ea3dc/optimum/onnxruntime/base.py#L677-L704
1566
1607
  pkvs[newName] = pastKeyValues[newName];
1567
- } else {
1568
- if (dispose && pastKeyValues) {
1569
- // Free old gpu buffer
1570
- const t = pastKeyValues[newName];
1571
- if (t.location === 'gpu-buffer') {
1572
- t.dispose();
1573
- }
1574
- }
1608
+ } else { // decoder or using first encoder PKVs
1575
1609
  pkvs[newName] = decoderResults[name];
1576
1610
  }
1611
+
1612
+ if (pastKeyValues && (!is_encoder_pkv || disposeEncoderPKVs)) {
1613
+ // - Always dispose decoder PKVs
1614
+ // - Only dispose encoder past key values when requested (after generation)
1615
+ const t = pastKeyValues[newName];
1616
+ if (t.location === 'gpu-buffer') {
1617
+ t.dispose();
1618
+ }
1619
+ }
1577
1620
  }
1578
1621
  }
1579
1622
  return pkvs;
@@ -1611,9 +1654,8 @@ export class PreTrainedModel extends Callable {
1611
1654
  if (pastKeyValues) {
1612
1655
  Object.assign(decoderFeeds, pastKeyValues)
1613
1656
  } else {
1614
-
1615
- /** @type {import('./transformers.js').DataType} */
1616
- const dtype = this.custom_config.kv_cache_dtype ?? 'float32';
1657
+ const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
1658
+ const dtype = session?.config?.kv_cache_dtype ?? 'float32';
1617
1659
  const empty = (dtype === 'float16') ? new Uint16Array() : [];
1618
1660
 
1619
1661
  const shapes = getKeyValueShapes(this.config);
@@ -2506,17 +2548,6 @@ export class T5PreTrainedModel extends PreTrainedModel {
2506
2548
  'decoder_attention_mask',
2507
2549
  'past_key_values',
2508
2550
  ];
2509
-
2510
- /**
2511
- * Creates a new instance of the `T5PreTrainedModel` class.
2512
- * @param {Object} config The model configuration.
2513
- * @param {Record<string, any>} sessions The inference sessions for the model.
2514
- * @param {GenerationConfig} generation_config The generation configuration.
2515
- */
2516
- constructor(config, sessions, generation_config) {
2517
- super(config, sessions);
2518
- this.generation_config = generation_config;
2519
- }
2520
2551
  };
2521
2552
 
2522
2553
  export class T5Model extends T5PreTrainedModel { }
@@ -2533,18 +2564,7 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { }
2533
2564
  /**
2534
2565
  * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
2535
2566
  */
2536
- export class LongT5PreTrainedModel extends PreTrainedModel {
2537
- /**
2538
- * Creates a new instance of the `LongT5ForConditionalGeneration` class.
2539
- * @param {Object} config The model configuration.
2540
- * @param {Record<string, any>} sessions The inference sessions for the model.
2541
- * @param {GenerationConfig} generation_config The generation configuration.
2542
- */
2543
- constructor(config, sessions, generation_config) {
2544
- super(config, sessions);
2545
- this.generation_config = generation_config;
2546
- }
2547
- };
2567
+ export class LongT5PreTrainedModel extends PreTrainedModel { };
2548
2568
 
2549
2569
  /**
2550
2570
  * The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.
@@ -2560,19 +2580,7 @@ export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { }
2560
2580
 
2561
2581
  //////////////////////////////////////////////////
2562
2582
  // MT5 models
2563
- export class MT5PreTrainedModel extends PreTrainedModel {
2564
-
2565
- /**
2566
- * Creates a new instance of the `MT5ForConditionalGeneration` class.
2567
- * @param {Object} config The model configuration.
2568
- * @param {Record<string, any>} sessions The inference sessions for the model.
2569
- * @param {GenerationConfig} generation_config The generation configuration.
2570
- */
2571
- constructor(config, sessions, generation_config) {
2572
- super(config, sessions);
2573
- this.generation_config = generation_config;
2574
- }
2575
- };
2583
+ export class MT5PreTrainedModel extends PreTrainedModel { };
2576
2584
 
2577
2585
  export class MT5Model extends MT5PreTrainedModel { }
2578
2586
 
@@ -2584,19 +2592,7 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { }
2584
2592
 
2585
2593
  //////////////////////////////////////////////////
2586
2594
  // Bart models
2587
- export class BartPretrainedModel extends PreTrainedModel {
2588
-
2589
- /**
2590
- * Creates a new instance of the `BartForConditionalGeneration` class.
2591
- * @param {Object} config The model configuration.
2592
- * @param {Record<string, any>} sessions The inference sessions for the model.
2593
- * @param {GenerationConfig} generation_config The generation configuration.
2594
- */
2595
- constructor(config, sessions, generation_config) {
2596
- super(config, sessions);
2597
- this.generation_config = generation_config;
2598
- }
2599
- };
2595
+ export class BartPretrainedModel extends PreTrainedModel { };
2600
2596
 
2601
2597
  /**
2602
2598
  * The bare BART Model outputting raw hidden-states without any specific head on top.
@@ -2627,19 +2623,7 @@ export class BartForSequenceClassification extends BartPretrainedModel {
2627
2623
 
2628
2624
  //////////////////////////////////////////////////
2629
2625
  // MBart models
2630
- export class MBartPreTrainedModel extends PreTrainedModel {
2631
-
2632
- /**
2633
- * Creates a new instance of the `MBartForConditionalGeneration` class.
2634
- * @param {Object} config The model configuration.
2635
- * @param {Record<string, any>} sessions The inference sessions for the model.
2636
- * @param {GenerationConfig} generation_config The generation configuration.
2637
- */
2638
- constructor(config, sessions, generation_config) {
2639
- super(config, sessions);
2640
- this.generation_config = generation_config;
2641
- }
2642
- };
2626
+ export class MBartPreTrainedModel extends PreTrainedModel { };
2643
2627
 
2644
2628
  /**
2645
2629
  * The bare MBART Model outputting raw hidden-states without any specific head on top.
@@ -2673,19 +2657,7 @@ export class MBartForCausalLM extends MBartPreTrainedModel { }
2673
2657
 
2674
2658
  //////////////////////////////////////////////////
2675
2659
  // Blenderbot models
2676
- export class BlenderbotPreTrainedModel extends PreTrainedModel {
2677
-
2678
- /**
2679
- * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
2680
- * @param {Object} config The model configuration.
2681
- * @param {Record<string, any>} sessions The inference sessions for the model.
2682
- * @param {GenerationConfig} generation_config The generation configuration.
2683
- */
2684
- constructor(config, sessions, generation_config) {
2685
- super(config, sessions);
2686
- this.generation_config = generation_config;
2687
- }
2688
- };
2660
+ export class BlenderbotPreTrainedModel extends PreTrainedModel { };
2689
2661
 
2690
2662
  /**
2691
2663
  * The bare Blenderbot Model outputting raw hidden-states without any specific head on top.
@@ -2701,19 +2673,7 @@ export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedMode
2701
2673
 
2702
2674
  //////////////////////////////////////////////////
2703
2675
  // Blenderbot models
2704
- export class BlenderbotSmallPreTrainedModel extends PreTrainedModel {
2705
-
2706
- /**
2707
- * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
2708
- * @param {Object} config The model configuration.
2709
- * @param {Record<string, any>} sessions The inference sessions for the model.
2710
- * @param {GenerationConfig} generation_config The generation configuration.
2711
- */
2712
- constructor(config, sessions, generation_config) {
2713
- super(config, sessions);
2714
- this.generation_config = generation_config;
2715
- }
2716
- };
2676
+ export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { };
2717
2677
 
2718
2678
  /**
2719
2679
  * The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top.
@@ -2962,17 +2922,6 @@ export class WhisperPreTrainedModel extends PreTrainedModel {
2962
2922
  'decoder_attention_mask',
2963
2923
  'past_key_values',
2964
2924
  ];
2965
-
2966
- /**
2967
- * Creates a new instance of the `WhisperPreTrainedModel` class.
2968
- * @param {Object} config The model configuration.
2969
- * @param {Record<string, any>} sessions The inference sessions for the model.
2970
- * @param {GenerationConfig} generation_config The generation configuration.
2971
- */
2972
- constructor(config, sessions, generation_config) {
2973
- super(config, sessions);
2974
- this.generation_config = generation_config;
2975
- }
2976
2925
  };
2977
2926
 
2978
2927
  /**
@@ -3238,21 +3187,14 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
3238
3187
  export class VisionEncoderDecoderModel extends PreTrainedModel {
3239
3188
  main_input_name = 'pixel_values';
3240
3189
  forward_params = [
3190
+ // Encoder inputs
3241
3191
  'pixel_values',
3242
- 'input_ids',
3192
+
3193
+ // Decoder inpputs
3194
+ 'decoder_input_ids',
3243
3195
  'encoder_hidden_states',
3244
3196
  'past_key_values',
3245
3197
  ];
3246
- /**
3247
- * Creates a new instance of the `VisionEncoderDecoderModel` class.
3248
- * @param {Object} config The model configuration.
3249
- * @param {Record<string, any>} sessions The inference sessions for the model.
3250
- * @param {GenerationConfig} generation_config The generation configuration.
3251
- */
3252
- constructor(config, sessions, generation_config) {
3253
- super(config, sessions);
3254
- this.generation_config = generation_config;
3255
- }
3256
3198
  }
3257
3199
  //////////////////////////////////////////////////
3258
3200
 
@@ -3267,11 +3209,6 @@ export class LlavaPreTrainedModel extends PreTrainedModel {
3267
3209
  'position_ids',
3268
3210
  'past_key_values',
3269
3211
  ];
3270
-
3271
- constructor(config, sessions, generation_config) {
3272
- super(config, sessions);
3273
- this.generation_config = generation_config;
3274
- }
3275
3212
  }
3276
3213
 
3277
3214
  /**
@@ -3358,11 +3295,6 @@ export class Florence2PreTrainedModel extends PreTrainedModel {
3358
3295
  'past_key_values',
3359
3296
  ];
3360
3297
  main_input_name = 'inputs_embeds';
3361
-
3362
- constructor(config, sessions, generation_config) {
3363
- super(config, sessions);
3364
- this.generation_config = generation_config;
3365
- }
3366
3298
  }
3367
3299
 
3368
3300
  export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel {
@@ -3501,6 +3433,18 @@ export class CLIPPreTrainedModel extends PreTrainedModel { }
3501
3433
  */
3502
3434
  export class CLIPModel extends CLIPPreTrainedModel { }
3503
3435
 
3436
+ /**
3437
+ * The text model from CLIP without any head or projection on top.
3438
+ */
3439
+ export class CLIPTextModel extends CLIPPreTrainedModel {
3440
+ /** @type {PreTrainedModel.from_pretrained} */
3441
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3442
+ // Update default model file name if not provided
3443
+ options.model_file_name ??= 'text_model';
3444
+ return super.from_pretrained(pretrained_model_name_or_path, options);
3445
+ }
3446
+ }
3447
+
3504
3448
  /**
3505
3449
  * CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output)
3506
3450
  *
@@ -3528,7 +3472,6 @@ export class CLIPModel extends CLIPPreTrainedModel { }
3528
3472
  * ```
3529
3473
  */
3530
3474
  export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
3531
-
3532
3475
  /** @type {PreTrainedModel.from_pretrained} */
3533
3476
  static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3534
3477
  // Update default model file name if not provided
@@ -3537,6 +3480,18 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
3537
3480
  }
3538
3481
  }
3539
3482
 
3483
+ /**
3484
+ * The vision model from CLIP without any head or projection on top.
3485
+ */
3486
+ export class CLIPVisionModel extends CLIPPreTrainedModel {
3487
+ /** @type {PreTrainedModel.from_pretrained} */
3488
+ static async from_pretrained(pretrained_model_name_or_path, options = {}) {
3489
+ // Update default model file name if not provided
3490
+ options.model_file_name ??= 'vision_model';
3491
+ return super.from_pretrained(pretrained_model_name_or_path, options);
3492
+ }
3493
+ }
3494
+
3540
3495
  /**
3541
3496
  * CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output)
3542
3497
  *
@@ -3759,18 +3714,7 @@ export class CLIPSegForImageSegmentation extends CLIPSegPreTrainedModel { }
3759
3714
 
3760
3715
  //////////////////////////////////////////////////
3761
3716
  // GPT2 models
3762
- export class GPT2PreTrainedModel extends PreTrainedModel {
3763
- /**
3764
- * Creates a new instance of the `GPT2PreTrainedModel` class.
3765
- * @param {Object} config The model configuration.
3766
- * @param {Record<string, any>} sessions The inference sessions for the model.
3767
- * @param {GenerationConfig} generation_config The generation configuration.
3768
- */
3769
- constructor(config, sessions, generation_config) {
3770
- super(config, sessions);
3771
- this.generation_config = generation_config;
3772
- }
3773
- }
3717
+ export class GPT2PreTrainedModel extends PreTrainedModel { }
3774
3718
 
3775
3719
  export class GPT2Model extends GPT2PreTrainedModel { }
3776
3720
 
@@ -3783,20 +3727,25 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { }
3783
3727
  // }
3784
3728
  //////////////////////////////////////////////////
3785
3729
 
3730
+ //////////////////////////////////////////////////
3731
+ // JAIS models
3732
+ export class JAISPreTrainedModel extends PreTrainedModel { }
3733
+
3734
+ /**
3735
+ * The bare JAIS Model transformer outputting raw hidden-states without any specific head on top.
3736
+ */
3737
+ export class JAISModel extends JAISPreTrainedModel { }
3738
+
3739
+ /**
3740
+ * The JAIS Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
3741
+ */
3742
+ export class JAISLMHeadModel extends JAISPreTrainedModel { }
3743
+ //////////////////////////////////////////////////
3744
+
3745
+
3786
3746
  //////////////////////////////////////////////////
3787
3747
  // GPTNeo models
3788
- export class GPTNeoPreTrainedModel extends PreTrainedModel {
3789
- /**
3790
- * Creates a new instance of the `GPTNeoPreTrainedModel` class.
3791
- * @param {Object} config The model configuration.
3792
- * @param {Record<string, any>} sessions The inference sessions for the model.
3793
- * @param {GenerationConfig} generation_config The generation configuration.
3794
- */
3795
- constructor(config, sessions, generation_config) {
3796
- super(config, sessions);
3797
- this.generation_config = generation_config;
3798
- }
3799
- }
3748
+ export class GPTNeoPreTrainedModel extends PreTrainedModel { }
3800
3749
  export class GPTNeoModel extends GPTNeoPreTrainedModel { }
3801
3750
 
3802
3751
  export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
@@ -3804,18 +3753,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
3804
3753
 
3805
3754
  //////////////////////////////////////////////////
3806
3755
  // GPTNeoX models
3807
- export class GPTNeoXPreTrainedModel extends PreTrainedModel {
3808
- /**
3809
- * Creates a new instance of the `GPTNeoXPreTrainedModel` class.
3810
- * @param {Object} config The model configuration.
3811
- * @param {Record<string, any>} sessions The inference sessions for the model.
3812
- * @param {GenerationConfig} generation_config The generation configuration.
3813
- */
3814
- constructor(config, sessions, generation_config) {
3815
- super(config, sessions);
3816
- this.generation_config = generation_config;
3817
- }
3818
- }
3756
+ export class GPTNeoXPreTrainedModel extends PreTrainedModel { }
3819
3757
  export class GPTNeoXModel extends GPTNeoXPreTrainedModel { }
3820
3758
 
3821
3759
  export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
@@ -3824,18 +3762,7 @@ export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
3824
3762
 
3825
3763
  //////////////////////////////////////////////////
3826
3764
  // GPT-J models
3827
- export class GPTJPreTrainedModel extends PreTrainedModel {
3828
- /**
3829
- * Creates a new instance of the `GPTJPreTrainedModel` class.
3830
- * @param {Object} config The model configuration.
3831
- * @param {Record<string, any>} sessions The inference sessions for the model.
3832
- * @param {GenerationConfig} generation_config The generation configuration.
3833
- */
3834
- constructor(config, sessions, generation_config) {
3835
- super(config, sessions);
3836
- this.generation_config = generation_config;
3837
- }
3838
- }
3765
+ export class GPTJPreTrainedModel extends PreTrainedModel { }
3839
3766
 
3840
3767
  export class GPTJModel extends GPTJPreTrainedModel { }
3841
3768
 
@@ -3845,18 +3772,7 @@ export class GPTJForCausalLM extends GPTJPreTrainedModel { }
3845
3772
 
3846
3773
  //////////////////////////////////////////////////
3847
3774
  // GPTBigCode models
3848
- export class GPTBigCodePreTrainedModel extends PreTrainedModel {
3849
- /**
3850
- * Creates a new instance of the `GPTBigCodePreTrainedModel` class.
3851
- * @param {Object} config The model configuration.
3852
- * @param {Record<string, any>} sessions The inference sessions for the model.
3853
- * @param {GenerationConfig} generation_config The generation configuration.
3854
- */
3855
- constructor(config, sessions, generation_config) {
3856
- super(config, sessions);
3857
- this.generation_config = generation_config;
3858
- }
3859
- }
3775
+ export class GPTBigCodePreTrainedModel extends PreTrainedModel { }
3860
3776
 
3861
3777
  export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { }
3862
3778
 
@@ -3865,18 +3781,7 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
3865
3781
 
3866
3782
  //////////////////////////////////////////////////
3867
3783
  // CodeGen models
3868
- export class CodeGenPreTrainedModel extends PreTrainedModel {
3869
- /**
3870
- * Creates a new instance of the `CodeGenPreTrainedModel` class.
3871
- * @param {Object} config The model configuration.
3872
- * @param {Record<string, any>} sessions The inference sessions for the model.
3873
- * @param {GenerationConfig} generation_config The generation configuration.
3874
- */
3875
- constructor(config, sessions, generation_config) {
3876
- super(config, sessions);
3877
- this.generation_config = generation_config;
3878
- }
3879
- }
3784
+ export class CodeGenPreTrainedModel extends PreTrainedModel { }
3880
3785
  /**
3881
3786
  * CodeGenModel is a class representing a code generation model without a language model head.
3882
3787
  */
@@ -3895,18 +3800,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
3895
3800
  /**
3896
3801
  * The bare LLama Model outputting raw hidden-states without any specific head on top.
3897
3802
  */
3898
- export class LlamaPreTrainedModel extends PreTrainedModel {
3899
- /**
3900
- * Creates a new instance of the `LlamaPreTrainedModel` class.
3901
- * @param {Object} config The model configuration.
3902
- * @param {Record<string, any>} sessions The inference sessions for the model.
3903
- * @param {GenerationConfig} generation_config The generation configuration.
3904
- */
3905
- constructor(config, sessions, generation_config) {
3906
- super(config, sessions);
3907
- this.generation_config = generation_config;
3908
- }
3909
- }
3803
+ export class LlamaPreTrainedModel extends PreTrainedModel { }
3910
3804
  /**
3911
3805
  * The bare LLaMA Model outputting raw hidden-states without any specific head on top.
3912
3806
  */
@@ -3915,24 +3809,22 @@ export class LlamaModel extends LlamaPreTrainedModel { }
3915
3809
  export class LlamaForCausalLM extends LlamaPreTrainedModel { }
3916
3810
  //////////////////////////////////////////////////
3917
3811
 
3812
+
3813
+ //////////////////////////////////////////////////
3814
+ // Granite models
3815
+ export class GranitePreTrainedModel extends PreTrainedModel { }
3816
+ export class GraniteModel extends GranitePreTrainedModel { }
3817
+ export class GraniteForCausalLM extends GranitePreTrainedModel { }
3818
+ //////////////////////////////////////////////////
3819
+
3820
+
3918
3821
  //////////////////////////////////////////////////
3919
3822
  // Cohere models
3920
3823
 
3921
3824
  /**
3922
3825
  * The bare Cohere Model outputting raw hidden-states without any specific head on top.
3923
3826
  */
3924
- export class CoherePreTrainedModel extends PreTrainedModel {
3925
- /**
3926
- * Creates a new instance of the `CoherePreTrainedModel` class.
3927
- * @param {Object} config The model configuration.
3928
- * @param {Record<string, any>} sessions The inference sessions for the model.
3929
- * @param {GenerationConfig} generation_config The generation configuration.
3930
- */
3931
- constructor(config, sessions, generation_config) {
3932
- super(config, sessions);
3933
- this.generation_config = generation_config;
3934
- }
3935
- }
3827
+ export class CoherePreTrainedModel extends PreTrainedModel { }
3936
3828
  export class CohereModel extends CoherePreTrainedModel { }
3937
3829
 
3938
3830
  export class CohereForCausalLM extends CoherePreTrainedModel { }
@@ -3944,18 +3836,7 @@ export class CohereForCausalLM extends CoherePreTrainedModel { }
3944
3836
  /**
3945
3837
  * The bare Gemma Model outputting raw hidden-states without any specific head on top.
3946
3838
  */
3947
- export class GemmaPreTrainedModel extends PreTrainedModel {
3948
- /**
3949
- * Creates a new instance of the `GemmaPreTrainedModel` class.
3950
- * @param {Object} config The model configuration.
3951
- * @param {Record<string, any>} sessions The inference sessions for the model.
3952
- * @param {GenerationConfig} generation_config The generation configuration.
3953
- */
3954
- constructor(config, sessions, generation_config) {
3955
- super(config, sessions);
3956
- this.generation_config = generation_config;
3957
- }
3958
- }
3839
+ export class GemmaPreTrainedModel extends PreTrainedModel { }
3959
3840
  /**
3960
3841
  * The bare Gemma Model outputting raw hidden-states without any specific head on top.
3961
3842
  */
@@ -3970,18 +3851,7 @@ export class GemmaForCausalLM extends GemmaPreTrainedModel { }
3970
3851
  /**
3971
3852
  * The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
3972
3853
  */
3973
- export class Gemma2PreTrainedModel extends PreTrainedModel {
3974
- /**
3975
- * Creates a new instance of the `Gemma2PreTrainedModel` class.
3976
- * @param {Object} config The model configuration.
3977
- * @param {Record<string, any>} sessions The inference sessions for the model.
3978
- * @param {GenerationConfig} generation_config The generation configuration.
3979
- */
3980
- constructor(config, sessions, generation_config) {
3981
- super(config, sessions);
3982
- this.generation_config = generation_config;
3983
- }
3984
- }
3854
+ export class Gemma2PreTrainedModel extends PreTrainedModel { }
3985
3855
  /**
3986
3856
  * The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
3987
3857
  */
@@ -3991,18 +3861,7 @@ export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
3991
3861
  //////////////////////////////////////////////////
3992
3862
 
3993
3863
  //////////////////////////////////////////////////
3994
- export class OpenELMPreTrainedModel extends PreTrainedModel {
3995
- /**
3996
- * Creates a new instance of the `OpenELMPreTrainedModel` class.
3997
- * @param {Object} config The model configuration.
3998
- * @param {Record<string, any>} sessions The inference sessions for the model.
3999
- * @param {GenerationConfig} generation_config The generation configuration.
4000
- */
4001
- constructor(config, sessions, generation_config) {
4002
- super(config, sessions);
4003
- this.generation_config = generation_config;
4004
- }
4005
- }
3864
+ export class OpenELMPreTrainedModel extends PreTrainedModel { }
4006
3865
  export class OpenELMModel extends OpenELMPreTrainedModel { }
4007
3866
 
4008
3867
  export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
@@ -4014,18 +3873,7 @@ export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
4014
3873
  /**
4015
3874
  * The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
4016
3875
  */
4017
- export class Qwen2PreTrainedModel extends PreTrainedModel {
4018
- /**
4019
- * Creates a new instance of the `Qwen2PreTrainedModel` class.
4020
- * @param {Object} config The model configuration.
4021
- * @param {Record<string, any>} sessions The inference sessions for the model.
4022
- * @param {GenerationConfig} generation_config The generation configuration.
4023
- */
4024
- constructor(config, sessions, generation_config) {
4025
- super(config, sessions);
4026
- this.generation_config = generation_config;
4027
- }
4028
- }
3876
+ export class Qwen2PreTrainedModel extends PreTrainedModel { }
4029
3877
  /**
4030
3878
  * The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
4031
3879
  */
@@ -4037,18 +3885,7 @@ export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
4037
3885
 
4038
3886
  //////////////////////////////////////////////////
4039
3887
  // Phi models
4040
- export class PhiPreTrainedModel extends PreTrainedModel {
4041
- /**
4042
- * Creates a new instance of the `PhiPreTrainedModel` class.
4043
- * @param {Object} config The model configuration.
4044
- * @param {Record<string, any>} sessions The inference sessions for the model.
4045
- * @param {GenerationConfig} generation_config The generation configuration.
4046
- */
4047
- constructor(config, sessions, generation_config) {
4048
- super(config, sessions);
4049
- this.generation_config = generation_config;
4050
- }
4051
- }
3888
+ export class PhiPreTrainedModel extends PreTrainedModel { }
4052
3889
  /**
4053
3890
  * The bare Phi Model outputting raw hidden-states without any specific head on top.
4054
3891
  */
@@ -4059,18 +3896,7 @@ export class PhiForCausalLM extends PhiPreTrainedModel { }
4059
3896
 
4060
3897
  //////////////////////////////////////////////////
4061
3898
  // Phi3 models
4062
- export class Phi3PreTrainedModel extends PreTrainedModel {
4063
- /**
4064
- * Creates a new instance of the `Phi3PreTrainedModel` class.
4065
- * @param {Object} config The model configuration.
4066
- * @param {Record<string, any>} sessions The inference sessions for the model.
4067
- * @param {GenerationConfig} generation_config The generation configuration.
4068
- */
4069
- constructor(config, sessions, generation_config) {
4070
- super(config, sessions);
4071
- this.generation_config = generation_config;
4072
- }
4073
- }
3899
+ export class Phi3PreTrainedModel extends PreTrainedModel { }
4074
3900
 
4075
3901
  /**
4076
3902
  * The bare Phi3 Model outputting raw hidden-states without any specific head on top.
@@ -4086,18 +3912,7 @@ export class Phi3ForCausalLM extends Phi3PreTrainedModel { }
4086
3912
  /**
4087
3913
  * The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
4088
3914
  */
4089
- export class BloomPreTrainedModel extends PreTrainedModel {
4090
- /**
4091
- * Creates a new instance of the `BloomPreTrainedModel` class.
4092
- * @param {Object} config The model configuration.
4093
- * @param {Record<string, any>} sessions The inference sessions for the model.
4094
- * @param {GenerationConfig} generation_config The generation configuration.
4095
- */
4096
- constructor(config, sessions, generation_config) {
4097
- super(config, sessions);
4098
- this.generation_config = generation_config;
4099
- }
4100
- }
3915
+ export class BloomPreTrainedModel extends PreTrainedModel { }
4101
3916
 
4102
3917
  /**
4103
3918
  * The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.
@@ -4112,18 +3927,7 @@ export class BloomForCausalLM extends BloomPreTrainedModel { }
4112
3927
 
4113
3928
  //////////////////////////////////////////////////
4114
3929
  // MPT models
4115
- export class MptPreTrainedModel extends PreTrainedModel {
4116
- /**
4117
- * Creates a new instance of the `MptPreTrainedModel` class.
4118
- * @param {Object} config The model configuration.
4119
- * @param {Record<string, any>} sessions The inference sessions for the model.
4120
- * @param {GenerationConfig} generation_config The generation configuration.
4121
- */
4122
- constructor(config, sessions, generation_config) {
4123
- super(config, sessions);
4124
- this.generation_config = generation_config;
4125
- }
4126
- }
3930
+ export class MptPreTrainedModel extends PreTrainedModel { }
4127
3931
 
4128
3932
  /**
4129
3933
  * The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.
@@ -4139,18 +3943,7 @@ export class MptForCausalLM extends MptPreTrainedModel { }
4139
3943
 
4140
3944
  //////////////////////////////////////////////////
4141
3945
  // OPT models
4142
- export class OPTPreTrainedModel extends PreTrainedModel {
4143
- /**
4144
- * Creates a new instance of the `OPTPreTrainedModel` class.
4145
- * @param {Object} config The model configuration.
4146
- * @param {Record<string, any>} sessions The inference sessions for the model.
4147
- * @param {GenerationConfig} generation_config The generation configuration.
4148
- */
4149
- constructor(config, sessions, generation_config) {
4150
- super(config, sessions);
4151
- this.generation_config = generation_config;
4152
- }
4153
- }
3946
+ export class OPTPreTrainedModel extends PreTrainedModel { }
4154
3947
 
4155
3948
  /**
4156
3949
  * The bare OPT Model outputting raw hidden-states without any specific head on top.
@@ -4176,6 +3969,43 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
4176
3969
  }
4177
3970
  //////////////////////////////////////////////////
4178
3971
 
3972
+ //////////////////////////////////////////////////
3973
+ export class PvtPreTrainedModel extends PreTrainedModel { }
3974
+ export class PvtModel extends PvtPreTrainedModel { }
3975
+ export class PvtForImageClassification extends PvtPreTrainedModel {
3976
+ /**
3977
+ * @param {any} model_inputs
3978
+ */
3979
+ async _call(model_inputs) {
3980
+ return new SequenceClassifierOutput(await super._call(model_inputs));
3981
+ }
3982
+ }
3983
+ //////////////////////////////////////////////////
3984
+
3985
+ //////////////////////////////////////////////////
3986
+ export class ViTMAEPreTrainedModel extends PreTrainedModel { }
3987
+ export class ViTMAEModel extends ViTMAEPreTrainedModel { }
3988
+ //////////////////////////////////////////////////
3989
+
3990
+
3991
+ //////////////////////////////////////////////////
3992
+ export class ViTMSNPreTrainedModel extends PreTrainedModel { }
3993
+ export class ViTMSNModel extends ViTMSNPreTrainedModel { }
3994
+ export class ViTMSNForImageClassification extends ViTMSNPreTrainedModel {
3995
+ /**
3996
+ * @param {any} model_inputs
3997
+ */
3998
+ async _call(model_inputs) {
3999
+ return new SequenceClassifierOutput(await super._call(model_inputs));
4000
+ }
4001
+ }
4002
+ //////////////////////////////////////////////////
4003
+
4004
+ //////////////////////////////////////////////////
4005
+ export class GroupViTPreTrainedModel extends PreTrainedModel { }
4006
+ export class GroupViTModel extends GroupViTPreTrainedModel { }
4007
+ //////////////////////////////////////////////////
4008
+
4179
4009
 
4180
4010
  //////////////////////////////////////////////////
4181
4011
  export class FastViTPreTrainedModel extends PreTrainedModel { }
@@ -4429,6 +4259,19 @@ export class DeiTForImageClassification extends DeiTPreTrainedModel {
4429
4259
  }
4430
4260
  //////////////////////////////////////////////////
4431
4261
 
4262
+ //////////////////////////////////////////////////
4263
+ export class HieraPreTrainedModel extends PreTrainedModel { }
4264
+ export class HieraModel extends HieraPreTrainedModel { }
4265
+ export class HieraForImageClassification extends HieraPreTrainedModel {
4266
+ /**
4267
+ * @param {any} model_inputs
4268
+ */
4269
+ async _call(model_inputs) {
4270
+ return new SequenceClassifierOutput(await super._call(model_inputs));
4271
+ }
4272
+ }
4273
+ //////////////////////////////////////////////////
4274
+
4432
4275
 
4433
4276
  //////////////////////////////////////////////////
4434
4277
  /**
@@ -4568,6 +4411,24 @@ export class DepthAnythingForDepthEstimation extends DepthAnythingPreTrainedMode
4568
4411
  //////////////////////////////////////////////////
4569
4412
 
4570
4413
 
4414
+ //////////////////////////////////////////////////
4415
+ export class SapiensPreTrainedModel extends PreTrainedModel { }
4416
+ export class SapiensForSemanticSegmentation extends SapiensPreTrainedModel { }
4417
+ export class SapiensForDepthEstimation extends SapiensPreTrainedModel { }
4418
+ export class SapiensForNormalEstimation extends SapiensPreTrainedModel { }
4419
+ //////////////////////////////////////////////////
4420
+
4421
+ //////////////////////////////////////////////////
4422
+ export class DepthProPreTrainedModel extends PreTrainedModel { }
4423
+ export class DepthProForDepthEstimation extends DepthProPreTrainedModel { }
4424
+ //////////////////////////////////////////////////
4425
+
4426
+ //////////////////////////////////////////////////
4427
+ export class MaskFormerPreTrainedModel extends PreTrainedModel { }
4428
+ export class MaskFormerModel extends MaskFormerPreTrainedModel { }
4429
+ export class MaskFormerForInstanceSegmentation extends MaskFormerPreTrainedModel { }
4430
+ //////////////////////////////////////////////////
4431
+
4571
4432
  //////////////////////////////////////////////////
4572
4433
  export class GLPNPreTrainedModel extends PreTrainedModel { }
4573
4434
 
@@ -4944,19 +4805,7 @@ export class SamImageSegmentationOutput extends ModelOutput {
4944
4805
 
4945
4806
  //////////////////////////////////////////////////
4946
4807
  // MarianMT models
4947
- export class MarianPreTrainedModel extends PreTrainedModel {
4948
-
4949
- /**
4950
- * Creates a new instance of the `MarianMTModel` class.
4951
- * @param {Object} config The model configuration.
4952
- * @param {Record<string, any>} sessions The inference sessions for the model.
4953
- * @param {GenerationConfig} generation_config The generation configuration.
4954
- */
4955
- constructor(config, sessions, generation_config) {
4956
- super(config, sessions);
4957
- this.generation_config = generation_config;
4958
- }
4959
- };
4808
+ export class MarianPreTrainedModel extends PreTrainedModel { };
4960
4809
 
4961
4810
  export class MarianModel extends MarianPreTrainedModel { }
4962
4811
 
@@ -4965,19 +4814,7 @@ export class MarianMTModel extends MarianPreTrainedModel { }
4965
4814
 
4966
4815
  //////////////////////////////////////////////////
4967
4816
  // M2M100 models
4968
- export class M2M100PreTrainedModel extends PreTrainedModel {
4969
-
4970
- /**
4971
- * Creates a new instance of the `M2M100ForConditionalGeneration` class.
4972
- * @param {Object} config The model configuration.
4973
- * @param {Record<string, any>} sessions The inference sessions for the model.
4974
- * @param {GenerationConfig} generation_config The generation configuration.
4975
- */
4976
- constructor(config, sessions, generation_config) {
4977
- super(config, sessions);
4978
- this.generation_config = generation_config;
4979
- }
4980
- };
4817
+ export class M2M100PreTrainedModel extends PreTrainedModel { };
4981
4818
 
4982
4819
  export class M2M100Model extends M2M100PreTrainedModel { }
4983
4820
 
@@ -5069,7 +4906,7 @@ export class PyAnnoteModel extends PyAnnotePreTrainedModel { }
5069
4906
  * **Example:** Load and run a `PyAnnoteForAudioFrameClassification` for speaker diarization.
5070
4907
  *
5071
4908
  * ```javascript
5072
- * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers';
4909
+ * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@huggingface/transformers';
5073
4910
  *
5074
4911
  * // Load model and processor
5075
4912
  * const model_id = 'onnx-community/pyannote-segmentation-3.0';
@@ -5487,19 +5324,7 @@ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
5487
5324
  /**
5488
5325
  * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
5489
5326
  */
5490
- export class SpeechT5PreTrainedModel extends PreTrainedModel {
5491
-
5492
- /**
5493
- * Creates a new instance of the `SpeechT5ForTextToSpeech` class.
5494
- * @param {Object} config The model configuration.
5495
- * @param {Record<string, any>} sessions The inference sessions for the model.
5496
- * @param {GenerationConfig} generation_config The generation configuration.
5497
- */
5498
- constructor(config, sessions, generation_config) {
5499
- super(config, sessions);
5500
- this.generation_config = generation_config;
5501
- }
5502
- };
5327
+ export class SpeechT5PreTrainedModel extends PreTrainedModel { };
5503
5328
 
5504
5329
  /**
5505
5330
  * The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
@@ -5660,18 +5485,7 @@ export class SpeechT5HifiGan extends PreTrainedModel {
5660
5485
 
5661
5486
  //////////////////////////////////////////////////
5662
5487
  // TrOCR models
5663
- export class TrOCRPreTrainedModel extends PreTrainedModel {
5664
- /**
5665
- * Creates a new instance of the `TrOCRPreTrainedModel` class.
5666
- * @param {Object} config The configuration of the model.
5667
- * @param {any} session The ONNX session containing the model weights.
5668
- * @param {GenerationConfig} generation_config The generation configuration.
5669
- */
5670
- constructor(config, session, generation_config) {
5671
- super(config, session);
5672
- this.generation_config = generation_config;
5673
- }
5674
- }
5488
+ export class TrOCRPreTrainedModel extends PreTrainedModel { }
5675
5489
 
5676
5490
  /**
5677
5491
  * The TrOCR Decoder with a language modeling head.
@@ -5686,18 +5500,7 @@ export class TrOCRForCausalLM extends TrOCRPreTrainedModel { }
5686
5500
  /**
5687
5501
  * The bare Mistral Model outputting raw hidden-states without any specific head on top.
5688
5502
  */
5689
- export class MistralPreTrainedModel extends PreTrainedModel {
5690
- /**
5691
- * Creates a new instance of the `MistralPreTrainedModel` class.
5692
- * @param {Object} config The configuration of the model.
5693
- * @param {any} session The ONNX session containing the model weights.
5694
- * @param {GenerationConfig} generation_config The generation configuration.
5695
- */
5696
- constructor(config, session, generation_config) {
5697
- super(config, session);
5698
- this.generation_config = generation_config;
5699
- }
5700
- }
5503
+ export class MistralPreTrainedModel extends PreTrainedModel { }
5701
5504
 
5702
5505
  export class MistralModel extends MistralPreTrainedModel { }
5703
5506
 
@@ -5710,18 +5513,7 @@ export class MistralForCausalLM extends MistralPreTrainedModel { }
5710
5513
  /**
5711
5514
  * The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.
5712
5515
  */
5713
- export class Starcoder2PreTrainedModel extends PreTrainedModel {
5714
- /**
5715
- * Creates a new instance of the `Starcoder2PreTrainedModel` class.
5716
- * @param {Object} config The configuration of the model.
5717
- * @param {any} session The ONNX session containing the model weights.
5718
- * @param {GenerationConfig} generation_config The generation configuration.
5719
- */
5720
- constructor(config, session, generation_config) {
5721
- super(config, session);
5722
- this.generation_config = generation_config;
5723
- }
5724
- }
5516
+ export class Starcoder2PreTrainedModel extends PreTrainedModel { }
5725
5517
 
5726
5518
  export class Starcoder2Model extends Starcoder2PreTrainedModel { }
5727
5519
 
@@ -5734,18 +5526,7 @@ export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel { }
5734
5526
  /**
5735
5527
  * The bare Falcon Model outputting raw hidden-states without any specific head on top.
5736
5528
  */
5737
- export class FalconPreTrainedModel extends PreTrainedModel {
5738
- /**
5739
- * Creates a new instance of the `FalconPreTrainedModel` class.
5740
- * @param {Object} config The configuration of the model.
5741
- * @param {any} session The ONNX session containing the model weights.
5742
- * @param {GenerationConfig} generation_config The generation configuration.
5743
- */
5744
- constructor(config, session, generation_config) {
5745
- super(config, session);
5746
- this.generation_config = generation_config;
5747
- }
5748
- }
5529
+ export class FalconPreTrainedModel extends PreTrainedModel { }
5749
5530
 
5750
5531
  export class FalconModel extends FalconPreTrainedModel { }
5751
5532
 
@@ -5895,18 +5676,7 @@ export class SegformerForSemanticSegmentation extends SegformerPreTrainedModel {
5895
5676
 
5896
5677
  //////////////////////////////////////////////////
5897
5678
  // StableLm models
5898
- export class StableLmPreTrainedModel extends PreTrainedModel {
5899
- /**
5900
- * Creates a new instance of the `StableLmPreTrainedModel` class.
5901
- * @param {Object} config The configuration of the model.
5902
- * @param {any} session The ONNX session containing the model weights.
5903
- * @param {GenerationConfig} generation_config The generation configuration.
5904
- */
5905
- constructor(config, session, generation_config) {
5906
- super(config, session);
5907
- this.generation_config = generation_config;
5908
- }
5909
- }
5679
+ export class StableLmPreTrainedModel extends PreTrainedModel { }
5910
5680
 
5911
5681
  /**
5912
5682
  * The bare StableLm Model transformer outputting raw hidden-states without any specific head on top.
@@ -6000,17 +5770,6 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE:
6000
5770
  'past_key_values',
6001
5771
  ];
6002
5772
 
6003
- /**
6004
- * Creates a new instance of the `MusicgenForConditionalGeneration` class.
6005
- * @param {Object} config The model configuration.
6006
- * @param {Record<string, any>} sessions The inference sessions for the model.
6007
- * @param {GenerationConfig} generation_config The generation configuration.
6008
- */
6009
- constructor(config, sessions, generation_config) {
6010
- super(config, sessions);
6011
- this.generation_config = generation_config;
6012
- }
6013
-
6014
5773
  /**
6015
5774
  * Apply the pattern mask to the final ids,
6016
5775
  * then revert the pattern delay mask by filtering the pad token id in a single step.
@@ -6089,6 +5848,7 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE:
6089
5848
  return audio_values;
6090
5849
  }
6091
5850
  }
5851
+ //////////////////////////////////////////////////
6092
5852
 
6093
5853
  //////////////////////////////////////////////////
6094
5854
  // MobileNetV1 models
@@ -6182,6 +5942,17 @@ export class MobileNetV4ForImageClassification extends MobileNetV4PreTrainedMode
6182
5942
  }
6183
5943
  //////////////////////////////////////////////////
6184
5944
 
5945
+ //////////////////////////////////////////////////
5946
+ // Decision Transformer models
5947
+ export class DecisionTransformerPreTrainedModel extends PreTrainedModel { }
5948
+
5949
+ /**
5950
+ * The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL setting.
5951
+ * Refer to the paper for more details: https://arxiv.org/abs/2106.01345
5952
+ */
5953
+ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel { }
5954
+
5955
+ //////////////////////////////////////////////////
6185
5956
 
6186
5957
  //////////////////////////////////////////////////
6187
5958
  // AutoModels, used to simplify construction of PreTrainedModels
@@ -6220,7 +5991,7 @@ export class PretrainedMixin {
6220
5991
  session_options = {},
6221
5992
  } = {}) {
6222
5993
 
6223
- let options = {
5994
+ const options = {
6224
5995
  progress_callback,
6225
5996
  config,
6226
5997
  cache_dir,
@@ -6239,7 +6010,7 @@ export class PretrainedMixin {
6239
6010
  throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name);
6240
6011
  }
6241
6012
 
6242
- for (let MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
6013
+ for (const MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) {
6243
6014
  const modelInfo = MODEL_CLASS_MAPPING.get(options.config.model_type);
6244
6015
  if (!modelInfo) {
6245
6016
  continue; // Item not found in this mapping
@@ -6294,6 +6065,10 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6294
6065
  ['rt_detr', ['RTDetrModel', RTDetrModel]],
6295
6066
  ['table-transformer', ['TableTransformerModel', TableTransformerModel]],
6296
6067
  ['vit', ['ViTModel', ViTModel]],
6068
+ ['pvt', ['PvtModel', PvtModel]],
6069
+ ['vit_msn', ['ViTMSNModel', ViTMSNModel]],
6070
+ ['vit_mae', ['ViTMAEModel', ViTMAEModel]],
6071
+ ['groupvit', ['GroupViTModel', GroupViTModel]],
6297
6072
  ['fastvit', ['FastViTModel', FastViTModel]],
6298
6073
  ['mobilevit', ['MobileViTModel', MobileViTModel]],
6299
6074
  ['mobilevitv2', ['MobileViTV2Model', MobileViTV2Model]],
@@ -6301,6 +6076,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6301
6076
  ['owlv2', ['Owlv2Model', Owlv2Model]],
6302
6077
  ['beit', ['BeitModel', BeitModel]],
6303
6078
  ['deit', ['DeiTModel', DeiTModel]],
6079
+ ['hiera', ['HieraModel', HieraModel]],
6304
6080
  ['convnext', ['ConvNextModel', ConvNextModel]],
6305
6081
  ['convnextv2', ['ConvNextV2Model', ConvNextV2Model]],
6306
6082
  ['dinov2', ['Dinov2Model', Dinov2Model]],
@@ -6315,10 +6091,14 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
6315
6091
  ['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],
6316
6092
  ['efficientnet', ['EfficientNetModel', EfficientNetModel]],
6317
6093
 
6094
+ ['decision_transformer', ['DecisionTransformerModel', DecisionTransformerModel]],
6095
+
6318
6096
  ['mobilenet_v1', ['MobileNetV1Model', MobileNetV1Model]],
6319
6097
  ['mobilenet_v2', ['MobileNetV2Model', MobileNetV2Model]],
6320
6098
  ['mobilenet_v3', ['MobileNetV3Model', MobileNetV3Model]],
6321
6099
  ['mobilenet_v4', ['MobileNetV4Model', MobileNetV4Model]],
6100
+
6101
+ ['maskformer', ['MaskFormerModel', MaskFormerModel]],
6322
6102
  ]);
6323
6103
 
6324
6104
  const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
@@ -6337,6 +6117,7 @@ const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
6337
6117
 
6338
6118
  const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
6339
6119
  ['bloom', ['BloomModel', BloomModel]],
6120
+ ['jais', ['JAISModel', JAISModel]],
6340
6121
  ['gpt2', ['GPT2Model', GPT2Model]],
6341
6122
  ['gptj', ['GPTJModel', GPTJModel]],
6342
6123
  ['gpt_bigcode', ['GPTBigCodeModel', GPTBigCodeModel]],
@@ -6344,6 +6125,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
6344
6125
  ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]],
6345
6126
  ['codegen', ['CodeGenModel', CodeGenModel]],
6346
6127
  ['llama', ['LlamaModel', LlamaModel]],
6128
+ ['granite', ['GraniteModel', GraniteModel]],
6347
6129
  ['cohere', ['CohereModel', CohereModel]],
6348
6130
  ['gemma', ['GemmaModel', GemmaModel]],
6349
6131
  ['gemma2', ['Gemma2Model', Gemma2Model]],
@@ -6425,12 +6207,14 @@ const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([
6425
6207
  const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
6426
6208
  ['bloom', ['BloomForCausalLM', BloomForCausalLM]],
6427
6209
  ['gpt2', ['GPT2LMHeadModel', GPT2LMHeadModel]],
6210
+ ['jais', ['JAISLMHeadModel', JAISLMHeadModel]],
6428
6211
  ['gptj', ['GPTJForCausalLM', GPTJForCausalLM]],
6429
6212
  ['gpt_bigcode', ['GPTBigCodeForCausalLM', GPTBigCodeForCausalLM]],
6430
6213
  ['gpt_neo', ['GPTNeoForCausalLM', GPTNeoForCausalLM]],
6431
6214
  ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]],
6432
6215
  ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]],
6433
6216
  ['llama', ['LlamaForCausalLM', LlamaForCausalLM]],
6217
+ ['granite', ['GraniteForCausalLM', GraniteForCausalLM]],
6434
6218
  ['cohere', ['CohereForCausalLM', CohereForCausalLM]],
6435
6219
  ['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
6436
6220
  ['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],
@@ -6501,11 +6285,14 @@ const MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
6501
6285
 
6502
6286
  const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([
6503
6287
  ['vit', ['ViTForImageClassification', ViTForImageClassification]],
6288
+ ['pvt', ['PvtForImageClassification', PvtForImageClassification]],
6289
+ ['vit_msn', ['ViTMSNForImageClassification', ViTMSNForImageClassification]],
6504
6290
  ['fastvit', ['FastViTForImageClassification', FastViTForImageClassification]],
6505
6291
  ['mobilevit', ['MobileViTForImageClassification', MobileViTForImageClassification]],
6506
6292
  ['mobilevitv2', ['MobileViTV2ForImageClassification', MobileViTV2ForImageClassification]],
6507
6293
  ['beit', ['BeitForImageClassification', BeitForImageClassification]],
6508
6294
  ['deit', ['DeiTForImageClassification', DeiTForImageClassification]],
6295
+ ['hiera', ['HieraForImageClassification', HieraForImageClassification]],
6509
6296
  ['convnext', ['ConvNextForImageClassification', ConvNextForImageClassification]],
6510
6297
  ['convnextv2', ['ConvNextV2ForImageClassification', ConvNextV2ForImageClassification]],
6511
6298
  ['dinov2', ['Dinov2ForImageClassification', Dinov2ForImageClassification]],
@@ -6532,12 +6319,19 @@ const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
6532
6319
  ]);
6533
6320
 
6534
6321
  const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
6322
+ // TODO: Do not add new models here
6535
6323
  ['detr', ['DetrForSegmentation', DetrForSegmentation]],
6536
6324
  ['clipseg', ['CLIPSegForImageSegmentation', CLIPSegForImageSegmentation]],
6537
6325
  ]);
6538
6326
 
6539
6327
  const MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = new Map([
6540
6328
  ['segformer', ['SegformerForSemanticSegmentation', SegformerForSemanticSegmentation]],
6329
+ ['sapiens', ['SapiensForSemanticSegmentation', SapiensForSemanticSegmentation]],
6330
+ ]);
6331
+
6332
+ const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([
6333
+ ['detr', ['DetrForSegmentation', DetrForSegmentation]],
6334
+ ['maskformer', ['MaskFormerForInstanceSegmentation', MaskFormerForInstanceSegmentation]],
6541
6335
  ]);
6542
6336
 
6543
6337
  const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
@@ -6586,6 +6380,12 @@ const MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES = new Map([
6586
6380
  ['dpt', ['DPTForDepthEstimation', DPTForDepthEstimation]],
6587
6381
  ['depth_anything', ['DepthAnythingForDepthEstimation', DepthAnythingForDepthEstimation]],
6588
6382
  ['glpn', ['GLPNForDepthEstimation', GLPNForDepthEstimation]],
6383
+ ['sapiens', ['SapiensForDepthEstimation', SapiensForDepthEstimation]],
6384
+ ['depth_pro', ['DepthProForDepthEstimation', DepthProForDepthEstimation]],
6385
+ ])
6386
+
6387
+ const MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES = new Map([
6388
+ ['sapiens', ['SapiensForNormalEstimation', SapiensForNormalEstimation]],
6589
6389
  ])
6590
6390
 
6591
6391
  // NOTE: This is custom to Transformers.js, and is necessary because certain models
@@ -6610,10 +6410,12 @@ const MODEL_CLASS_TYPE_MAPPING = [
6610
6410
  [MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES, MODEL_TYPES.ImageTextToText],
6611
6411
  [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6612
6412
  [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6413
+ [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6613
6414
  [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6614
6415
  [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6615
6416
  [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6616
6417
  [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6418
+ [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6617
6419
  [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6618
6420
  [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
6619
6421
  [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
@@ -6811,6 +6613,17 @@ export class AutoModelForSemanticSegmentation extends PretrainedMixin {
6811
6613
  static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES];
6812
6614
  }
6813
6615
 
6616
+ /**
6617
+ * Helper class which is used to instantiate pretrained universal image segmentation models with the `from_pretrained` function.
6618
+ * The chosen model class is determined by the type specified in the model config.
6619
+ *
6620
+ * @example
6621
+ * let model = await AutoModelForUniversalSegmentation.from_pretrained('hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation');
6622
+ */
6623
+ export class AutoModelForUniversalSegmentation extends PretrainedMixin {
6624
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES];
6625
+ }
6626
+
6814
6627
  /**
6815
6628
  * Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
6816
6629
  * The chosen model class is determined by the type specified in the model config.
@@ -6870,6 +6683,10 @@ export class AutoModelForDepthEstimation extends PretrainedMixin {
6870
6683
  static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES];
6871
6684
  }
6872
6685
 
6686
+ export class AutoModelForNormalEstimation extends PretrainedMixin {
6687
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_NORMAL_ESTIMATION_MAPPING_NAMES];
6688
+ }
6689
+
6873
6690
  export class AutoModelForImageFeatureExtraction extends PretrainedMixin {
6874
6691
  static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_FEATURE_EXTRACTION_MAPPING_NAMES];
6875
6692
  }