@drakulavich/parakeet-cli 0.1.4 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@drakulavich/parakeet-cli",
3
- "version": "0.1.4",
3
+ "version": "0.2.0",
4
4
  "description": "Fast multilingual speech-to-text CLI powered by NVIDIA Parakeet ONNX models",
5
5
  "type": "module",
6
6
  "bin": {
@@ -1,5 +1,5 @@
1
1
  import { describe, test, expect } from "bun:test";
2
- import { greedyDecode, type DecoderSession } from "../decoder";
2
+ import { beamDecode, type DecoderSession } from "../decoder";
3
3
 
4
4
  function mockSession(responses: Array<{ tokenLogits: number[]; durationLogits: number[] }>): DecoderSession {
5
5
  let callIndex = 0;
@@ -24,7 +24,8 @@ describe("decoder", () => {
24
24
  { tokenLogits: [0, 10, 0, -10], durationLogits: [10, 0] },
25
25
  { tokenLogits: [0, 0, 0, 10], durationLogits: [10, 0] },
26
26
  ]);
27
- const tokens = await greedyDecode(session, 3);
27
+ const encoderData = new Float32Array(3);
28
+ const tokens = await beamDecode(session, 3, encoderData, 1, 1);
28
29
  expect(tokens).toEqual([0, 1]);
29
30
  });
30
31
 
@@ -34,24 +35,16 @@ describe("decoder", () => {
34
35
  { tokenLogits: [0, 10, 0, -10], durationLogits: [10, 0, 0] },
35
36
  { tokenLogits: [0, 0, 0, 10], durationLogits: [10, 0, 0] },
36
37
  ]);
37
- const tokens = await greedyDecode(session, 5);
38
+ const encoderData = new Float32Array(5);
39
+ const tokens = await beamDecode(session, 5, encoderData, 1, 1);
38
40
  expect(tokens).toEqual([0, 1]);
39
41
  });
40
42
 
41
- test("handles max_tokens_per_step limit", async () => {
42
- const session = mockSession([
43
- { tokenLogits: [10, 0, 0, -10], durationLogits: [10, 0] },
44
- ]);
45
- const tokens = await greedyDecode(session, 2);
46
- expect(tokens.length).toBeLessThanOrEqual(20);
47
- expect(tokens.length).toBeGreaterThan(0);
48
- });
49
-
50
43
  test("returns empty for zero-length encoder output", async () => {
51
44
  const session = mockSession([
52
45
  { tokenLogits: [0, 0, 0, 10], durationLogits: [10, 0] },
53
46
  ]);
54
- const tokens = await greedyDecode(session, 0);
47
+ const tokens = await beamDecode(session, 0, new Float32Array(0), 1);
55
48
  expect(tokens).toEqual([]);
56
49
  });
57
50
  });
package/src/decoder.ts CHANGED
@@ -20,69 +20,87 @@ export interface DecoderSession {
20
20
  stateDims: { layers: number; hidden: number };
21
21
  }
22
22
 
23
- export async function greedyDecode(
23
+ const DEFAULT_BEAM_WIDTH = 4;
24
+
25
+ interface Beam {
26
+ tokens: number[];
27
+ score: number;
28
+ lastToken: number;
29
+ state1: F32;
30
+ state2: F32;
31
+ t: number;
32
+ }
33
+
34
+ export async function beamDecode(
24
35
  session: DecoderSession,
25
36
  encoderLength: number,
26
- encoderData?: Float32Array,
27
- encoderDim?: number
37
+ encoderData: Float32Array,
38
+ encoderDim: number,
39
+ beamWidth: number = DEFAULT_BEAM_WIDTH,
28
40
  ): Promise<number[]> {
29
41
  if (encoderLength === 0) return [];
30
42
 
31
- const tokens: number[] = [];
32
43
  const stateSize = session.stateDims.layers * session.stateDims.hidden;
33
- let state1: F32 = new Float32Array(stateSize);
34
- let state2: F32 = new Float32Array(stateSize);
35
- let lastToken = session.blankId;
36
-
37
- let t = 0;
38
- while (t < encoderLength) {
39
- let tokensThisStep = 0;
40
-
41
- while (tokensThisStep < MAX_TOKENS_PER_STEP) {
42
- let frame: Float32Array;
43
- if (encoderData && encoderDim) {
44
- // Must copy — ort.Tensor doesn't work with subarray views under Bun
45
- frame = encoderData.slice(t * encoderDim, (t + 1) * encoderDim);
46
- } else {
47
- frame = new Float32Array(1);
48
- }
49
44
 
50
- const result = await session.decode(frame, [lastToken], 1, state1, state2);
45
+ let beams: Beam[] = [{
46
+ tokens: [],
47
+ score: 0,
48
+ lastToken: session.blankId,
49
+ state1: new Float32Array(stateSize),
50
+ state2: new Float32Array(stateSize),
51
+ t: 0,
52
+ }];
53
+
54
+ const maxSteps = encoderLength * MAX_TOKENS_PER_STEP;
55
+
56
+ for (let step = 0; step < maxSteps; step++) {
57
+ const active = beams.filter(b => b.t < encoderLength);
58
+ if (active.length === 0) break;
59
+
60
+ const candidates: Beam[] = [];
61
+
62
+ for (const beam of active) {
63
+ // Must copy — ort.Tensor doesn't work with subarray views under Bun
64
+ const frame = encoderData.slice(beam.t * encoderDim, (beam.t + 1) * encoderDim);
65
+ const result = await session.decode(frame, [beam.lastToken], 1, beam.state1, beam.state2);
51
66
  const output = result.output;
52
67
 
53
68
  const tokenLogits = output.slice(0, session.vocabSize);
54
69
  const durationLogits = output.slice(session.vocabSize);
55
-
56
- const tokenId = argmax(tokenLogits);
57
70
  const duration = argmax(durationLogits);
58
71
 
59
- state1 = result.state1;
60
- state2 = result.state2;
61
-
62
- if (tokenId === session.blankId) {
63
- t += 1;
64
- break;
65
- }
66
-
67
- tokens.push(tokenId);
68
- lastToken = tokenId;
69
- tokensThisStep++;
72
+ // Blank option: advance one frame, keep same tokens
73
+ candidates.push({
74
+ tokens: beam.tokens,
75
+ score: beam.score + tokenLogits[session.blankId],
76
+ lastToken: beam.lastToken,
77
+ state1: result.state1,
78
+ state2: result.state2,
79
+ t: beam.t + 1,
80
+ });
70
81
 
71
- if (duration > 0) {
72
- t += duration;
73
- break;
82
+ // Top non-blank token options
83
+ const topK = topKIndices(tokenLogits, beamWidth, session.blankId);
84
+ for (const tokenId of topK) {
85
+ candidates.push({
86
+ tokens: [...beam.tokens, tokenId],
87
+ score: beam.score + tokenLogits[tokenId],
88
+ lastToken: tokenId,
89
+ state1: result.state1,
90
+ state2: result.state2,
91
+ t: duration > 0 ? beam.t + duration : beam.t,
92
+ });
74
93
  }
75
94
  }
76
95
 
77
- if (tokensThisStep >= MAX_TOKENS_PER_STEP) {
78
- t += 1;
79
- }
96
+ candidates.sort((a, b) => b.score - a.score);
97
+ beams = candidates.slice(0, beamWidth);
80
98
  }
81
99
 
82
- return tokens;
100
+ return beams[0].tokens;
83
101
  }
84
102
 
85
- function argmax(arr: Float32Array): number {
103
+ function argmax(arr: F32): number {
86
104
  let maxIdx = 0;
87
105
  let maxVal = arr[0];
88
106
  for (let i = 1; i < arr.length; i++) {
@@ -94,6 +112,15 @@ function argmax(arr: Float32Array): number {
94
112
  return maxIdx;
95
113
  }
96
114
 
115
+ function topKIndices(arr: F32, k: number, excludeId: number): number[] {
116
+ const indexed: [number, number][] = [];
117
+ for (let i = 0; i < arr.length; i++) {
118
+ if (i !== excludeId) indexed.push([arr[i], i]);
119
+ }
120
+ indexed.sort((a, b) => b[0] - a[0]);
121
+ return indexed.slice(0, k).map(([, i]) => i);
122
+ }
123
+
97
124
  let onnxSession: ort.InferenceSession | null = null;
98
125
 
99
126
  export async function initDecoder(modelDir: string): Promise<void> {
package/src/transcribe.ts CHANGED
@@ -5,7 +5,7 @@ import { initEncoder, encode } from "./encoder";
5
5
  import {
6
6
  initDecoder,
7
7
  createOnnxDecoderSession,
8
- greedyDecode,
8
+ beamDecode,
9
9
  } from "./decoder";
10
10
  import { Tokenizer } from "./tokenizer";
11
11
  import { join } from "path";
@@ -63,6 +63,6 @@ export async function transcribe(audioPath: string, opts: TranscribeOptions = {}
63
63
  DECODER_HIDDEN,
64
64
  );
65
65
 
66
- const tokens = await greedyDecode(session, encodedLength, transposed, D);
66
+ const tokens = await beamDecode(session, encodedLength, transposed, D);
67
67
  return tokenizer.detokenize(tokens);
68
68
  }