@huggingface/transformers 3.0.0-alpha.21 → 3.0.0-alpha.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3930,9 +3930,10 @@ let wasmInitPromise = null;
3930
3930
  * Create an ONNX inference session.
3931
3931
  * @param {Uint8Array} buffer The ONNX model buffer.
3932
3932
  * @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
3933
- * @returns {Promise<import('onnxruntime-common').InferenceSession>} The ONNX inference session.
3933
+ * @param {Object} session_config ONNX inference session configuration.
3934
+ * @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
3934
3935
  */
3935
- async function createInferenceSession(buffer, session_options) {
3936
+ async function createInferenceSession(buffer, session_options, session_config) {
3936
3937
  if (wasmInitPromise) {
3937
3938
  // A previous session has already initialized the WASM runtime
3938
3939
  // so we wait for it to resolve before creating this new session.
@@ -3941,7 +3942,9 @@ async function createInferenceSession(buffer, session_options) {
3941
3942
 
3942
3943
  const sessionPromise = InferenceSession.create(buffer, session_options);
3943
3944
  wasmInitPromise ??= sessionPromise;
3944
- return await sessionPromise;
3945
+ const session = await sessionPromise;
3946
+ session.config = session_config;
3947
+ return session;
3945
3948
  }
3946
3949
 
3947
3950
  /**
@@ -4376,7 +4379,7 @@ class AutoConfig {
4376
4379
  /**
4377
4380
  * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
4378
4381
  * @typedef {Object} TransformersJSConfig
4379
- * @property {import('./utils/tensor.js').DataType} [kv_cache_dtype] The data type of the key-value cache.
4382
+ * @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>} [kv_cache_dtype] The data type of the key-value cache.
4380
4383
  * @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
4381
4384
  * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
4382
4385
  * for more information.
@@ -4430,7 +4433,7 @@ __webpack_require__.r(__webpack_exports__);
4430
4433
 
4431
4434
 
4432
4435
 
4433
- const VERSION = '3.0.0-alpha.21';
4436
+ const VERSION = '3.0.0-alpha.22';
4434
4437
 
4435
4438
  // Check if various APIs are available (depends on environment)
4436
4439
  const IS_BROWSER_ENV = typeof self !== 'undefined';
@@ -6877,7 +6880,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
6877
6880
  * @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
6878
6881
  * @param {string} fileName The name of the model file.
6879
6882
  * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
6880
- * @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
6883
+ * @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
6881
6884
  * @private
6882
6885
  */
6883
6886
  async function getSession(pretrained_model_name_or_path, fileName, options) {
@@ -6918,6 +6921,22 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
6918
6921
  throw new Error(`The device (${selectedDevice}) does not support fp16.`);
6919
6922
  }
6920
6923
 
6924
+ // Only valid for models with a decoder
6925
+ const kv_cache_dtype = custom_config.kv_cache_dtype
6926
+ ? (typeof custom_config.kv_cache_dtype === 'string'
6927
+ ? custom_config.kv_cache_dtype
6928
+ : custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
6929
+ : undefined;
6930
+
6931
+ if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
6932
+ throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
6933
+ }
6934
+
6935
+ const session_config = {
6936
+ dtype: selectedDtype,
6937
+ kv_cache_dtype,
6938
+ }
6939
+
6921
6940
  // Construct the model file name
6922
6941
  const suffix = _utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
6923
6942
  const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
@@ -6993,7 +7012,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
6993
7012
  }
6994
7013
 
6995
7014
  const buffer = await bufferPromise;
6996
- return { buffer, session_options };
7015
+
7016
+ return { buffer, session_options, session_config };
6997
7017
  }
6998
7018
 
6999
7019
  /**
@@ -7008,8 +7028,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
7008
7028
  async function constructSessions(pretrained_model_name_or_path, names, options) {
7009
7029
  return Object.fromEntries(await Promise.all(
7010
7030
  Object.keys(names).map(async (name) => {
7011
- const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options);
7012
- const session = await (0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.createInferenceSession)(buffer, session_options);
7031
+ const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
7032
+ const session = await (0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.createInferenceSession)(buffer, session_options, session_config);
7013
7033
  return [name, session];
7014
7034
  })
7015
7035
  ));
@@ -8366,9 +8386,8 @@ class PreTrainedModel extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_3__.Cal
8366
8386
  if (pastKeyValues) {
8367
8387
  Object.assign(decoderFeeds, pastKeyValues)
8368
8388
  } else {
8369
-
8370
- /** @type {import('./transformers.js').DataType} */
8371
- const dtype = this.custom_config.kv_cache_dtype ?? 'float32';
8389
+ const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
8390
+ const dtype = session?.config?.kv_cache_dtype ?? 'float32';
8372
8391
  const empty = (dtype === 'float16') ? new Uint16Array() : [];
8373
8392
 
8374
8393
  const shapes = (0,_configs_js__WEBPACK_IMPORTED_MODULE_0__.getKeyValueShapes)(this.config);