@huggingface/transformers 3.0.0-alpha.21 → 3.0.0-alpha.22

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -101,7 +101,7 @@ npm i @huggingface/transformers
101
101
  Alternatively, you can use it in vanilla JS, without any bundler, by using a CDN or static hosting. For example, using [ES Modules](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Guide/Modules), you can import the library with:
102
102
  ```html
103
103
  <script type="module">
104
- import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.21';
104
+ import { pipeline } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.22';
105
105
  </script>
106
106
  ```
107
107
 
@@ -134,7 +134,7 @@ Check out the Transformers.js [template](https://huggingface.co/new-space?templa
134
134
 
135
135
 
136
136
 
137
- By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.21/dist/), which should work out-of-the-box. You can customize this as follows:
137
+ By default, Transformers.js uses [hosted pretrained models](https://huggingface.co/models?library=transformers.js) and [precompiled WASM binaries](https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.0-alpha.22/dist/), which should work out-of-the-box. You can customize this as follows:
138
138
 
139
139
  ### Settings
140
140
 
@@ -3955,9 +3955,10 @@ let wasmInitPromise = null;
3955
3955
  * Create an ONNX inference session.
3956
3956
  * @param {Uint8Array} buffer The ONNX model buffer.
3957
3957
  * @param {import('onnxruntime-common').InferenceSession.SessionOptions} session_options ONNX inference session options.
3958
- * @returns {Promise<import('onnxruntime-common').InferenceSession>} The ONNX inference session.
3958
+ * @param {Object} session_config ONNX inference session configuration.
3959
+ * @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
3959
3960
  */
3960
- async function createInferenceSession(buffer, session_options) {
3961
+ async function createInferenceSession(buffer, session_options, session_config) {
3961
3962
  if (wasmInitPromise) {
3962
3963
  // A previous session has already initialized the WASM runtime
3963
3964
  // so we wait for it to resolve before creating this new session.
@@ -3966,7 +3967,9 @@ async function createInferenceSession(buffer, session_options) {
3966
3967
 
3967
3968
  const sessionPromise = InferenceSession.create(buffer, session_options);
3968
3969
  wasmInitPromise ??= sessionPromise;
3969
- return await sessionPromise;
3970
+ const session = await sessionPromise;
3971
+ session.config = session_config;
3972
+ return session;
3970
3973
  }
3971
3974
 
3972
3975
  /**
@@ -4402,7 +4405,7 @@ class AutoConfig {
4402
4405
  /**
4403
4406
  * Transformers.js-specific configuration, possibly present in config.json under the key `transformers.js_config`.
4404
4407
  * @typedef {Object} TransformersJSConfig
4405
- * @property {import('./utils/tensor.js').DataType} [kv_cache_dtype] The data type of the key-value cache.
4408
+ * @property {import('./utils/tensor.js').DataType|Record<import('./utils/dtypes.js').DataType, import('./utils/tensor.js').DataType>} [kv_cache_dtype] The data type of the key-value cache.
4406
4409
  * @property {Record<string, number>} [free_dimension_overrides] Override the free dimensions of the model.
4407
4410
  * See https://onnxruntime.ai/docs/tutorials/web/env-flags-and-session-options.html#freedimensionoverrides
4408
4411
  * for more information.
@@ -4457,7 +4460,7 @@ __webpack_require__.r(__webpack_exports__);
4457
4460
 
4458
4461
 
4459
4462
 
4460
- const VERSION = '3.0.0-alpha.21';
4463
+ const VERSION = '3.0.0-alpha.22';
4461
4464
 
4462
4465
  // Check if various APIs are available (depends on environment)
4463
4466
  const IS_BROWSER_ENV = typeof self !== 'undefined';
@@ -6910,7 +6913,7 @@ const MODEL_CLASS_TO_NAME_MAPPING = new Map();
6910
6913
  * @param {string} pretrained_model_name_or_path The path to the directory containing the model file.
6911
6914
  * @param {string} fileName The name of the model file.
6912
6915
  * @param {import('./utils/hub.js').PretrainedModelOptions} options Additional options for loading the model.
6913
- * @returns {Promise<{buffer: Uint8Array, session_options: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
6916
+ * @returns {Promise<{buffer: Uint8Array, session_options: Object, session_config: Object}>} A Promise that resolves to the data needed to create an InferenceSession object.
6914
6917
  * @private
6915
6918
  */
6916
6919
  async function getSession(pretrained_model_name_or_path, fileName, options) {
@@ -6951,6 +6954,22 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
6951
6954
  throw new Error(`The device (${selectedDevice}) does not support fp16.`);
6952
6955
  }
6953
6956
 
6957
+ // Only valid for models with a decoder
6958
+ const kv_cache_dtype = custom_config.kv_cache_dtype
6959
+ ? (typeof custom_config.kv_cache_dtype === 'string'
6960
+ ? custom_config.kv_cache_dtype
6961
+ : custom_config.kv_cache_dtype[selectedDtype] ?? 'float32')
6962
+ : undefined;
6963
+
6964
+ if (kv_cache_dtype && !['float32', 'float16'].includes(kv_cache_dtype)) {
6965
+ throw new Error(`Invalid kv_cache_dtype: ${kv_cache_dtype}. Should be one of: float32, float16`);
6966
+ }
6967
+
6968
+ const session_config = {
6969
+ dtype: selectedDtype,
6970
+ kv_cache_dtype,
6971
+ }
6972
+
6954
6973
  // Construct the model file name
6955
6974
  const suffix = _utils_dtypes_js__WEBPACK_IMPORTED_MODULE_2__.DEFAULT_DTYPE_SUFFIX_MAPPING[selectedDtype];
6956
6975
  const modelFileName = `${options.subfolder ?? ''}/${fileName}${suffix}.onnx`;
@@ -7026,7 +7045,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
7026
7045
  }
7027
7046
 
7028
7047
  const buffer = await bufferPromise;
7029
- return { buffer, session_options };
7048
+
7049
+ return { buffer, session_options, session_config };
7030
7050
  }
7031
7051
 
7032
7052
  /**
@@ -7041,8 +7061,8 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
7041
7061
  async function constructSessions(pretrained_model_name_or_path, names, options) {
7042
7062
  return Object.fromEntries(await Promise.all(
7043
7063
  Object.keys(names).map(async (name) => {
7044
- const { buffer, session_options } = await getSession(pretrained_model_name_or_path, names[name], options);
7045
- const session = await (0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.createInferenceSession)(buffer, session_options);
7064
+ const { buffer, session_options, session_config } = await getSession(pretrained_model_name_or_path, names[name], options);
7065
+ const session = await (0,_backends_onnx_js__WEBPACK_IMPORTED_MODULE_1__.createInferenceSession)(buffer, session_options, session_config);
7046
7066
  return [name, session];
7047
7067
  })
7048
7068
  ));
@@ -8399,9 +8419,8 @@ class PreTrainedModel extends _utils_generic_js__WEBPACK_IMPORTED_MODULE_3__.Cal
8399
8419
  if (pastKeyValues) {
8400
8420
  Object.assign(decoderFeeds, pastKeyValues)
8401
8421
  } else {
8402
-
8403
- /** @type {import('./transformers.js').DataType} */
8404
- const dtype = this.custom_config.kv_cache_dtype ?? 'float32';
8422
+ const session = this.sessions['decoder_model_merged'] ?? this.sessions['model'];
8423
+ const dtype = session?.config?.kv_cache_dtype ?? 'float32';
8405
8424
  const empty = (dtype === 'float16') ? new Uint16Array() : [];
8406
8425
 
8407
8426
  const shapes = (0,_configs_js__WEBPACK_IMPORTED_MODULE_0__.getKeyValueShapes)(this.config);