@huggingface/transformers 3.2.4 → 3.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. package/README.md +5 -3
  2. package/dist/ort-wasm-simd-threaded.jsep.mjs +135 -0
  3. package/dist/ort-wasm-simd-threaded.jsep.wasm +0 -0
  4. package/dist/transformers.cjs +598 -247
  5. package/dist/transformers.cjs.map +1 -1
  6. package/dist/transformers.js +956 -573
  7. package/dist/transformers.js.map +1 -1
  8. package/dist/transformers.min.cjs +1 -1
  9. package/dist/transformers.min.cjs.map +1 -1
  10. package/dist/transformers.min.js +1 -1
  11. package/dist/transformers.min.js.map +1 -1
  12. package/dist/transformers.min.mjs +1 -1
  13. package/dist/transformers.min.mjs.map +1 -1
  14. package/dist/transformers.mjs +604 -248
  15. package/dist/transformers.mjs.map +1 -1
  16. package/package.json +3 -3
  17. package/src/base/image_processors_utils.js +1 -1
  18. package/src/base/processing_utils.js +11 -0
  19. package/src/env.js +1 -2
  20. package/src/generation/streamers.js +5 -2
  21. package/src/models/grounding_dino/image_processing_grounding_dino.js +29 -0
  22. package/src/models/grounding_dino/processing_grounding_dino.js +101 -0
  23. package/src/models/image_processors.js +1 -0
  24. package/src/models/processors.js +3 -2
  25. package/src/models.js +22 -5
  26. package/src/pipelines.js +39 -16
  27. package/src/utils/audio.js +113 -1
  28. package/src/utils/core.js +26 -0
  29. package/src/utils/image.js +5 -18
  30. package/src/utils/tensor.js +100 -112
  31. package/types/base/image_processors_utils.d.ts +7 -0
  32. package/types/base/image_processors_utils.d.ts.map +1 -1
  33. package/types/base/processing_utils.d.ts +8 -0
  34. package/types/base/processing_utils.d.ts.map +1 -1
  35. package/types/generation/streamers.d.ts +3 -1
  36. package/types/generation/streamers.d.ts.map +1 -1
  37. package/types/models/auto/image_processing_auto.d.ts.map +1 -1
  38. package/types/models/grounding_dino/image_processing_grounding_dino.d.ts +20 -0
  39. package/types/models/grounding_dino/image_processing_grounding_dino.d.ts.map +1 -0
  40. package/types/models/grounding_dino/processing_grounding_dino.d.ts +27 -0
  41. package/types/models/grounding_dino/processing_grounding_dino.d.ts.map +1 -0
  42. package/types/models/image_processors.d.ts +1 -0
  43. package/types/models/processors.d.ts +3 -2
  44. package/types/models.d.ts +8 -0
  45. package/types/models.d.ts.map +1 -1
  46. package/types/pipelines.d.ts +5 -10
  47. package/types/pipelines.d.ts.map +1 -1
  48. package/types/tsconfig.tsbuildinfo +1 -1
  49. package/types/utils/audio.d.ts +25 -0
  50. package/types/utils/audio.d.ts.map +1 -1
  51. package/types/utils/core.d.ts +6 -0
  52. package/types/utils/core.d.ts.map +1 -1
  53. package/types/utils/image.d.ts.map +1 -1
  54. package/types/utils/tensor.d.ts +14 -2
  55. package/types/utils/tensor.d.ts.map +1 -1
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@huggingface/transformers",
3
- "version": "3.2.4",
3
+ "version": "3.3.1",
4
4
  "description": "State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!",
5
5
  "main": "./src/transformers.js",
6
6
  "types": "./types/transformers.d.ts",
@@ -26,7 +26,7 @@
26
26
  "format:check": "prettier --check .",
27
27
  "typegen": "tsc --build",
28
28
  "dev": "webpack serve --no-client-overlay",
29
- "build": "webpack && npm run typegen",
29
+ "build": "webpack && npm run typegen && rm ./dist/ort.bundle.min.mjs && cp ./node_modules/onnxruntime-web/dist/ort-wasm-simd-threaded.jsep.mjs ./dist",
30
30
  "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose",
31
31
  "readme": "python ./docs/scripts/build_readme.py",
32
32
  "docs-api": "node ./docs/scripts/generate.js",
@@ -57,7 +57,7 @@
57
57
  "dependencies": {
58
58
  "@huggingface/jinja": "^0.3.2",
59
59
  "onnxruntime-node": "1.20.1",
60
- "onnxruntime-web": "1.21.0-dev.20241205-d27fecd3d3",
60
+ "onnxruntime-web": "1.21.0-dev.20250114-228dd16893",
61
61
  "sharp": "^0.33.5"
62
62
  },
63
63
  "devDependencies": {
@@ -68,7 +68,7 @@ function enforce_size_divisibility([width, height], divisor) {
68
68
  * @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height)
69
69
  * @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
70
70
  */
71
- function center_to_corners_format([centerX, centerY, width, height]) {
71
+ export function center_to_corners_format([centerX, centerY, width, height]) {
72
72
  return [
73
73
  centerX - width / 2,
74
74
  centerY - height / 2,
@@ -101,6 +101,17 @@ export class Processor extends Callable {
101
101
  return this.tokenizer.batch_decode(...args);
102
102
  }
103
103
 
104
+ /**
105
+ * @param {Parameters<PreTrainedTokenizer['decode']>} args
106
+ * @returns {ReturnType<PreTrainedTokenizer['decode']>}
107
+ */
108
+ decode(...args) {
109
+ if (!this.tokenizer) {
110
+ throw new Error('Unable to decode without a tokenizer.');
111
+ }
112
+ return this.tokenizer.decode(...args);
113
+ }
114
+
104
115
 
105
116
  /**
106
117
  * Calls the feature_extractor function with the given input.
package/src/env.js CHANGED
@@ -26,7 +26,7 @@ import fs from 'fs';
26
26
  import path from 'path';
27
27
  import url from 'url';
28
28
 
29
- const VERSION = '3.2.4';
29
+ const VERSION = '3.3.1';
30
30
 
31
31
  // Check if various APIs are available (depends on environment)
32
32
  const IS_BROWSER_ENV = typeof window !== "undefined" && typeof window.document !== "undefined";
@@ -160,4 +160,3 @@ export const env = {
160
160
  function isEmpty(obj) {
161
161
  return Object.keys(obj).length === 0;
162
162
  }
163
-
@@ -37,6 +37,7 @@ export class TextStreamer extends BaseStreamer {
37
37
  * @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
38
38
  * @param {Object} options
39
39
  * @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
40
+ * @param {boolean} [options.skip_special_tokens=true] Whether to skip special tokens when decoding
40
41
  * @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
41
42
  * @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
42
43
  * @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
@@ -45,6 +46,7 @@ export class TextStreamer extends BaseStreamer {
45
46
  skip_prompt = false,
46
47
  callback_function = null,
47
48
  token_callback_function = null,
49
+ skip_special_tokens = true,
48
50
  decode_kwargs = {},
49
51
  ...kwargs
50
52
  } = {}) {
@@ -53,7 +55,7 @@ export class TextStreamer extends BaseStreamer {
53
55
  this.skip_prompt = skip_prompt;
54
56
  this.callback_function = callback_function ?? stdout_write;
55
57
  this.token_callback_function = token_callback_function;
56
- this.decode_kwargs = { ...decode_kwargs, ...kwargs };
58
+ this.decode_kwargs = { skip_special_tokens, ...decode_kwargs, ...kwargs };
57
59
 
58
60
  // variables used in the streaming process
59
61
  this.token_cache = [];
@@ -169,9 +171,10 @@ export class WhisperTextStreamer extends TextStreamer {
169
171
  } = {}) {
170
172
  super(tokenizer, {
171
173
  skip_prompt,
174
+ skip_special_tokens,
172
175
  callback_function,
173
176
  token_callback_function,
174
- decode_kwargs: { skip_special_tokens, ...decode_kwargs },
177
+ decode_kwargs,
175
178
  });
176
179
  this.timestamp_begin = tokenizer.timestamp_begin;
177
180
 
@@ -0,0 +1,29 @@
1
+
2
+ import {
3
+ ImageProcessor,
4
+ } from "../../base/image_processors_utils.js";
5
+ import { ones } from '../../utils/tensor.js';
6
+
7
+
8
+ /**
9
+ * @typedef {object} GroundingDinoFeatureExtractorResultProps
10
+ * @property {import('../../utils/tensor.js').Tensor} pixel_mask
11
+ * @typedef {import('../../base/image_processors_utils.js').ImageProcessorResult & GroundingDinoFeatureExtractorResultProps} GroundingDinoFeatureExtractorResult
12
+ */
13
+
14
+ export class GroundingDinoImageProcessor extends ImageProcessor {
15
+ /**
16
+ * Calls the feature extraction process on an array of images, preprocesses
17
+ * each image, and concatenates the resulting features into a single Tensor.
18
+ * @param {import('../../utils/image.js').RawImage[]} images The image(s) to extract features from.
19
+ * @returns {Promise<GroundingDinoFeatureExtractorResult>} An object containing the concatenated pixel values of the preprocessed images.
20
+ */
21
+ async _call(images) {
22
+ const result = await super._call(images);
23
+
24
+ const dims = result.pixel_values.dims;
25
+ const pixel_mask = ones([dims[0], dims[2], dims[3]]);
26
+
27
+ return { ...result, pixel_mask };
28
+ }
29
+ }
@@ -0,0 +1,101 @@
1
+ import { Processor } from "../../base/processing_utils.js";
2
+ import { AutoImageProcessor } from "../auto/image_processing_auto.js";
3
+ import { AutoTokenizer } from "../../tokenizers.js";
4
+ import { center_to_corners_format } from "../../base/image_processors_utils.js";
5
+
6
+ /**
7
+ * Get token ids of phrases from posmaps and input_ids.
8
+ * @param {import('../../utils/tensor.js').Tensor} posmaps A boolean tensor of unbatched text-thresholded logits related to the detected bounding boxes of shape `(hidden_size, )`.
9
+ * @param {import('../../utils/tensor.js').Tensor} input_ids A tensor of token ids of shape `(sequence_length, )`.
10
+ */
11
+ function get_phrases_from_posmap(posmaps, input_ids) {
12
+
13
+ const left_idx = 0;
14
+ const right_idx = posmaps.dims.at(-1) - 1;
15
+
16
+ const posmaps_list = posmaps.tolist();
17
+ posmaps_list.fill(false, 0, left_idx + 1);
18
+ posmaps_list.fill(false, right_idx);
19
+
20
+ const input_ids_list = input_ids.tolist();
21
+ return posmaps_list
22
+ .map((val, idx) => val ? idx : null)
23
+ .filter(idx => idx !== null)
24
+ .map(i => input_ids_list[i]);
25
+ }
26
+
27
+ export class GroundingDinoProcessor extends Processor {
28
+ static tokenizer_class = AutoTokenizer
29
+ static image_processor_class = AutoImageProcessor
30
+
31
+ /**
32
+ * @typedef {import('../../utils/image.js').RawImage} RawImage
33
+ */
34
+ /**
35
+ *
36
+ * @param {RawImage|RawImage[]|RawImage[][]} images
37
+ * @param {string|string[]} text
38
+ * @returns {Promise<any>}
39
+ */
40
+ async _call(images, text, options = {}) {
41
+
42
+ const image_inputs = images ? await this.image_processor(images, options) : {};
43
+ const text_inputs = text ? this.tokenizer(text, options) : {};
44
+
45
+ return {
46
+ ...text_inputs,
47
+ ...image_inputs,
48
+ }
49
+ }
50
+ post_process_grounded_object_detection(outputs, input_ids, {
51
+ box_threshold = 0.25,
52
+ text_threshold = 0.25,
53
+ target_sizes = null
54
+ } = {}) {
55
+ const { logits, pred_boxes } = outputs;
56
+ const batch_size = logits.dims[0];
57
+
58
+ if (target_sizes !== null && target_sizes.length !== batch_size) {
59
+ throw Error("Make sure that you pass in as many target sizes as the batch dimension of the logits")
60
+ }
61
+ const num_queries = logits.dims.at(1);
62
+
63
+ const probs = logits.sigmoid(); // (batch_size, num_queries, 256)
64
+ const scores = probs.max(-1).tolist(); // (batch_size, num_queries)
65
+
66
+ // Convert to [x0, y0, x1, y1] format
67
+ const boxes = pred_boxes.tolist() // (batch_size, num_queries, 4)
68
+ .map(batch => batch.map(box => center_to_corners_format(box)));
69
+
70
+ const results = [];
71
+ for (let i = 0; i < batch_size; ++i) {
72
+ const target_size = target_sizes !== null ? target_sizes[i] : null;
73
+
74
+ // Convert from relative [0, 1] to absolute [0, height] coordinates
75
+ if (target_size !== null) {
76
+ boxes[i] = boxes[i].map(box => box.map((x, j) => x * target_size[(j + 1) % 2]));
77
+ }
78
+
79
+ const batch_scores = scores[i];
80
+ const final_scores = [];
81
+ const final_phrases = [];
82
+ const final_boxes = [];
83
+ for (let j = 0; j < num_queries; ++j) {
84
+ const score = batch_scores[j];
85
+ if (score <= box_threshold) {
86
+ continue;
87
+ }
88
+ const box = boxes[i][j];
89
+ const prob = probs[i][j];
90
+
91
+ final_scores.push(score);
92
+ final_boxes.push(box);
93
+
94
+ const phrases = get_phrases_from_posmap(prob.gt(text_threshold), input_ids[i]);
95
+ final_phrases.push(phrases);
96
+ }
97
+ results.push({ scores: final_scores, boxes: final_boxes, labels: this.batch_decode(final_phrases) });
98
+ }
99
+ return results;
100
+ }
101
+ }
@@ -10,6 +10,7 @@ export * from './donut/image_processing_donut.js'
10
10
  export * from './dpt/image_processing_dpt.js'
11
11
  export * from './efficientnet/image_processing_efficientnet.js'
12
12
  export * from './glpn/image_processing_glpn.js'
13
+ export * from './grounding_dino/image_processing_grounding_dino.js'
13
14
  export * from './idefics3/image_processing_idefics3.js'
14
15
  export * from './janus/image_processing_janus.js'
15
16
  export * from './jina_clip/image_processing_jina_clip.js'
@@ -1,9 +1,10 @@
1
1
  export * from './florence2/processing_florence2.js';
2
- export * from './mgp_str/processing_mgp_str.js';
3
- export * from './moonshine/processing_moonshine.js';
2
+ export * from './grounding_dino/processing_grounding_dino.js';
4
3
  export * from './idefics3/processing_idefics3.js';
5
4
  export * from './janus/processing_janus.js';
6
5
  export * from './jina_clip/processing_jina_clip.js';
6
+ export * from './mgp_str/processing_mgp_str.js';
7
+ export * from './moonshine/processing_moonshine.js';
7
8
  export * from './owlvit/processing_owlvit.js';
8
9
  export * from './phi3_v/processing_phi3_v.js';
9
10
  export * from './paligemma/processing_paligemma.js';
package/src/models.js CHANGED
@@ -532,14 +532,23 @@ async function encoderForward(self, model_inputs) {
532
532
  encoderFeeds.inputs_embeds = await self.encode_text({ input_ids: model_inputs.input_ids });
533
533
  }
534
534
  if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
535
+ if (!encoderFeeds.input_ids) {
536
+ throw new Error('Both `input_ids` and `token_type_ids` are missing in the model inputs.');
537
+ }
535
538
  // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it,
536
539
  // but they weren't created by the tokenizer.
537
- encoderFeeds.token_type_ids = new Tensor(
538
- 'int64',
539
- new BigInt64Array(encoderFeeds.input_ids.data.length),
540
- encoderFeeds.input_ids.dims
541
- )
540
+ encoderFeeds.token_type_ids = zeros_like(encoderFeeds.input_ids);
541
+ }
542
+ if (session.inputNames.includes('pixel_mask') && !encoderFeeds.pixel_mask) {
543
+ if (!encoderFeeds.pixel_values) {
544
+ throw new Error('Both `pixel_values` and `pixel_mask` are missing in the model inputs.');
545
+ }
546
+ // Assign default `pixel_mask` (all ones) to the `encoderFeeds` if the model expects it,
547
+ // but they weren't created by the processor.
548
+ const dims = encoderFeeds.pixel_values.dims;
549
+ encoderFeeds.pixel_mask = ones([dims[0], dims[2], dims[3]]);
542
550
  }
551
+
543
552
  return await sessionRun(session, encoderFeeds);
544
553
  }
545
554
 
@@ -5428,6 +5437,8 @@ export class Dinov2WithRegistersForImageClassification extends Dinov2WithRegiste
5428
5437
  }
5429
5438
  }
5430
5439
  //////////////////////////////////////////////////
5440
+ export class GroundingDinoPreTrainedModel extends PreTrainedModel { }
5441
+ export class GroundingDinoForObjectDetection extends GroundingDinoPreTrainedModel { }
5431
5442
 
5432
5443
  //////////////////////////////////////////////////
5433
5444
  export class YolosPreTrainedModel extends PreTrainedModel { }
@@ -6126,6 +6137,9 @@ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {
6126
6137
  }
6127
6138
  }
6128
6139
 
6140
+ export class StyleTextToSpeech2PreTrainedModel extends PreTrainedModel { }
6141
+ export class StyleTextToSpeech2Model extends StyleTextToSpeech2PreTrainedModel { }
6142
+
6129
6143
  //////////////////////////////////////////////////
6130
6144
  // SpeechT5 models
6131
6145
  /**
@@ -7089,6 +7103,8 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
7089
7103
 
7090
7104
  ['maskformer', ['MaskFormerModel', MaskFormerModel]],
7091
7105
  ['mgp-str', ['MgpstrForSceneTextRecognition', MgpstrForSceneTextRecognition]],
7106
+
7107
+ ['style_text_to_speech_2', ['StyleTextToSpeech2Model', StyleTextToSpeech2Model]],
7092
7108
  ]);
7093
7109
 
7094
7110
  const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
@@ -7333,6 +7349,7 @@ const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([
7333
7349
  const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
7334
7350
  ['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]],
7335
7351
  ['owlv2', ['Owlv2ForObjectDetection', Owlv2ForObjectDetection]],
7352
+ ['grounding-dino', ['GroundingDinoForObjectDetection', GroundingDinoForObjectDetection]],
7336
7353
  ]);
7337
7354
 
7338
7355
  const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
package/src/pipelines.js CHANGED
@@ -64,7 +64,8 @@ import {
64
64
  round,
65
65
  } from './utils/maths.js';
66
66
  import {
67
- read_audio
67
+ read_audio,
68
+ RawAudio
68
69
  } from './utils/audio.js';
69
70
  import {
70
71
  Tensor,
@@ -2552,13 +2553,35 @@ export class ZeroShotObjectDetectionPipeline extends (/** @type {new (options: T
2552
2553
  // Run model with both text and pixel inputs
2553
2554
  const output = await this.model({ ...text_inputs, pixel_values });
2554
2555
 
2555
- // @ts-ignore
2556
- const processed = this.processor.image_processor.post_process_object_detection(output, threshold, imageSize, true)[0];
2557
- let result = processed.boxes.map((box, i) => ({
2558
- score: processed.scores[i],
2559
- label: candidate_labels[processed.classes[i]],
2560
- box: get_bounding_box(box, !percentage),
2561
- })).sort((a, b) => b.score - a.score);
2556
+ let result;
2557
+ if('post_process_grounded_object_detection' in this.processor) {
2558
+ // @ts-ignore
2559
+ const processed = this.processor.post_process_grounded_object_detection(
2560
+ output,
2561
+ text_inputs.input_ids,
2562
+ {
2563
+ // TODO: support separate threshold values
2564
+ box_threshold: threshold,
2565
+ text_threshold: threshold,
2566
+ target_sizes: imageSize,
2567
+ },
2568
+ )[0];
2569
+ result = processed.boxes.map((box, i) => ({
2570
+ score: processed.scores[i],
2571
+ label: processed.labels[i],
2572
+ box: get_bounding_box(box, !percentage),
2573
+ }))
2574
+ } else {
2575
+ // @ts-ignore
2576
+ const processed = this.processor.image_processor.post_process_object_detection(output, threshold, imageSize, true)[0];
2577
+ result = processed.boxes.map((box, i) => ({
2578
+ score: processed.scores[i],
2579
+ label: candidate_labels[processed.classes[i]],
2580
+ box: get_bounding_box(box, !percentage),
2581
+ }))
2582
+ }
2583
+ result.sort((a, b) => b.score - a.score);
2584
+
2562
2585
  if (top_k !== null) {
2563
2586
  result = result.slice(0, top_k);
2564
2587
  }
@@ -2678,7 +2701,7 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
2678
2701
  * const synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts', { quantized: false });
2679
2702
  * const speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin';
2680
2703
  * const out = await synthesizer('Hello, my dog is cute', { speaker_embeddings });
2681
- * // {
2704
+ * // RawAudio {
2682
2705
  * // audio: Float32Array(26112) [-0.00005657337896991521, 0.00020583874720614403, ...],
2683
2706
  * // sampling_rate: 16000
2684
2707
  * // }
@@ -2698,7 +2721,7 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
2698
2721
  * ```javascript
2699
2722
  * const synthesizer = await pipeline('text-to-speech', 'Xenova/mms-tts-fra');
2700
2723
  * const out = await synthesizer('Bonjour');
2701
- * // {
2724
+ * // RawAudio {
2702
2725
  * // audio: Float32Array(23808) [-0.00037693005288019776, 0.0003325853613205254, ...],
2703
2726
  * // sampling_rate: 16000
2704
2727
  * // }
@@ -2745,10 +2768,10 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi
2745
2768
 
2746
2769
  // @ts-expect-error TS2339
2747
2770
  const sampling_rate = this.model.config.sampling_rate;
2748
- return {
2749
- audio: waveform.data,
2771
+ return new RawAudio(
2772
+ waveform.data,
2750
2773
  sampling_rate,
2751
- }
2774
+ )
2752
2775
  }
2753
2776
 
2754
2777
  async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {
@@ -2788,10 +2811,10 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi
2788
2811
  const { waveform } = await this.model.generate_speech(input_ids, speaker_embeddings, { vocoder: this.vocoder });
2789
2812
 
2790
2813
  const sampling_rate = this.processor.feature_extractor.config.sampling_rate;
2791
- return {
2792
- audio: waveform.data,
2814
+ return new RawAudio(
2815
+ waveform.data,
2793
2816
  sampling_rate,
2794
- }
2817
+ )
2795
2818
  }
2796
2819
  }
2797
2820
 
@@ -12,8 +12,10 @@ import {
12
12
  } from './hub.js';
13
13
  import { FFT, max } from './maths.js';
14
14
  import {
15
- calculateReflectOffset,
15
+ calculateReflectOffset, saveBlob,
16
16
  } from './core.js';
17
+ import { apis } from '../env.js';
18
+ import fs from 'fs';
17
19
  import { Tensor, matmul } from './tensor.js';
18
20
 
19
21
 
@@ -702,3 +704,113 @@ export function window_function(window_length, name, {
702
704
 
703
705
  return window;
704
706
  }
707
+
708
+ /**
709
+ * Encode audio data to a WAV file.
710
+ * WAV file specs : https://en.wikipedia.org/wiki/WAV#WAV_File_header
711
+ *
712
+ * Adapted from https://www.npmjs.com/package/audiobuffer-to-wav
713
+ * @param {Float32Array} samples The audio samples.
714
+ * @param {number} rate The sample rate.
715
+ * @returns {ArrayBuffer} The WAV audio buffer.
716
+ */
717
+ function encodeWAV(samples, rate) {
718
+ let offset = 44;
719
+ const buffer = new ArrayBuffer(offset + samples.length * 4);
720
+ const view = new DataView(buffer);
721
+
722
+ /* RIFF identifier */
723
+ writeString(view, 0, "RIFF");
724
+ /* RIFF chunk length */
725
+ view.setUint32(4, 36 + samples.length * 4, true);
726
+ /* RIFF type */
727
+ writeString(view, 8, "WAVE");
728
+ /* format chunk identifier */
729
+ writeString(view, 12, "fmt ");
730
+ /* format chunk length */
731
+ view.setUint32(16, 16, true);
732
+ /* sample format (raw) */
733
+ view.setUint16(20, 3, true);
734
+ /* channel count */
735
+ view.setUint16(22, 1, true);
736
+ /* sample rate */
737
+ view.setUint32(24, rate, true);
738
+ /* byte rate (sample rate * block align) */
739
+ view.setUint32(28, rate * 4, true);
740
+ /* block align (channel count * bytes per sample) */
741
+ view.setUint16(32, 4, true);
742
+ /* bits per sample */
743
+ view.setUint16(34, 32, true);
744
+ /* data chunk identifier */
745
+ writeString(view, 36, "data");
746
+ /* data chunk length */
747
+ view.setUint32(40, samples.length * 4, true);
748
+
749
+ for (let i = 0; i < samples.length; ++i, offset += 4) {
750
+ view.setFloat32(offset, samples[i], true);
751
+ }
752
+
753
+ return buffer;
754
+ }
755
+
756
+ function writeString(view, offset, string) {
757
+ for (let i = 0; i < string.length; ++i) {
758
+ view.setUint8(offset + i, string.charCodeAt(i));
759
+ }
760
+ }
761
+
762
+
763
+ export class RawAudio {
764
+
765
+ /**
766
+ * Create a new `RawAudio` object.
767
+ * @param {Float32Array} audio Audio data
768
+ * @param {number} sampling_rate Sampling rate of the audio data
769
+ */
770
+ constructor(audio, sampling_rate) {
771
+ this.audio = audio
772
+ this.sampling_rate = sampling_rate
773
+ }
774
+
775
+ /**
776
+ * Convert the audio to a wav file buffer.
777
+ * @returns {ArrayBuffer} The WAV file.
778
+ */
779
+ toWav() {
780
+ return encodeWAV(this.audio, this.sampling_rate)
781
+ }
782
+
783
+ /**
784
+ * Convert the audio to a blob.
785
+ * @returns {Blob}
786
+ */
787
+ toBlob() {
788
+ const wav = this.toWav();
789
+ const blob = new Blob([wav], { type: 'audio/wav' });
790
+ return blob;
791
+ }
792
+
793
+ /**
794
+ * Save the audio to a wav file.
795
+ * @param {string} path
796
+ */
797
+ async save(path) {
798
+ let fn;
799
+
800
+ if (apis.IS_BROWSER_ENV) {
801
+ if (apis.IS_WEBWORKER_ENV) {
802
+ throw new Error('Unable to save a file from a Web Worker.')
803
+ }
804
+ fn = saveBlob;
805
+ } else if (apis.IS_FS_AVAILABLE) {
806
+ fn = async (/** @type {string} */ path, /** @type {Blob} */ blob) => {
807
+ let buffer = await blob.arrayBuffer();
808
+ fs.writeFileSync(path, Buffer.from(buffer));
809
+ }
810
+ } else {
811
+ throw new Error('Unable to save because filesystem is disabled in this environment.')
812
+ }
813
+
814
+ await fn(path, this.toBlob())
815
+ }
816
+ }
package/src/utils/core.js CHANGED
@@ -189,6 +189,32 @@ export function calculateReflectOffset(i, w) {
189
189
  return Math.abs((i + w) % (2 * w) - w);
190
190
  }
191
191
 
192
+ /**
193
+ * Save blob file on the web.
194
+ * @param {string} path The path to save the blob to
195
+ * @param {Blob} blob The blob to save
196
+ */
197
+ export function saveBlob(path, blob){
198
+ // Convert the canvas content to a data URL
199
+ const dataURL = URL.createObjectURL(blob);
200
+
201
+ // Create an anchor element with the data URL as the href attribute
202
+ const downloadLink = document.createElement('a');
203
+ downloadLink.href = dataURL;
204
+
205
+ // Set the download attribute to specify the desired filename for the downloaded image
206
+ downloadLink.download = path;
207
+
208
+ // Trigger the download
209
+ downloadLink.click();
210
+
211
+ // Clean up: remove the anchor element from the DOM
212
+ downloadLink.remove();
213
+
214
+ // Revoke the Object URL to free up memory
215
+ URL.revokeObjectURL(dataURL);
216
+ }
217
+
192
218
  /**
193
219
  *
194
220
  * @param {Object} o
@@ -8,9 +8,9 @@
8
8
  * @module utils/image
9
9
  */
10
10
 
11
- import { isNullishDimension } from './core.js';
11
+ import { isNullishDimension, saveBlob } from './core.js';
12
12
  import { getFile } from './hub.js';
13
- import { env, apis } from '../env.js';
13
+ import { apis } from '../env.js';
14
14
  import { Tensor } from './tensor.js';
15
15
 
16
16
  // Will be empty (or not used) if running in browser or web-worker
@@ -793,23 +793,9 @@ export class RawImage {
793
793
  // Convert image to Blob
794
794
  const blob = await this.toBlob(mime);
795
795
 
796
- // Convert the canvas content to a data URL
797
- const dataURL = URL.createObjectURL(blob);
796
+ saveBlob(path, blob)
798
797
 
799
- // Create an anchor element with the data URL as the href attribute
800
- const downloadLink = document.createElement('a');
801
- downloadLink.href = dataURL;
802
-
803
- // Set the download attribute to specify the desired filename for the downloaded image
804
- downloadLink.download = path;
805
-
806
- // Trigger the download
807
- downloadLink.click();
808
-
809
- // Clean up: remove the anchor element from the DOM
810
- downloadLink.remove();
811
-
812
- } else if (!env.useFS) {
798
+ } else if (!apis.IS_FS_AVAILABLE) {
813
799
  throw new Error('Unable to save the image because filesystem is disabled in this environment.')
814
800
 
815
801
  } else {
@@ -837,3 +823,4 @@ export class RawImage {
837
823
  * Helper function to load an image from a URL, path, etc.
838
824
  */
839
825
  export const load_image = RawImage.read.bind(RawImage);
826
+