@huggingface/transformers 3.0.0-alpha.20 → 3.0.0-alpha.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/models.js CHANGED
@@ -142,7 +142,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
142
142
  * @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
143
143
  * @param {string} fileName The name of the model file.
144
144
  * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
145
- * @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
145
+ * @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
146
146
  * @private
147
147
  */
148
148
  async function getSession(pretrained_model_name_or_path, fileName, options) {
@@ -183,6 +183,22 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
183
183
  throw new Error(`The device (${selectedDevice}) does not support fp16.`);
184
184
  }
185
185
 
186
+ // Only valid for models with a decoder
187
+ const kv_cache_dtype = custom_config.kv_cache_dtype
188
+ ? (typeof custom_config.kv_cache_dtype === 'string'
189
+ ? custom_config.kv_cache_dtype
190
+ : custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
191
+ : undefined;
192
+
193
+ if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
194
+ throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
195
+ }
196
+
197
+ const session_config = {
198
+ dtype: selectedDtype,
199
+ kv_cache_dtype,
200
+ }
201
+
186
202
  // Construct the model file name
187
203
  const suffix = DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
188
204
  const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
@@ -258,7 +274,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
258
274
  }
259
275
 
260
276
  const buffer = await bufferPromise;
261
- return { buffer, session_options };
277
+
278
+ return { buffer, session_options, session_config };
262
279
  }
263
280
 
264
281
  /**
@@ -273,13 +290,30 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
273
290
  async function constructSessions(pretrained_model_name_or_path, names, options) {
274
291
  return Object.fromEntries(await Promise.all(
275
292
  Object.keys(names).map(async (name) => {
276
- const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options);
277
- const session = await createInferenceSession(buffer, session_options);
293
+ const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
294
+ const session = await createInferenceSession(buffer, session_options, session_config);
278
295
  return [name, session];
279
296
  })
280
297
  ));
281
298
  }
282
299
 
300
+ /**
301
+ * Helper function to load multiple optional configuration files
302
+ * @param {string} pretrained_model_name_or_path The path to the directory containing the config file.
303
+ * @param {Record<string, string>} names The names of the config files to load.
304
+ * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the configs.
305
+ * @returns {Promise<Record<string, any>>} A Promise that resolves to a dictionary of configuration objects.
306
+ * @private
307
+ */
308
+ async function getOptionalConfigs(pretrained_model_name_or_path, names, options) {
309
+ return Object.fromEntries(await Promise.all(
310
+ Object.keys(names).map(async (name) => {
311
+ const config = await getModelJSON(pretrained_model_name_or_path, names[name], false, options);
312
+ return [name, config];
313
+ })
314
+ ));
315
+ }
316
+
283
317
  /**
284
318
  * Validate model inputs
285
319
  * @param {Object} session The InferenceSession object that will be run.
@@ -691,12 +725,14 @@ export class PreTrainedModel extends Callable {
691
725
  * Creates a new instance of the `PreTrainedModel` class.
692
726
  * @param {import('./configs.js').PretrainedConfig} config The model configuration.
693
727
  * @param {Record<string, any>} sessions The inference sessions for the model.
728
+ * @param {Record<string, Object>} configs Additional configuration files (e.g., generation_config.json).
694
729
  */
695
- constructor(config, sessions) {
730
+ constructor(config, sessions, configs) {
696
731
  super();
697
732
 
698
733
  this.config = config;
699
734
  this.sessions = sessions;
735
+ this.configs = configs;
700
736
 
701
737
  const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this.constructor);
702
738
  const modelType = MODEL_TYPE_MAPPING.get(modelName);
@@ -812,7 +848,9 @@ export class PreTrainedModel extends Callable {
812
848
  constructSessions(pretrained_model_name_or_path, {
813
849
  model: options.model_file_name ?? 'model',
814
850
  }, options),
815
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
851
+ getOptionalConfigs(pretrained_model_name_or_path, {
852
+ generation_config: 'generation_config.json',
853
+ }, options),
816
854
  ]);
817
855
 
818
856
  } else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
@@ -821,7 +859,9 @@ export class PreTrainedModel extends Callable {
821
859
  model: 'encoder_model',
822
860
  decoder_model_merged: 'decoder_model_merged',
823
861
  }, options),
824
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
862
+ getOptionalConfigs(pretrained_model_name_or_path, {
863
+ generation_config: 'generation_config.json',
864
+ }, options),
825
865
  ]);
826
866
 
827
867
  } else if (modelType === MODEL_TYPES.MaskGeneration) {
@@ -851,7 +891,9 @@ export class PreTrainedModel extends Callable {
851
891
  }
852
892
  info = await Promise.all([
853
893
  constructSessions(pretrained_model_name_or_path, sessions, options),
854
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
894
+ getOptionalConfigs(pretrained_model_name_or_path, {
895
+ generation_config: 'generation_config.json',
896
+ }, options),
855
897
  ]);
856
898
 
857
899
  } else if (modelType === MODEL_TYPES.Musicgen) {
@@ -861,7 +903,9 @@ export class PreTrainedModel extends Callable {
861
903
  decoder_model_merged: 'decoder_model_merged',
862
904
  encodec_decode: 'encodec_decode',
863
905
  }, options),
864
- getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
906
+ getOptionalConfigs(pretrained_model_name_or_path, {
907
+ generation_config: 'generation_config.json',
908
+ }, options),
865
909
  ]);
866
910
 
867
911
  } else { // should be MODEL_TYPES.EncoderOnly
@@ -899,6 +943,14 @@ export class PreTrainedModel extends Callable {
899
943
  return await this._forward(this, model_inputs);
900
944
  }
901
945
 
946
+ /**
947
+ * Get the model's generation config, if it exists.
948
+ * @returns {GenerationConfig|null} The model's generation config if it exists, otherwise `null`.
949
+ */
950
+ get generation_config() {
951
+ return this.configs?.generation_config ?? null;
952
+ }
953
+
902
954
  /**
903
955
  * This function returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`]
904
956
  * instances used for multinomial sampling.
@@ -1078,9 +1130,7 @@ export class PreTrainedModel extends Callable {
1078
1130
  const gen_config = new cls(config);
1079
1131
 
1080
1132
  // Apply model's generation config, if it exists
1081
- if ('generation_config' in this) {
1082
- Object.assign(gen_config, this.generation_config);
1083
- }
1133
+ Object.assign(gen_config, this.generation_config ?? {});
1084
1134
 
1085
1135
  // Next, use any generation config specified by the user
1086
1136
  // when calling `generate`
@@ -1598,9 +1648,8 @@ export class PreTrainedModel extends Callable {
1598
1648
  if (pastKeyValues) {
1599
1649
  Object.assign(decoderFeeds, pastKeyValues)
1600
1650
  } else {
1601
-
1602
- /** @type {import('./transformers.js').DataType} */
1603
- const dtype = this.custom_config.kv_cache_dtype ?? 'float32';
1651
+ const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
1652
+ const dtype = session?.config?.kv_cache_dtype ?? 'float32';
1604
1653
  const empty = (dtype === 'float16') ? new Uint16Array() : [];
1605
1654
 
1606
1655
  const shapes = getKeyValueShapes(this.config);
@@ -2493,17 +2542,6 @@ export class T5PreTrainedModel extends PreTrainedModel {
2493
2542
  'decoder_attention_mask',
2494
2543
  'past_key_values',
2495
2544
  ];
2496
-
2497
- /**
2498
- * Creates a new instance of the `T5PreTrainedModel` class.
2499
- * @param {Object} config The model configuration.
2500
- * @param {Record<string, any>} sessions The inference sessions for the model.
2501
- * @param {GenerationConfig} generation_config The generation configuration.
2502
- */
2503
- constructor(config, sessions, generation_config) {
2504
- super(config, sessions);
2505
- this.generation_config = generation_config;
2506
- }
2507
2545
  };
2508
2546
 
2509
2547
  export class T5Model extends T5PreTrainedModel { }
@@ -2520,18 +2558,7 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { }
2520
2558
  /**
2521
2559
  * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
2522
2560
  */
2523
- export class LongT5PreTrainedModel extends PreTrainedModel {
2524
- /**
2525
- * Creates a new instance of the `LongT5ForConditionalGeneration` class.
2526
- * @param {Object} config The model configuration.
2527
- * @param {Record<string, any>} sessions The inference sessions for the model.
2528
- * @param {GenerationConfig} generation_config The generation configuration.
2529
- */
2530
- constructor(config, sessions, generation_config) {
2531
- super(config, sessions);
2532
- this.generation_config = generation_config;
2533
- }
2534
- };
2561
+ export class LongT5PreTrainedModel extends PreTrainedModel { };
2535
2562
 
2536
2563
  /**
2537
2564
  * The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.
@@ -2547,19 +2574,7 @@ export class LongT5ForConditionalGeneration extends LongT5PreTrainedModel { }
2547
2574
 
2548
2575
  //////////////////////////////////////////////////
2549
2576
  // MT5 models
2550
- export class MT5PreTrainedModel extends PreTrainedModel {
2551
-
2552
- /**
2553
- * Creates a new instance of the `MT5ForConditionalGeneration` class.
2554
- * @param {Object} config The model configuration.
2555
- * @param {Record<string, any>} sessions The inference sessions for the model.
2556
- * @param {GenerationConfig} generation_config The generation configuration.
2557
- */
2558
- constructor(config, sessions, generation_config) {
2559
- super(config, sessions);
2560
- this.generation_config = generation_config;
2561
- }
2562
- };
2577
+ export class MT5PreTrainedModel extends PreTrainedModel { };
2563
2578
 
2564
2579
  export class MT5Model extends MT5PreTrainedModel { }
2565
2580
 
@@ -2571,19 +2586,7 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { }
2571
2586
 
2572
2587
  //////////////////////////////////////////////////
2573
2588
  // Bart models
2574
- export class BartPretrainedModel extends PreTrainedModel {
2575
-
2576
- /**
2577
- * Creates a new instance of the `BartForConditionalGeneration` class.
2578
- * @param {Object} config The model configuration.
2579
- * @param {Record<string, any>} sessions The inference sessions for the model.
2580
- * @param {GenerationConfig} generation_config The generation configuration.
2581
- */
2582
- constructor(config, sessions, generation_config) {
2583
- super(config, sessions);
2584
- this.generation_config = generation_config;
2585
- }
2586
- };
2589
+ export class BartPretrainedModel extends PreTrainedModel { };
2587
2590
 
2588
2591
  /**
2589
2592
  * The bare BART Model outputting raw hidden-states without any specific head on top.
@@ -2614,19 +2617,7 @@ export class BartForSequenceClassification extends BartPretrainedModel {
2614
2617
 
2615
2618
  //////////////////////////////////////////////////
2616
2619
  // MBart models
2617
- export class MBartPreTrainedModel extends PreTrainedModel {
2618
-
2619
- /**
2620
- * Creates a new instance of the `MBartForConditionalGeneration` class.
2621
- * @param {Object} config The model configuration.
2622
- * @param {Record<string, any>} sessions The inference sessions for the model.
2623
- * @param {GenerationConfig} generation_config The generation configuration.
2624
- */
2625
- constructor(config, sessions, generation_config) {
2626
- super(config, sessions);
2627
- this.generation_config = generation_config;
2628
- }
2629
- };
2620
+ export class MBartPreTrainedModel extends PreTrainedModel { };
2630
2621
 
2631
2622
  /**
2632
2623
  * The bare MBART Model outputting raw hidden-states without any specific head on top.
@@ -2660,19 +2651,7 @@ export class MBartForCausalLM extends MBartPreTrainedModel { }
2660
2651
 
2661
2652
  //////////////////////////////////////////////////
2662
2653
  // Blenderbot models
2663
- export class BlenderbotPreTrainedModel extends PreTrainedModel {
2664
-
2665
- /**
2666
- * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
2667
- * @param {Object} config The model configuration.
2668
- * @param {Record<string, any>} sessions The inference sessions for the model.
2669
- * @param {GenerationConfig} generation_config The generation configuration.
2670
- */
2671
- constructor(config, sessions, generation_config) {
2672
- super(config, sessions);
2673
- this.generation_config = generation_config;
2674
- }
2675
- };
2654
+ export class BlenderbotPreTrainedModel extends PreTrainedModel { };
2676
2655
 
2677
2656
  /**
2678
2657
  * The bare Blenderbot Model outputting raw hidden-states without any specific head on top.
@@ -2688,19 +2667,7 @@ export class BlenderbotForConditionalGeneration extends BlenderbotPreTrainedMode
2688
2667
 
2689
2668
  //////////////////////////////////////////////////
2690
2669
  // Blenderbot models
2691
- export class BlenderbotSmallPreTrainedModel extends PreTrainedModel {
2692
-
2693
- /**
2694
- * Creates a new instance of the `BlenderbotForConditionalGeneration` class.
2695
- * @param {Object} config The model configuration.
2696
- * @param {Record<string, any>} sessions The inference sessions for the model.
2697
- * @param {GenerationConfig} generation_config The generation configuration.
2698
- */
2699
- constructor(config, sessions, generation_config) {
2700
- super(config, sessions);
2701
- this.generation_config = generation_config;
2702
- }
2703
- };
2670
+ export class BlenderbotSmallPreTrainedModel extends PreTrainedModel { };
2704
2671
 
2705
2672
  /**
2706
2673
  * The bare BlenderbotSmall Model outputting raw hidden-states without any specific head on top.
@@ -2949,17 +2916,6 @@ export class WhisperPreTrainedModel extends PreTrainedModel {
2949
2916
  'decoder_attention_mask',
2950
2917
  'past_key_values',
2951
2918
  ];
2952
-
2953
- /**
2954
- * Creates a new instance of the `WhisperPreTrainedModel` class.
2955
- * @param {Object} config The model configuration.
2956
- * @param {Record<string, any>} sessions The inference sessions for the model.
2957
- * @param {GenerationConfig} generation_config The generation configuration.
2958
- */
2959
- constructor(config, sessions, generation_config) {
2960
- super(config, sessions);
2961
- this.generation_config = generation_config;
2962
- }
2963
2919
  };
2964
2920
 
2965
2921
  /**
@@ -3230,16 +3186,6 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
3230
3186
  'encoder_hidden_states',
3231
3187
  'past_key_values',
3232
3188
  ];
3233
- /**
3234
- * Creates a new instance of the `VisionEncoderDecoderModel` class.
3235
- * @param {Object} config The model configuration.
3236
- * @param {Record<string, any>} sessions The inference sessions for the model.
3237
- * @param {GenerationConfig} generation_config The generation configuration.
3238
- */
3239
- constructor(config, sessions, generation_config) {
3240
- super(config, sessions);
3241
- this.generation_config = generation_config;
3242
- }
3243
3189
  }
3244
3190
  //////////////////////////////////////////////////
3245
3191
 
@@ -3254,11 +3200,6 @@ export class LlavaPreTrainedModel extends PreTrainedModel {
3254
3200
  'position_ids',
3255
3201
  'past_key_values',
3256
3202
  ];
3257
-
3258
- constructor(config, sessions, generation_config) {
3259
- super(config, sessions);
3260
- this.generation_config = generation_config;
3261
- }
3262
3203
  }
3263
3204
 
3264
3205
  /**
@@ -3345,11 +3286,6 @@ export class Florence2PreTrainedModel extends PreTrainedModel {
3345
3286
  'past_key_values',
3346
3287
  ];
3347
3288
  main_input_name = 'inputs_embeds';
3348
-
3349
- constructor(config, sessions, generation_config) {
3350
- super(config, sessions);
3351
- this.generation_config = generation_config;
3352
- }
3353
3289
  }
3354
3290
 
3355
3291
  export class Florence2ForConditionalGeneration extends Florence2PreTrainedModel {
@@ -3769,18 +3705,7 @@ export class CLIPSegForImageSegmentation extends CLIPSegPreTrainedModel { }
3769
3705
 
3770
3706
  //////////////////////////////////////////////////
3771
3707
  // GPT2 models
3772
- export class GPT2PreTrainedModel extends PreTrainedModel {
3773
- /**
3774
- * Creates a new instance of the `GPT2PreTrainedModel` class.
3775
- * @param {Object} config The model configuration.
3776
- * @param {Record<string, any>} sessions The inference sessions for the model.
3777
- * @param {GenerationConfig} generation_config The generation configuration.
3778
- */
3779
- constructor(config, sessions, generation_config) {
3780
- super(config, sessions);
3781
- this.generation_config = generation_config;
3782
- }
3783
- }
3708
+ export class GPT2PreTrainedModel extends PreTrainedModel { }
3784
3709
 
3785
3710
  export class GPT2Model extends GPT2PreTrainedModel { }
3786
3711
 
@@ -3795,18 +3720,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { }
3795
3720
 
3796
3721
  //////////////////////////////////////////////////
3797
3722
  // JAIS models
3798
- export class JAISPreTrainedModel extends PreTrainedModel {
3799
- /**
3800
- * Creates a new instance of the `JAISPreTrainedModel` class.
3801
- * @param {Object} config The model configuration.
3802
- * @param {Record<string, any>} sessions The inference sessions for the model.
3803
- * @param {GenerationConfig} generation_config The generation configuration.
3804
- */
3805
- constructor(config, sessions, generation_config) {
3806
- super(config, sessions);
3807
- this.generation_config = generation_config;
3808
- }
3809
- }
3723
+ export class JAISPreTrainedModel extends PreTrainedModel { }
3810
3724
 
3811
3725
  /**
3812
3726
  * The bare JAIS Model transformer outputting raw hidden-states without any specific head on top.
@@ -3822,18 +3736,7 @@ export class JAISLMHeadModel extends JAISPreTrainedModel { }
3822
3736
 
3823
3737
  //////////////////////////////////////////////////
3824
3738
  // GPTNeo models
3825
- export class GPTNeoPreTrainedModel extends PreTrainedModel {
3826
- /**
3827
- * Creates a new instance of the `GPTNeoPreTrainedModel` class.
3828
- * @param {Object} config The model configuration.
3829
- * @param {Record<string, any>} sessions The inference sessions for the model.
3830
- * @param {GenerationConfig} generation_config The generation configuration.
3831
- */
3832
- constructor(config, sessions, generation_config) {
3833
- super(config, sessions);
3834
- this.generation_config = generation_config;
3835
- }
3836
- }
3739
+ export class GPTNeoPreTrainedModel extends PreTrainedModel { }
3837
3740
  export class GPTNeoModel extends GPTNeoPreTrainedModel { }
3838
3741
 
3839
3742
  export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
@@ -3841,18 +3744,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { }
3841
3744
 
3842
3745
  //////////////////////////////////////////////////
3843
3746
  // GPTNeoX models
3844
- export class GPTNeoXPreTrainedModel extends PreTrainedModel {
3845
- /**
3846
- * Creates a new instance of the `GPTNeoXPreTrainedModel` class.
3847
- * @param {Object} config The model configuration.
3848
- * @param {Record<string, any>} sessions The inference sessions for the model.
3849
- * @param {GenerationConfig} generation_config The generation configuration.
3850
- */
3851
- constructor(config, sessions, generation_config) {
3852
- super(config, sessions);
3853
- this.generation_config = generation_config;
3854
- }
3855
- }
3747
+ export class GPTNeoXPreTrainedModel extends PreTrainedModel { }
3856
3748
  export class GPTNeoXModel extends GPTNeoXPreTrainedModel { }
3857
3749
 
3858
3750
  export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
@@ -3861,18 +3753,7 @@ export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { }
3861
3753
 
3862
3754
  //////////////////////////////////////////////////
3863
3755
  // GPT-J models
3864
- export class GPTJPreTrainedModel extends PreTrainedModel {
3865
- /**
3866
- * Creates a new instance of the `GPTJPreTrainedModel` class.
3867
- * @param {Object} config The model configuration.
3868
- * @param {Record<string, any>} sessions The inference sessions for the model.
3869
- * @param {GenerationConfig} generation_config The generation configuration.
3870
- */
3871
- constructor(config, sessions, generation_config) {
3872
- super(config, sessions);
3873
- this.generation_config = generation_config;
3874
- }
3875
- }
3756
+ export class GPTJPreTrainedModel extends PreTrainedModel { }
3876
3757
 
3877
3758
  export class GPTJModel extends GPTJPreTrainedModel { }
3878
3759
 
@@ -3882,18 +3763,7 @@ export class GPTJForCausalLM extends GPTJPreTrainedModel { }
3882
3763
 
3883
3764
  //////////////////////////////////////////////////
3884
3765
  // GPTBigCode models
3885
- export class GPTBigCodePreTrainedModel extends PreTrainedModel {
3886
- /**
3887
- * Creates a new instance of the `GPTBigCodePreTrainedModel` class.
3888
- * @param {Object} config The model configuration.
3889
- * @param {Record<string, any>} sessions The inference sessions for the model.
3890
- * @param {GenerationConfig} generation_config The generation configuration.
3891
- */
3892
- constructor(config, sessions, generation_config) {
3893
- super(config, sessions);
3894
- this.generation_config = generation_config;
3895
- }
3896
- }
3766
+ export class GPTBigCodePreTrainedModel extends PreTrainedModel { }
3897
3767
 
3898
3768
  export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { }
3899
3769
 
@@ -3902,18 +3772,7 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
3902
3772
 
3903
3773
  //////////////////////////////////////////////////
3904
3774
  // CodeGen models
3905
- export class CodeGenPreTrainedModel extends PreTrainedModel {
3906
- /**
3907
- * Creates a new instance of the `CodeGenPreTrainedModel` class.
3908
- * @param {Object} config The model configuration.
3909
- * @param {Record<string, any>} sessions The inference sessions for the model.
3910
- * @param {GenerationConfig} generation_config The generation configuration.
3911
- */
3912
- constructor(config, sessions, generation_config) {
3913
- super(config, sessions);
3914
- this.generation_config = generation_config;
3915
- }
3916
- }
3775
+ export class CodeGenPreTrainedModel extends PreTrainedModel { }
3917
3776
  /**
3918
3777
  * CodeGenModel is a class representing a code generation model without a language model head.
3919
3778
  */
@@ -3932,18 +3791,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
3932
3791
  /**
3933
3792
  * The bare LLama Model outputting raw hidden-states without any specific head on top.
3934
3793
  */
3935
- export class LlamaPreTrainedModel extends PreTrainedModel {
3936
- /**
3937
- * Creates a new instance of the `LlamaPreTrainedModel` class.
3938
- * @param {Object} config The model configuration.
3939
- * @param {Record<string, any>} sessions The inference sessions for the model.
3940
- * @param {GenerationConfig} generation_config The generation configuration.
3941
- */
3942
- constructor(config, sessions, generation_config) {
3943
- super(config, sessions);
3944
- this.generation_config = generation_config;
3945
- }
3946
- }
3794
+ export class LlamaPreTrainedModel extends PreTrainedModel { }
3947
3795
  /**
3948
3796
  * The bare LLaMA Model outputting raw hidden-states without any specific head on top.
3949
3797
  */
@@ -3952,24 +3800,22 @@ export class LlamaModel extends LlamaPreTrainedModel { }
3952
3800
  export class LlamaForCausalLM extends LlamaPreTrainedModel { }
3953
3801
  //////////////////////////////////////////////////
3954
3802
 
3803
+
3804
+ //////////////////////////////////////////////////
3805
+ // Granite models
3806
+ export class GranitePreTrainedModel extends PreTrainedModel { }
3807
+ export class GraniteModel extends GranitePreTrainedModel { }
3808
+ export class GraniteForCausalLM extends GranitePreTrainedModel { }
3809
+ //////////////////////////////////////////////////
3810
+
3811
+
3955
3812
  //////////////////////////////////////////////////
3956
3813
  // Cohere models
3957
3814
 
3958
3815
  /**
3959
3816
  * The bare Cohere Model outputting raw hidden-states without any specific head on top.
3960
3817
  */
3961
- export class CoherePreTrainedModel extends PreTrainedModel {
3962
- /**
3963
- * Creates a new instance of the `CoherePreTrainedModel` class.
3964
- * @param {Object} config The model configuration.
3965
- * @param {Record<string, any>} sessions The inference sessions for the model.
3966
- * @param {GenerationConfig} generation_config The generation configuration.
3967
- */
3968
- constructor(config, sessions, generation_config) {
3969
- super(config, sessions);
3970
- this.generation_config = generation_config;
3971
- }
3972
- }
3818
+ export class CoherePreTrainedModel extends PreTrainedModel { }
3973
3819
  export class CohereModel extends CoherePreTrainedModel { }
3974
3820
 
3975
3821
  export class CohereForCausalLM extends CoherePreTrainedModel { }
@@ -3981,18 +3827,7 @@ export class CohereForCausalLM extends CoherePreTrainedModel { }
3981
3827
  /**
3982
3828
  * The bare Gemma Model outputting raw hidden-states without any specific head on top.
3983
3829
  */
3984
- export class GemmaPreTrainedModel extends PreTrainedModel {
3985
- /**
3986
- * Creates a new instance of the `GemmaPreTrainedModel` class.
3987
- * @param {Object} config The model configuration.
3988
- * @param {Record<string, any>} sessions The inference sessions for the model.
3989
- * @param {GenerationConfig} generation_config The generation configuration.
3990
- */
3991
- constructor(config, sessions, generation_config) {
3992
- super(config, sessions);
3993
- this.generation_config = generation_config;
3994
- }
3995
- }
3830
+ export class GemmaPreTrainedModel extends PreTrainedModel { }
3996
3831
  /**
3997
3832
  * The bare Gemma Model outputting raw hidden-states without any specific head on top.
3998
3833
  */
@@ -4007,18 +3842,7 @@ export class GemmaForCausalLM extends GemmaPreTrainedModel { }
4007
3842
  /**
4008
3843
  * The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
4009
3844
  */
4010
- export class Gemma2PreTrainedModel extends PreTrainedModel {
4011
- /**
4012
- * Creates a new instance of the `Gemma2PreTrainedModel` class.
4013
- * @param {Object} config The model configuration.
4014
- * @param {Record<string, any>} sessions The inference sessions for the model.
4015
- * @param {GenerationConfig} generation_config The generation configuration.
4016
- */
4017
- constructor(config, sessions, generation_config) {
4018
- super(config, sessions);
4019
- this.generation_config = generation_config;
4020
- }
4021
- }
3845
+ export class Gemma2PreTrainedModel extends PreTrainedModel { }
4022
3846
  /**
4023
3847
  * The bare Gemma2 Model outputting raw hidden-states without any specific head on top.
4024
3848
  */
@@ -4028,18 +3852,7 @@ export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { }
4028
3852
  //////////////////////////////////////////////////
4029
3853
 
4030
3854
  //////////////////////////////////////////////////
4031
- export class OpenELMPreTrainedModel extends PreTrainedModel {
4032
- /**
4033
- * Creates a new instance of the `OpenELMPreTrainedModel` class.
4034
- * @param {Object} config The model configuration.
4035
- * @param {Record<string, any>} sessions The inference sessions for the model.
4036
- * @param {GenerationConfig} generation_config The generation configuration.
4037
- */
4038
- constructor(config, sessions, generation_config) {
4039
- super(config, sessions);
4040
- this.generation_config = generation_config;
4041
- }
4042
- }
3855
+ export class OpenELMPreTrainedModel extends PreTrainedModel { }
4043
3856
  export class OpenELMModel extends OpenELMPreTrainedModel { }
4044
3857
 
4045
3858
  export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
@@ -4051,18 +3864,7 @@ export class OpenELMForCausalLM extends OpenELMPreTrainedModel { }
4051
3864
  /**
4052
3865
  * The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
4053
3866
  */
4054
- export class Qwen2PreTrainedModel extends PreTrainedModel {
4055
- /**
4056
- * Creates a new instance of the `Qwen2PreTrainedModel` class.
4057
- * @param {Object} config The model configuration.
4058
- * @param {Record<string, any>} sessions The inference sessions for the model.
4059
- * @param {GenerationConfig} generation_config The generation configuration.
4060
- */
4061
- constructor(config, sessions, generation_config) {
4062
- super(config, sessions);
4063
- this.generation_config = generation_config;
4064
- }
4065
- }
3867
+ export class Qwen2PreTrainedModel extends PreTrainedModel { }
4066
3868
  /**
4067
3869
  * The bare Qwen2 Model outputting raw hidden-states without any specific head on top.
4068
3870
  */
@@ -4074,18 +3876,7 @@ export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { }
4074
3876
 
4075
3877
  //////////////////////////////////////////////////
4076
3878
  // Phi models
4077
- export class PhiPreTrainedModel extends PreTrainedModel {
4078
- /**
4079
- * Creates a new instance of the `PhiPreTrainedModel` class.
4080
- * @param {Object} config The model configuration.
4081
- * @param {Record<string, any>} sessions The inference sessions for the model.
4082
- * @param {GenerationConfig} generation_config The generation configuration.
4083
- */
4084
- constructor(config, sessions, generation_config) {
4085
- super(config, sessions);
4086
- this.generation_config = generation_config;
4087
- }
4088
- }
3879
+ export class PhiPreTrainedModel extends PreTrainedModel { }
4089
3880
  /**
4090
3881
  * The bare Phi Model outputting raw hidden-states without any specific head on top.
4091
3882
  */
@@ -4096,18 +3887,7 @@ export class PhiForCausalLM extends PhiPreTrainedModel { }
4096
3887
 
4097
3888
  //////////////////////////////////////////////////
4098
3889
  // Phi3 models
4099
- export class Phi3PreTrainedModel extends PreTrainedModel {
4100
- /**
4101
- * Creates a new instance of the `Phi3PreTrainedModel` class.
4102
- * @param {Object} config The model configuration.
4103
- * @param {Record<string, any>} sessions The inference sessions for the model.
4104
- * @param {GenerationConfig} generation_config The generation configuration.
4105
- */
4106
- constructor(config, sessions, generation_config) {
4107
- super(config, sessions);
4108
- this.generation_config = generation_config;
4109
- }
4110
- }
3890
+ export class Phi3PreTrainedModel extends PreTrainedModel { }
4111
3891
 
4112
3892
  /**
4113
3893
  * The bare Phi3 Model outputting raw hidden-states without any specific head on top.
@@ -4123,18 +3903,7 @@ export class Phi3ForCausalLM extends Phi3PreTrainedModel { }
4123
3903
  /**
4124
3904
  * The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).
4125
3905
  */
4126
- export class BloomPreTrainedModel extends PreTrainedModel {
4127
- /**
4128
- * Creates a new instance of the `BloomPreTrainedModel` class.
4129
- * @param {Object} config The model configuration.
4130
- * @param {Record<string, any>} sessions The inference sessions for the model.
4131
- * @param {GenerationConfig} generation_config The generation configuration.
4132
- */
4133
- constructor(config, sessions, generation_config) {
4134
- super(config, sessions);
4135
- this.generation_config = generation_config;
4136
- }
4137
- }
3906
+ export class BloomPreTrainedModel extends PreTrainedModel { }
4138
3907
 
4139
3908
  /**
4140
3909
  * The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.
@@ -4149,18 +3918,7 @@ export class BloomForCausalLM extends BloomPreTrainedModel { }
4149
3918
 
4150
3919
  //////////////////////////////////////////////////
4151
3920
  // MPT models
4152
- export class MptPreTrainedModel extends PreTrainedModel {
4153
- /**
4154
- * Creates a new instance of the `MptPreTrainedModel` class.
4155
- * @param {Object} config The model configuration.
4156
- * @param {Record<string, any>} sessions The inference sessions for the model.
4157
- * @param {GenerationConfig} generation_config The generation configuration.
4158
- */
4159
- constructor(config, sessions, generation_config) {
4160
- super(config, sessions);
4161
- this.generation_config = generation_config;
4162
- }
4163
- }
3921
+ export class MptPreTrainedModel extends PreTrainedModel { }
4164
3922
 
4165
3923
  /**
4166
3924
  * The bare Mpt Model transformer outputting raw hidden-states without any specific head on top.
@@ -4176,18 +3934,7 @@ export class MptForCausalLM extends MptPreTrainedModel { }
4176
3934
 
4177
3935
  //////////////////////////////////////////////////
4178
3936
  // OPT models
4179
- export class OPTPreTrainedModel extends PreTrainedModel {
4180
- /**
4181
- * Creates a new instance of the `OPTPreTrainedModel` class.
4182
- * @param {Object} config The model configuration.
4183
- * @param {Record<string, any>} sessions The inference sessions for the model.
4184
- * @param {GenerationConfig} generation_config The generation configuration.
4185
- */
4186
- constructor(config, sessions, generation_config) {
4187
- super(config, sessions);
4188
- this.generation_config = generation_config;
4189
- }
4190
- }
3937
+ export class OPTPreTrainedModel extends PreTrainedModel { }
4191
3938
 
4192
3939
  /**
4193
3940
  * The bare OPT Model outputting raw hidden-states without any specific head on top.
@@ -5049,19 +4796,7 @@ export class SamImageSegmentationOutput extends ModelOutput {
5049
4796
 
5050
4797
  //////////////////////////////////////////////////
5051
4798
  // MarianMT models
5052
- export class MarianPreTrainedModel extends PreTrainedModel {
5053
-
5054
- /**
5055
- * Creates a new instance of the `MarianMTModel` class.
5056
- * @param {Object} config The model configuration.
5057
- * @param {Record<string, any>} sessions The inference sessions for the model.
5058
- * @param {GenerationConfig} generation_config The generation configuration.
5059
- */
5060
- constructor(config, sessions, generation_config) {
5061
- super(config, sessions);
5062
- this.generation_config = generation_config;
5063
- }
5064
- };
4799
+ export class MarianPreTrainedModel extends PreTrainedModel { };
5065
4800
 
5066
4801
  export class MarianModel extends MarianPreTrainedModel { }
5067
4802
 
@@ -5070,19 +4805,7 @@ export class MarianMTModel extends MarianPreTrainedModel { }
5070
4805
 
5071
4806
  //////////////////////////////////////////////////
5072
4807
  // M2M100 models
5073
- export class M2M100PreTrainedModel extends PreTrainedModel {
5074
-
5075
- /**
5076
- * Creates a new instance of the `M2M100ForConditionalGeneration` class.
5077
- * @param {Object} config The model configuration.
5078
- * @param {Record<string, any>} sessions The inference sessions for the model.
5079
- * @param {GenerationConfig} generation_config The generation configuration.
5080
- */
5081
- constructor(config, sessions, generation_config) {
5082
- super(config, sessions);
5083
- this.generation_config = generation_config;
5084
- }
5085
- };
4808
+ export class M2M100PreTrainedModel extends PreTrainedModel { };
5086
4809
 
5087
4810
  export class M2M100Model extends M2M100PreTrainedModel { }
5088
4811
 
@@ -5592,19 +5315,7 @@ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
5592
5315
  /**
5593
5316
  * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
5594
5317
  */
5595
- export class SpeechT5PreTrainedModel extends PreTrainedModel {
5596
-
5597
- /**
5598
- * Creates a new instance of the `SpeechT5ForTextToSpeech` class.
5599
- * @param {Object} config The model configuration.
5600
- * @param {Record<string, any>} sessions The inference sessions for the model.
5601
- * @param {GenerationConfig} generation_config The generation configuration.
5602
- */
5603
- constructor(config, sessions, generation_config) {
5604
- super(config, sessions);
5605
- this.generation_config = generation_config;
5606
- }
5607
- };
5318
+ export class SpeechT5PreTrainedModel extends PreTrainedModel { };
5608
5319
 
5609
5320
  /**
5610
5321
  * The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
@@ -5765,18 +5476,7 @@ export class SpeechT5HifiGan extends PreTrainedModel {
5765
5476
 
5766
5477
  //////////////////////////////////////////////////
5767
5478
  // TrOCR models
5768
- export class TrOCRPreTrainedModel extends PreTrainedModel {
5769
- /**
5770
- * Creates a new instance of the `TrOCRPreTrainedModel` class.
5771
- * @param {Object} config The configuration of the model.
5772
- * @param {any} session The ONNX session containing the model weights.
5773
- * @param {GenerationConfig} generation_config The generation configuration.
5774
- */
5775
- constructor(config, session, generation_config) {
5776
- super(config, session);
5777
- this.generation_config = generation_config;
5778
- }
5779
- }
5479
+ export class TrOCRPreTrainedModel extends PreTrainedModel { }
5780
5480
 
5781
5481
  /**
5782
5482
  * The TrOCR Decoder with a language modeling head.
@@ -5791,18 +5491,7 @@ export class TrOCRForCausalLM extends TrOCRPreTrainedModel { }
5791
5491
  /**
5792
5492
  * The bare Mistral Model outputting raw hidden-states without any specific head on top.
5793
5493
  */
5794
- export class MistralPreTrainedModel extends PreTrainedModel {
5795
- /**
5796
- * Creates a new instance of the `MistralPreTrainedModel` class.
5797
- * @param {Object} config The configuration of the model.
5798
- * @param {any} session The ONNX session containing the model weights.
5799
- * @param {GenerationConfig} generation_config The generation configuration.
5800
- */
5801
- constructor(config, session, generation_config) {
5802
- super(config, session);
5803
- this.generation_config = generation_config;
5804
- }
5805
- }
5494
+ export class MistralPreTrainedModel extends PreTrainedModel { }
5806
5495
 
5807
5496
  export class MistralModel extends MistralPreTrainedModel { }
5808
5497
 
@@ -5815,18 +5504,7 @@ export class MistralForCausalLM extends MistralPreTrainedModel { }
5815
5504
  /**
5816
5505
  * The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.
5817
5506
  */
5818
- export class Starcoder2PreTrainedModel extends PreTrainedModel {
5819
- /**
5820
- * Creates a new instance of the `Starcoder2PreTrainedModel` class.
5821
- * @param {Object} config The configuration of the model.
5822
- * @param {any} session The ONNX session containing the model weights.
5823
- * @param {GenerationConfig} generation_config The generation configuration.
5824
- */
5825
- constructor(config, session, generation_config) {
5826
- super(config, session);
5827
- this.generation_config = generation_config;
5828
- }
5829
- }
5507
+ export class Starcoder2PreTrainedModel extends PreTrainedModel { }
5830
5508
 
5831
5509
  export class Starcoder2Model extends Starcoder2PreTrainedModel { }
5832
5510
 
@@ -5839,18 +5517,7 @@ export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel { }
5839
5517
  /**
5840
5518
  * The bare Falcon Model outputting raw hidden-states without any specific head on top.
5841
5519
  */
5842
- export class FalconPreTrainedModel extends PreTrainedModel {
5843
- /**
5844
- * Creates a new instance of the `FalconPreTrainedModel` class.
5845
- * @param {Object} config The configuration of the model.
5846
- * @param {any} session The ONNX session containing the model weights.
5847
- * @param {GenerationConfig} generation_config The generation configuration.
5848
- */
5849
- constructor(config, session, generation_config) {
5850
- super(config, session);
5851
- this.generation_config = generation_config;
5852
- }
5853
- }
5520
+ export class FalconPreTrainedModel extends PreTrainedModel { }
5854
5521
 
5855
5522
  export class FalconModel extends FalconPreTrainedModel { }
5856
5523
 
@@ -6000,18 +5667,7 @@ export class SegformerForSemanticSegmentation extends SegformerPreTrainedModel {
6000
5667
 
6001
5668
  //////////////////////////////////////////////////
6002
5669
  // StableLm models
6003
- export class StableLmPreTrainedModel extends PreTrainedModel {
6004
- /**
6005
- * Creates a new instance of the `StableLmPreTrainedModel` class.
6006
- * @param {Object} config The configuration of the model.
6007
- * @param {any} session The ONNX session containing the model weights.
6008
- * @param {GenerationConfig} generation_config The generation configuration.
6009
- */
6010
- constructor(config, session, generation_config) {
6011
- super(config, session);
6012
- this.generation_config = generation_config;
6013
- }
6014
- }
5670
+ export class StableLmPreTrainedModel extends PreTrainedModel { }
6015
5671
 
6016
5672
  /**
6017
5673
  * The bare StableLm Model transformer outputting raw hidden-states without any specific head on top.
@@ -6105,17 +5761,6 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { // NOTE:
6105
5761
  'past_key_values',
6106
5762
  ];
6107
5763
 
6108
- /**
6109
- * Creates a new instance of the `MusicgenForConditionalGeneration` class.
6110
- * @param {Object} config The model configuration.
6111
- * @param {Record<string, any>} sessions The inference sessions for the model.
6112
- * @param {GenerationConfig} generation_config The generation configuration.
6113
- */
6114
- constructor(config, sessions, generation_config) {
6115
- super(config, sessions);
6116
- this.generation_config = generation_config;
6117
- }
6118
-
6119
5764
  /**
6120
5765
  * Apply the pattern mask to the final ids,
6121
5766
  * then revert the pattern delay mask by filtering the pad token id in a single step.
@@ -6471,6 +6116,7 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
6471
6116
  ['gpt_neox', ['GPTNeoXModel', GPTNeoXModel]],
6472
6117
  ['codegen', ['CodeGenModel', CodeGenModel]],
6473
6118
  ['llama', ['LlamaModel', LlamaModel]],
6119
+ ['granite', ['GraniteModel', GraniteModel]],
6474
6120
  ['cohere', ['CohereModel', CohereModel]],
6475
6121
  ['gemma', ['GemmaModel', GemmaModel]],
6476
6122
  ['gemma2', ['Gemma2Model', Gemma2Model]],
@@ -6559,6 +6205,7 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
6559
6205
  ['gpt_neox', ['GPTNeoXForCausalLM', GPTNeoXForCausalLM]],
6560
6206
  ['codegen', ['CodeGenForCausalLM', CodeGenForCausalLM]],
6561
6207
  ['llama', ['LlamaForCausalLM', LlamaForCausalLM]],
6208
+ ['granite', ['GraniteForCausalLM', GraniteForCausalLM]],
6562
6209
  ['cohere', ['CohereForCausalLM', CohereForCausalLM]],
6563
6210
  ['gemma', ['GemmaForCausalLM', GemmaForCausalLM]],
6564
6211
  ['gemma2', ['Gemma2ForCausalLM', Gemma2ForCausalLM]],