edgeflowjs 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. package/README.md +200 -66
  2. package/dist/backends/index.d.ts +9 -2
  3. package/dist/backends/index.d.ts.map +1 -1
  4. package/dist/backends/index.js +13 -13
  5. package/dist/backends/index.js.map +1 -1
  6. package/dist/backends/onnx.d.ts +11 -4
  7. package/dist/backends/onnx.d.ts.map +1 -1
  8. package/dist/backends/onnx.js +97 -78
  9. package/dist/backends/onnx.js.map +1 -1
  10. package/dist/backends/transformers-adapter.d.ts +99 -0
  11. package/dist/backends/transformers-adapter.d.ts.map +1 -0
  12. package/dist/backends/transformers-adapter.js +171 -0
  13. package/dist/backends/transformers-adapter.js.map +1 -0
  14. package/dist/backends/webgpu.d.ts +7 -5
  15. package/dist/backends/webgpu.d.ts.map +1 -1
  16. package/dist/backends/webgpu.js +7 -5
  17. package/dist/backends/webgpu.js.map +1 -1
  18. package/dist/backends/webnn.d.ts +6 -5
  19. package/dist/backends/webnn.d.ts.map +1 -1
  20. package/dist/backends/webnn.js +6 -5
  21. package/dist/backends/webnn.js.map +1 -1
  22. package/dist/core/composer.d.ts +118 -0
  23. package/dist/core/composer.d.ts.map +1 -0
  24. package/dist/core/composer.js +163 -0
  25. package/dist/core/composer.js.map +1 -0
  26. package/dist/core/device-profiler.d.ts +75 -0
  27. package/dist/core/device-profiler.d.ts.map +1 -0
  28. package/dist/core/device-profiler.js +131 -0
  29. package/dist/core/device-profiler.js.map +1 -0
  30. package/dist/core/index.d.ts +4 -0
  31. package/dist/core/index.d.ts.map +1 -1
  32. package/dist/core/index.js +8 -0
  33. package/dist/core/index.js.map +1 -1
  34. package/dist/core/memory.d.ts +22 -2
  35. package/dist/core/memory.d.ts.map +1 -1
  36. package/dist/core/memory.js +49 -13
  37. package/dist/core/memory.js.map +1 -1
  38. package/dist/core/plugin.d.ts +100 -0
  39. package/dist/core/plugin.d.ts.map +1 -0
  40. package/dist/core/plugin.js +106 -0
  41. package/dist/core/plugin.js.map +1 -0
  42. package/dist/core/runtime.d.ts +4 -0
  43. package/dist/core/runtime.d.ts.map +1 -1
  44. package/dist/core/runtime.js +18 -0
  45. package/dist/core/runtime.js.map +1 -1
  46. package/dist/core/scheduler.d.ts +17 -0
  47. package/dist/core/scheduler.d.ts.map +1 -1
  48. package/dist/core/scheduler.js +101 -3
  49. package/dist/core/scheduler.js.map +1 -1
  50. package/dist/core/types.d.ts +14 -0
  51. package/dist/core/types.d.ts.map +1 -1
  52. package/dist/core/types.js.map +1 -1
  53. package/dist/core/worker.d.ts +202 -0
  54. package/dist/core/worker.d.ts.map +1 -0
  55. package/dist/core/worker.js +477 -0
  56. package/dist/core/worker.js.map +1 -0
  57. package/dist/edgeflow.browser.js +9770 -4383
  58. package/dist/edgeflow.browser.js.map +4 -4
  59. package/dist/edgeflow.browser.min.js +435 -5
  60. package/dist/edgeflow.browser.min.js.map +4 -4
  61. package/dist/index.d.ts +7 -4
  62. package/dist/index.d.ts.map +1 -1
  63. package/dist/index.js +28 -10
  64. package/dist/index.js.map +1 -1
  65. package/dist/pipelines/automatic-speech-recognition.d.ts +63 -0
  66. package/dist/pipelines/automatic-speech-recognition.d.ts.map +1 -0
  67. package/dist/pipelines/automatic-speech-recognition.js +269 -0
  68. package/dist/pipelines/automatic-speech-recognition.js.map +1 -0
  69. package/dist/pipelines/base.d.ts +6 -1
  70. package/dist/pipelines/base.d.ts.map +1 -1
  71. package/dist/pipelines/base.js +12 -2
  72. package/dist/pipelines/base.js.map +1 -1
  73. package/dist/pipelines/feature-extraction.d.ts +5 -40
  74. package/dist/pipelines/feature-extraction.d.ts.map +1 -1
  75. package/dist/pipelines/feature-extraction.js +44 -63
  76. package/dist/pipelines/feature-extraction.js.map +1 -1
  77. package/dist/pipelines/image-classification.d.ts +4 -36
  78. package/dist/pipelines/image-classification.d.ts.map +1 -1
  79. package/dist/pipelines/image-classification.js +22 -60
  80. package/dist/pipelines/image-classification.js.map +1 -1
  81. package/dist/pipelines/image-segmentation.d.ts +221 -0
  82. package/dist/pipelines/image-segmentation.d.ts.map +1 -0
  83. package/dist/pipelines/image-segmentation.js +535 -0
  84. package/dist/pipelines/image-segmentation.js.map +1 -0
  85. package/dist/pipelines/index.d.ts +18 -0
  86. package/dist/pipelines/index.d.ts.map +1 -1
  87. package/dist/pipelines/index.js +51 -2
  88. package/dist/pipelines/index.js.map +1 -1
  89. package/dist/pipelines/object-detection.d.ts +44 -0
  90. package/dist/pipelines/object-detection.d.ts.map +1 -0
  91. package/dist/pipelines/object-detection.js +218 -0
  92. package/dist/pipelines/object-detection.js.map +1 -0
  93. package/dist/pipelines/question-answering.d.ts +41 -0
  94. package/dist/pipelines/question-answering.d.ts.map +1 -0
  95. package/dist/pipelines/question-answering.js +164 -0
  96. package/dist/pipelines/question-answering.js.map +1 -0
  97. package/dist/pipelines/text-classification.d.ts +3 -39
  98. package/dist/pipelines/text-classification.d.ts.map +1 -1
  99. package/dist/pipelines/text-classification.js +29 -67
  100. package/dist/pipelines/text-classification.js.map +1 -1
  101. package/dist/pipelines/text-generation.d.ts +281 -0
  102. package/dist/pipelines/text-generation.d.ts.map +1 -0
  103. package/dist/pipelines/text-generation.js +766 -0
  104. package/dist/pipelines/text-generation.js.map +1 -0
  105. package/dist/pipelines/zero-shot-classification.d.ts +45 -0
  106. package/dist/pipelines/zero-shot-classification.d.ts.map +1 -0
  107. package/dist/pipelines/zero-shot-classification.js +140 -0
  108. package/dist/pipelines/zero-shot-classification.js.map +1 -0
  109. package/dist/tools/benchmark.d.ts +92 -0
  110. package/dist/tools/benchmark.d.ts.map +1 -0
  111. package/dist/tools/benchmark.js +213 -0
  112. package/dist/tools/benchmark.js.map +1 -0
  113. package/dist/tools/debugger.d.ts +258 -0
  114. package/dist/tools/debugger.d.ts.map +1 -0
  115. package/dist/tools/debugger.js +624 -0
  116. package/dist/tools/debugger.js.map +1 -0
  117. package/dist/tools/index.d.ts +8 -0
  118. package/dist/tools/index.d.ts.map +1 -1
  119. package/dist/tools/index.js +16 -0
  120. package/dist/tools/index.js.map +1 -1
  121. package/dist/tools/monitor.d.ts +284 -0
  122. package/dist/tools/monitor.d.ts.map +1 -0
  123. package/dist/tools/monitor.js +921 -0
  124. package/dist/tools/monitor.js.map +1 -0
  125. package/dist/tools/quantization.d.ts +235 -0
  126. package/dist/tools/quantization.d.ts.map +1 -0
  127. package/dist/tools/quantization.js +830 -0
  128. package/dist/tools/quantization.js.map +1 -0
  129. package/dist/utils/hub.d.ts +162 -0
  130. package/dist/utils/hub.d.ts.map +1 -0
  131. package/dist/utils/hub.js +311 -0
  132. package/dist/utils/hub.js.map +1 -0
  133. package/dist/utils/index.d.ts +3 -1
  134. package/dist/utils/index.d.ts.map +1 -1
  135. package/dist/utils/index.js +5 -1
  136. package/dist/utils/index.js.map +1 -1
  137. package/dist/utils/model-loader.d.ts.map +1 -1
  138. package/dist/utils/model-loader.js +106 -30
  139. package/dist/utils/model-loader.js.map +1 -1
  140. package/dist/utils/offline.d.ts +147 -0
  141. package/dist/utils/offline.d.ts.map +1 -0
  142. package/dist/utils/offline.js +405 -0
  143. package/dist/utils/offline.js.map +1 -0
  144. package/dist/utils/preprocessor.d.ts +82 -6
  145. package/dist/utils/preprocessor.d.ts.map +1 -1
  146. package/dist/utils/preprocessor.js +278 -21
  147. package/dist/utils/preprocessor.js.map +1 -1
  148. package/dist/utils/tokenizer.d.ts +197 -72
  149. package/dist/utils/tokenizer.d.ts.map +1 -1
  150. package/dist/utils/tokenizer.js +558 -274
  151. package/dist/utils/tokenizer.js.map +1 -1
  152. package/package.json +26 -11
@@ -0,0 +1,766 @@
1
+ /**
2
+ * edgeFlow.js - Text Generation Pipeline
3
+ *
4
+ * Autoregressive text generation with streaming support.
5
+ * Supports GPT-2, LLaMA, Mistral, and other causal LM models.
6
+ * Includes chat/conversation support with message history.
7
+ */
8
+ import { BasePipeline } from './base.js';
9
+ import { Tokenizer } from '../utils/tokenizer.js';
10
+ import { EdgeFlowTensor, softmax } from '../core/tensor.js';
11
+ import { runInferenceNamed, loadModelFromBuffer } from '../core/runtime.js';
12
+ // ============================================================================
13
+ // Default Model URLs (TinyLlama - quantized for browser)
14
+ // ============================================================================
15
+ const DEFAULT_LLM_MODELS = {
16
+ model: 'https://huggingface.co/Xenova/TinyLlama-1.1B-Chat-v1.0/resolve/main/onnx/model_q4f16.onnx',
17
+ tokenizer: 'https://huggingface.co/Xenova/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json',
18
+ };
19
+ // ============================================================================
20
+ // Text Generation Pipeline
21
+ // ============================================================================
22
+ /**
23
+ * TextGenerationPipeline - Autoregressive text generation
24
+ *
25
+ * @example
26
+ * ```typescript
27
+ * const generator = await pipeline('text-generation', 'Xenova/gpt2');
28
+ *
29
+ * // Simple generation
30
+ * const result = await generator.run('Once upon a time');
31
+ * console.log(result.generatedText);
32
+ *
33
+ * // Streaming generation
34
+ * for await (const event of generator.stream('Hello, ')) {
35
+ * process.stdout.write(event.token);
36
+ * }
37
+ * ```
38
+ */
39
+ export class TextGenerationPipeline extends BasePipeline {
40
+ tokenizer = null;
41
+ eosTokenId = 50256; // GPT-2 default
42
+ llmModel = null;
43
+ modelsLoaded = false;
44
+ // Custom model URLs
45
+ modelUrl;
46
+ tokenizerUrl;
47
+ constructor(config) {
48
+ super(config ?? {
49
+ task: 'text-generation',
50
+ model: 'default',
51
+ });
52
+ this.modelUrl = DEFAULT_LLM_MODELS.model;
53
+ this.tokenizerUrl = DEFAULT_LLM_MODELS.tokenizer;
54
+ }
55
+ /**
56
+ * Check if model is loaded
57
+ */
58
+ get isModelLoaded() {
59
+ return this.modelsLoaded;
60
+ }
61
+ /**
62
+ * Set custom model URLs
63
+ */
64
+ setModelUrls(model, tokenizer) {
65
+ this.modelUrl = model;
66
+ this.tokenizerUrl = tokenizer;
67
+ }
68
+ /**
69
+ * Load model and tokenizer with progress callback
70
+ */
71
+ async loadModel(onProgress) {
72
+ if (this.modelsLoaded)
73
+ return;
74
+ // Load tokenizer first (small, fast)
75
+ onProgress?.({ stage: 'tokenizer', loaded: 0, total: 100, progress: 0 });
76
+ try {
77
+ const tokenizerResponse = await fetch(this.tokenizerUrl);
78
+ if (!tokenizerResponse.ok) {
79
+ throw new Error(`Failed to fetch tokenizer: ${tokenizerResponse.status}`);
80
+ }
81
+ const tokenizerJson = await tokenizerResponse.json();
82
+ this.tokenizer = await Tokenizer.fromJSON(tokenizerJson);
83
+ const specialIds = this.tokenizer.getSpecialTokenIds();
84
+ this.eosTokenId = specialIds.eosTokenId ?? specialIds.sepTokenId ?? 2; // TinyLlama uses 2 as EOS
85
+ onProgress?.({ stage: 'tokenizer', loaded: 100, total: 100, progress: 100 });
86
+ }
87
+ catch (error) {
88
+ throw new Error(`Failed to load tokenizer: ${error}`);
89
+ }
90
+ // Load model with progress tracking
91
+ onProgress?.({ stage: 'model', loaded: 0, total: 100, progress: 0 });
92
+ const modelData = await this.fetchModelWithProgress(this.modelUrl, (loaded, total) => {
93
+ onProgress?.({
94
+ stage: 'model',
95
+ loaded,
96
+ total,
97
+ progress: Math.round((loaded / total) * 100),
98
+ });
99
+ });
100
+ this.llmModel = await loadModelFromBuffer(modelData, {
101
+ runtime: 'wasm', // Uses ONNXRuntime which auto-detects WebGPU internally
102
+ });
103
+ this.model = this.llmModel;
104
+ this.modelsLoaded = true;
105
+ }
106
+ /**
107
+ * Fetch model with progress tracking
108
+ */
109
+ async fetchModelWithProgress(url, onProgress) {
110
+ const response = await fetch(url);
111
+ if (!response.ok) {
112
+ throw new Error(`Failed to fetch model: ${response.status} ${response.statusText}`);
113
+ }
114
+ const contentLength = response.headers.get('content-length');
115
+ const total = contentLength ? parseInt(contentLength, 10) : 0;
116
+ if (!response.body) {
117
+ // Fallback if no streaming support
118
+ const buffer = await response.arrayBuffer();
119
+ onProgress(buffer.byteLength, buffer.byteLength);
120
+ return buffer;
121
+ }
122
+ const reader = response.body.getReader();
123
+ const chunks = [];
124
+ let loaded = 0;
125
+ while (true) {
126
+ const { done, value } = await reader.read();
127
+ if (done)
128
+ break;
129
+ chunks.push(value);
130
+ loaded += value.length;
131
+ onProgress(loaded, total || loaded);
132
+ }
133
+ // Combine chunks into ArrayBuffer
134
+ const buffer = new Uint8Array(loaded);
135
+ let offset = 0;
136
+ for (const chunk of chunks) {
137
+ buffer.set(chunk, offset);
138
+ offset += chunk.length;
139
+ }
140
+ return buffer.buffer;
141
+ }
142
+ /**
143
+ * Initialize pipeline (override to skip default model loading)
144
+ */
145
+ async initialize() {
146
+ if (this.isReady)
147
+ return;
148
+ // Don't call super.initialize() - we handle model loading separately
149
+ this.isReady = true;
150
+ }
151
+ /**
152
+ * Set tokenizer
153
+ */
154
+ setTokenizer(tokenizer) {
155
+ this.tokenizer = tokenizer;
156
+ const specialIds = tokenizer.getSpecialTokenIds();
157
+ this.eosTokenId = specialIds.eosTokenId ?? specialIds.sepTokenId ?? 50256;
158
+ }
159
+ /**
160
+ * Preprocess - not used for text generation (handled in generateSingle)
161
+ */
162
+ async preprocess(input) {
163
+ // For text generation, preprocessing is handled in generateNextToken
164
+ const text = Array.isArray(input) ? input[0] ?? '' : input;
165
+ if (!this.tokenizer) {
166
+ // Return dummy tensor if no tokenizer
167
+ return [new EdgeFlowTensor(new Float32Array([0]), [1], 'float32')];
168
+ }
169
+ const encoded = this.tokenizer.encode(text, {
170
+ addSpecialTokens: false,
171
+ padding: 'do_not_pad',
172
+ });
173
+ return [new EdgeFlowTensor(BigInt64Array.from(encoded.inputIds.map(id => BigInt(id))), [1, encoded.inputIds.length], 'int64')];
174
+ }
175
+ /**
176
+ * Postprocess - not used for text generation (handled in generateSingle)
177
+ */
178
+ async postprocess(_outputs, _options) {
179
+ // For text generation, postprocessing is handled in generateSingle
180
+ return {
181
+ generatedText: '',
182
+ tokenIds: [],
183
+ numTokens: 0,
184
+ processingTime: 0,
185
+ };
186
+ }
187
+ /**
188
+ * Generate text (non-streaming)
189
+ */
190
+ async run(prompt, options) {
191
+ await this.initialize();
192
+ const prompts = Array.isArray(prompt) ? prompt : [prompt];
193
+ const results = await Promise.all(prompts.map(p => this.generateSingle(p, options ?? {})));
194
+ return Array.isArray(prompt) ? results : results[0];
195
+ }
196
+ /**
197
+ * Generate text with streaming (async generator)
198
+ */
199
+ async *stream(prompt, options = {}) {
200
+ const startTime = performance.now();
201
+ if (!this.tokenizer) {
202
+ throw new Error('Tokenizer not set. Call setTokenizer() first.');
203
+ }
204
+ const { maxNewTokens = 50, maxLength = 512, temperature = 1.0, topK = 0, topP = 1.0, repetitionPenalty = 1.0, stopSequences = [], doSample = true, } = options;
205
+ // Encode prompt
206
+ const encoded = this.tokenizer.encode(prompt, {
207
+ addSpecialTokens: false,
208
+ padding: 'do_not_pad',
209
+ truncation: false,
210
+ });
211
+ let inputIds = [...encoded.inputIds];
212
+ const generatedIds = [];
213
+ let generatedText = '';
214
+ // Generation loop
215
+ for (let i = 0; i < maxNewTokens; i++) {
216
+ // Check max length
217
+ if (inputIds.length >= maxLength)
218
+ break;
219
+ // Run model forward pass
220
+ const nextTokenId = await this.generateNextToken(inputIds, temperature, topK, topP, repetitionPenalty, doSample);
221
+ // Check for EOS
222
+ if (nextTokenId === this.eosTokenId) {
223
+ yield {
224
+ token: '',
225
+ tokenId: nextTokenId,
226
+ generatedText,
227
+ done: true,
228
+ };
229
+ break;
230
+ }
231
+ // Decode token
232
+ const token = this.tokenizer.decode([nextTokenId], true);
233
+ generatedIds.push(nextTokenId);
234
+ inputIds.push(nextTokenId);
235
+ generatedText += token;
236
+ // Call token callback
237
+ if (options.onToken) {
238
+ options.onToken(token, nextTokenId);
239
+ }
240
+ // Check stop sequences
241
+ let shouldStop = false;
242
+ for (const stopSeq of stopSequences) {
243
+ if (generatedText.endsWith(stopSeq)) {
244
+ generatedText = generatedText.slice(0, -stopSeq.length);
245
+ shouldStop = true;
246
+ break;
247
+ }
248
+ }
249
+ yield {
250
+ token,
251
+ tokenId: nextTokenId,
252
+ generatedText,
253
+ done: shouldStop,
254
+ };
255
+ if (shouldStop)
256
+ break;
257
+ }
258
+ // Final event
259
+ const endTime = performance.now();
260
+ console.log(`Generation completed in ${(endTime - startTime).toFixed(2)}ms`);
261
+ }
262
+ /**
263
+ * Generate a single sequence (non-streaming)
264
+ */
265
+ async generateSingle(prompt, options) {
266
+ const startTime = performance.now();
267
+ if (!this.tokenizer) {
268
+ throw new Error('Tokenizer not set. Call setTokenizer() first.');
269
+ }
270
+ const { maxNewTokens = 50, maxLength = 512, temperature = 1.0, topK = 0, topP = 1.0, repetitionPenalty = 1.0, stopSequences = [], doSample = true, returnFullText = false, } = options;
271
+ // Encode prompt
272
+ const encoded = this.tokenizer.encode(prompt, {
273
+ addSpecialTokens: false,
274
+ padding: 'do_not_pad',
275
+ truncation: false,
276
+ });
277
+ let inputIds = [...encoded.inputIds];
278
+ const generatedIds = [];
279
+ // Generation loop
280
+ for (let i = 0; i < maxNewTokens; i++) {
281
+ // Check max length
282
+ if (inputIds.length >= maxLength)
283
+ break;
284
+ // Run model forward pass
285
+ const nextTokenId = await this.generateNextToken(inputIds, temperature, topK, topP, repetitionPenalty, doSample);
286
+ // Check for EOS
287
+ if (nextTokenId === this.eosTokenId)
288
+ break;
289
+ // Add to sequence
290
+ generatedIds.push(nextTokenId);
291
+ inputIds.push(nextTokenId);
292
+ // Call token callback
293
+ if (options.onToken) {
294
+ const token = this.tokenizer.decode([nextTokenId], true);
295
+ options.onToken(token, nextTokenId);
296
+ }
297
+ // Check stop sequences
298
+ const currentText = this.tokenizer.decode(generatedIds, true);
299
+ let shouldStop = false;
300
+ for (const stopSeq of stopSequences) {
301
+ if (currentText.endsWith(stopSeq)) {
302
+ shouldStop = true;
303
+ break;
304
+ }
305
+ }
306
+ if (shouldStop)
307
+ break;
308
+ }
309
+ // Decode generated text
310
+ const generatedText = this.tokenizer.decode(generatedIds, true);
311
+ const endTime = performance.now();
312
+ return {
313
+ generatedText,
314
+ fullText: returnFullText ? prompt + generatedText : undefined,
315
+ tokenIds: generatedIds,
316
+ numTokens: generatedIds.length,
317
+ processingTime: endTime - startTime,
318
+ };
319
+ }
320
+ /**
321
+ * Generate next token using the model
322
+ */
323
+ async generateNextToken(inputIds, temperature, topK, topP, repetitionPenalty, doSample) {
324
+ if (!this.model) {
325
+ throw new Error('Model not loaded');
326
+ }
327
+ const seqLen = inputIds.length;
328
+ // Prepare named inputs
329
+ const inputs = new Map();
330
+ // input_ids: [1, seq_len]
331
+ inputs.set('input_ids', new EdgeFlowTensor(BigInt64Array.from(inputIds.map(id => BigInt(id))), [1, seqLen], 'int64'));
332
+ // attention_mask: [1, seq_len]
333
+ inputs.set('attention_mask', new EdgeFlowTensor(BigInt64Array.from(inputIds.map(() => BigInt(1))), [1, seqLen], 'int64'));
334
+ // position_ids: [1, seq_len] - sequential positions from 0 to seq_len-1
335
+ inputs.set('position_ids', new EdgeFlowTensor(BigInt64Array.from(Array.from({ length: seqLen }, (_, i) => BigInt(i))), [1, seqLen], 'int64'));
336
+ // TinyLlama has 22 layers with GQA (4 KV heads, head_dim=64)
337
+ // For first inference without cache, provide empty past_key_values
338
+ const numLayers = 22;
339
+ const numKVHeads = 4;
340
+ const headDim = 64;
341
+ for (let i = 0; i < numLayers; i++) {
342
+ // past_key_values.{i}.key: [batch, num_kv_heads, 0, head_dim]
343
+ inputs.set(`past_key_values.${i}.key`, new EdgeFlowTensor(new Float32Array(0), [1, numKVHeads, 0, headDim], 'float32'));
344
+ // past_key_values.{i}.value: [batch, num_kv_heads, 0, head_dim]
345
+ inputs.set(`past_key_values.${i}.value`, new EdgeFlowTensor(new Float32Array(0), [1, numKVHeads, 0, headDim], 'float32'));
346
+ }
347
+ // Run inference with named inputs
348
+ const outputs = await runInferenceNamed(this.model, inputs);
349
+ if (!outputs || outputs.length === 0) {
350
+ throw new Error('Model returned no outputs');
351
+ }
352
+ // Get logits for last token
353
+ const logits = outputs[0];
354
+ const logitsData = logits.toFloat32Array();
355
+ const vocabSize = logits.shape[logits.shape.length - 1] ?? 50257;
356
+ // Get logits for the last position
357
+ const lastPositionLogits = new Float32Array(vocabSize);
358
+ const offset = (inputIds.length - 1) * vocabSize;
359
+ for (let i = 0; i < vocabSize; i++) {
360
+ lastPositionLogits[i] = logitsData[offset + i] ?? 0;
361
+ }
362
+ // Apply repetition penalty
363
+ if (repetitionPenalty !== 1.0) {
364
+ for (const prevId of inputIds) {
365
+ if (prevId < vocabSize) {
366
+ const score = lastPositionLogits[prevId] ?? 0;
367
+ lastPositionLogits[prevId] = score > 0
368
+ ? score / repetitionPenalty
369
+ : score * repetitionPenalty;
370
+ }
371
+ }
372
+ }
373
+ // Apply temperature
374
+ if (temperature !== 1.0) {
375
+ for (let i = 0; i < vocabSize; i++) {
376
+ lastPositionLogits[i] = (lastPositionLogits[i] ?? 0) / temperature;
377
+ }
378
+ }
379
+ // Convert to probabilities
380
+ const logitsTensor = new EdgeFlowTensor(lastPositionLogits, [vocabSize], 'float32');
381
+ const probs = softmax(logitsTensor).toFloat32Array();
382
+ // Sample or greedy
383
+ if (doSample) {
384
+ return this.sample(probs, topK, topP);
385
+ }
386
+ else {
387
+ return this.greedy(probs);
388
+ }
389
+ }
390
+ /**
391
+ * Greedy decoding (argmax)
392
+ */
393
+ greedy(probs) {
394
+ let maxIdx = 0;
395
+ let maxProb = probs[0] ?? 0;
396
+ for (let i = 1; i < probs.length; i++) {
397
+ if ((probs[i] ?? 0) > maxProb) {
398
+ maxProb = probs[i] ?? 0;
399
+ maxIdx = i;
400
+ }
401
+ }
402
+ return maxIdx;
403
+ }
404
+ /**
405
+ * Sample from probability distribution with top-k/top-p filtering
406
+ */
407
+ sample(probs, topK, topP) {
408
+ // Create sorted indices
409
+ const indices = Array.from({ length: probs.length }, (_, i) => i);
410
+ indices.sort((a, b) => (probs[b] ?? 0) - (probs[a] ?? 0));
411
+ // Apply top-k filtering
412
+ let candidateIndices = indices;
413
+ if (topK > 0 && topK < probs.length) {
414
+ candidateIndices = indices.slice(0, topK);
415
+ }
416
+ // Apply top-p (nucleus) filtering
417
+ if (topP < 1.0) {
418
+ let cumulativeProb = 0;
419
+ const filtered = [];
420
+ for (const idx of candidateIndices) {
421
+ filtered.push(idx);
422
+ cumulativeProb += probs[idx] ?? 0;
423
+ if (cumulativeProb >= topP)
424
+ break;
425
+ }
426
+ candidateIndices = filtered;
427
+ }
428
+ // Renormalize probabilities
429
+ let totalProb = 0;
430
+ for (const idx of candidateIndices) {
431
+ totalProb += probs[idx] ?? 0;
432
+ }
433
+ // Sample
434
+ const r = Math.random() * totalProb;
435
+ let cumulative = 0;
436
+ for (const idx of candidateIndices) {
437
+ cumulative += probs[idx] ?? 0;
438
+ if (cumulative >= r) {
439
+ return idx;
440
+ }
441
+ }
442
+ // Fallback
443
+ return candidateIndices[0] ?? 0;
444
+ }
445
+ // ==========================================================================
446
+ // Chat / Conversation Support
447
+ // ==========================================================================
448
+ conversationHistory = [];
449
+ chatTemplateType = 'chatml';
450
+ /**
451
+ * Set the chat template type
452
+ */
453
+ setChatTemplate(templateType) {
454
+ this.chatTemplateType = templateType;
455
+ }
456
+ /**
457
+ * Apply chat template to messages
458
+ */
459
+ applyChatTemplate(messages, options) {
460
+ const templateType = options?.templateType ?? this.chatTemplateType;
461
+ switch (templateType) {
462
+ case 'chatml':
463
+ return this.applyChatMLTemplate(messages);
464
+ case 'llama2':
465
+ return this.applyLlama2Template(messages);
466
+ case 'llama3':
467
+ return this.applyLlama3Template(messages);
468
+ case 'mistral':
469
+ return this.applyMistralTemplate(messages);
470
+ case 'phi3':
471
+ return this.applyPhi3Template(messages);
472
+ case 'alpaca':
473
+ return this.applyAlpacaTemplate(messages);
474
+ case 'vicuna':
475
+ return this.applyVicunaTemplate(messages);
476
+ case 'custom':
477
+ return this.applyCustomTemplate(messages, options?.customTemplate ?? {});
478
+ default:
479
+ return this.applyChatMLTemplate(messages);
480
+ }
481
+ }
482
+ /**
483
+ * ChatML template (used by many models including Qwen, Yi)
484
+ */
485
+ applyChatMLTemplate(messages) {
486
+ let prompt = '';
487
+ for (const msg of messages) {
488
+ prompt += `<|im_start|>${msg.role}\n${msg.content}<|im_end|>\n`;
489
+ }
490
+ prompt += '<|im_start|>assistant\n';
491
+ return prompt;
492
+ }
493
+ /**
494
+ * Llama 2 template
495
+ */
496
+ applyLlama2Template(messages) {
497
+ let prompt = '';
498
+ let systemMsg = '';
499
+ for (const msg of messages) {
500
+ if (msg.role === 'system') {
501
+ systemMsg = msg.content;
502
+ }
503
+ else if (msg.role === 'user') {
504
+ if (systemMsg) {
505
+ prompt += `<s>[INST] <<SYS>>\n${systemMsg}\n<</SYS>>\n\n${msg.content} [/INST]`;
506
+ systemMsg = '';
507
+ }
508
+ else {
509
+ prompt += `<s>[INST] ${msg.content} [/INST]`;
510
+ }
511
+ }
512
+ else if (msg.role === 'assistant') {
513
+ prompt += ` ${msg.content} </s>`;
514
+ }
515
+ }
516
+ return prompt;
517
+ }
518
+ /**
519
+ * Llama 3 template
520
+ */
521
+ applyLlama3Template(messages) {
522
+ let prompt = '<|begin_of_text|>';
523
+ for (const msg of messages) {
524
+ prompt += `<|start_header_id|>${msg.role}<|end_header_id|>\n\n${msg.content}<|eot_id|>`;
525
+ }
526
+ prompt += '<|start_header_id|>assistant<|end_header_id|>\n\n';
527
+ return prompt;
528
+ }
529
+ /**
530
+ * Mistral template
531
+ */
532
+ applyMistralTemplate(messages) {
533
+ let prompt = '<s>';
534
+ for (const msg of messages) {
535
+ if (msg.role === 'user') {
536
+ prompt += `[INST] ${msg.content} [/INST]`;
537
+ }
538
+ else if (msg.role === 'assistant') {
539
+ prompt += ` ${msg.content}</s>`;
540
+ }
541
+ else if (msg.role === 'system') {
542
+ prompt += `[INST] ${msg.content}\n`;
543
+ }
544
+ }
545
+ return prompt;
546
+ }
547
+ /**
548
+ * Phi-3 template
549
+ */
550
+ applyPhi3Template(messages) {
551
+ let prompt = '';
552
+ for (const msg of messages) {
553
+ prompt += `<|${msg.role}|>\n${msg.content}<|end|>\n`;
554
+ }
555
+ prompt += '<|assistant|>\n';
556
+ return prompt;
557
+ }
558
+ /**
559
+ * Alpaca template
560
+ */
561
+ applyAlpacaTemplate(messages) {
562
+ let prompt = '';
563
+ let instruction = '';
564
+ let input = '';
565
+ for (const msg of messages) {
566
+ if (msg.role === 'system') {
567
+ instruction = msg.content;
568
+ }
569
+ else if (msg.role === 'user') {
570
+ input = msg.content;
571
+ }
572
+ }
573
+ if (instruction) {
574
+ prompt = `### Instruction:\n${instruction}\n\n`;
575
+ }
576
+ if (input) {
577
+ prompt += `### Input:\n${input}\n\n`;
578
+ }
579
+ prompt += '### Response:\n';
580
+ return prompt;
581
+ }
582
+ /**
583
+ * Vicuna template
584
+ */
585
+ applyVicunaTemplate(messages) {
586
+ let prompt = '';
587
+ for (const msg of messages) {
588
+ if (msg.role === 'system') {
589
+ prompt += `${msg.content}\n\n`;
590
+ }
591
+ else if (msg.role === 'user') {
592
+ prompt += `USER: ${msg.content}\n`;
593
+ }
594
+ else if (msg.role === 'assistant') {
595
+ prompt += `ASSISTANT: ${msg.content}\n`;
596
+ }
597
+ }
598
+ prompt += 'ASSISTANT:';
599
+ return prompt;
600
+ }
601
+ /**
602
+ * Custom template
603
+ */
604
+ applyCustomTemplate(messages, template) {
605
+ const { systemPrefix = '', systemSuffix = '\n', userPrefix = 'User: ', userSuffix = '\n', assistantPrefix = 'Assistant: ', assistantSuffix = '\n', separator = '', } = template;
606
+ let prompt = '';
607
+ for (let i = 0; i < messages.length; i++) {
608
+ const msg = messages[i];
609
+ if (i > 0)
610
+ prompt += separator;
611
+ switch (msg.role) {
612
+ case 'system':
613
+ prompt += `${systemPrefix}${msg.content}${systemSuffix}`;
614
+ break;
615
+ case 'user':
616
+ prompt += `${userPrefix}${msg.content}${userSuffix}`;
617
+ break;
618
+ case 'assistant':
619
+ prompt += `${assistantPrefix}${msg.content}${assistantSuffix}`;
620
+ break;
621
+ }
622
+ }
623
+ prompt += assistantPrefix;
624
+ return prompt;
625
+ }
626
+ /**
627
+ * Chat with the model
628
+ *
629
+ * @example
630
+ * ```typescript
631
+ * const generator = await pipeline('text-generation', 'model');
632
+ *
633
+ * // Single turn
634
+ * const response = await generator.chat('Hello, how are you?');
635
+ *
636
+ * // Multi-turn with history
637
+ * const response1 = await generator.chat('What is AI?');
638
+ * const response2 = await generator.chat('Can you give an example?');
639
+ *
640
+ * // With system prompt
641
+ * const response = await generator.chat('Hello', {
642
+ * systemPrompt: 'You are a helpful assistant.',
643
+ * });
644
+ * ```
645
+ */
646
+ async chat(userMessage, options) {
647
+ // Add system message if provided and not already present
648
+ if (options?.systemPrompt &&
649
+ (this.conversationHistory.length === 0 || this.conversationHistory[0]?.role !== 'system')) {
650
+ this.conversationHistory.unshift({
651
+ role: 'system',
652
+ content: options.systemPrompt,
653
+ });
654
+ }
655
+ // Add user message
656
+ this.conversationHistory.push({
657
+ role: 'user',
658
+ content: userMessage,
659
+ });
660
+ // Apply chat template
661
+ const prompt = this.applyChatTemplate(this.conversationHistory, options);
662
+ // Generate response
663
+ const result = await this.run(prompt, {
664
+ ...options,
665
+ stopSequences: [
666
+ ...(options?.stopSequences ?? []),
667
+ '<|im_end|>',
668
+ '<|end|>',
669
+ '<|eot_id|>',
670
+ '</s>',
671
+ '\n\nUser:',
672
+ '\n\nHuman:',
673
+ ],
674
+ });
675
+ // Add assistant response to history
676
+ const response = Array.isArray(result) ? result[0] : result;
677
+ this.conversationHistory.push({
678
+ role: 'assistant',
679
+ content: response.generatedText.trim(),
680
+ });
681
+ return response;
682
+ }
683
+ /**
684
+ * Stream chat response
685
+ */
686
+ async *chatStream(userMessage, options) {
687
+ // Add system message if provided
688
+ if (options?.systemPrompt &&
689
+ (this.conversationHistory.length === 0 || this.conversationHistory[0]?.role !== 'system')) {
690
+ this.conversationHistory.unshift({
691
+ role: 'system',
692
+ content: options.systemPrompt,
693
+ });
694
+ }
695
+ // Add user message
696
+ this.conversationHistory.push({
697
+ role: 'user',
698
+ content: userMessage,
699
+ });
700
+ // Apply chat template
701
+ const prompt = this.applyChatTemplate(this.conversationHistory, options);
702
+ // Stream response
703
+ let fullResponse = '';
704
+ for await (const event of this.stream(prompt, {
705
+ ...options,
706
+ stopSequences: [
707
+ ...(options?.stopSequences ?? []),
708
+ '<|im_end|>',
709
+ '<|end|>',
710
+ '<|eot_id|>',
711
+ '</s>',
712
+ ],
713
+ })) {
714
+ fullResponse = event.generatedText;
715
+ yield event;
716
+ }
717
+ // Add assistant response to history
718
+ this.conversationHistory.push({
719
+ role: 'assistant',
720
+ content: fullResponse.trim(),
721
+ });
722
+ }
723
+ /**
724
+ * Get conversation history
725
+ */
726
+ getConversationHistory() {
727
+ return [...this.conversationHistory];
728
+ }
729
+ /**
730
+ * Set conversation history
731
+ */
732
+ setConversationHistory(messages) {
733
+ this.conversationHistory = [...messages];
734
+ }
735
+ /**
736
+ * Clear conversation history
737
+ */
738
+ clearConversation() {
739
+ this.conversationHistory = [];
740
+ }
741
+ /**
742
+ * Remove last exchange (user message + assistant response)
743
+ */
744
+ undoLastExchange() {
745
+ // Remove assistant message
746
+ if (this.conversationHistory.length > 0 &&
747
+ this.conversationHistory[this.conversationHistory.length - 1]?.role === 'assistant') {
748
+ this.conversationHistory.pop();
749
+ }
750
+ // Remove user message
751
+ if (this.conversationHistory.length > 0 &&
752
+ this.conversationHistory[this.conversationHistory.length - 1]?.role === 'user') {
753
+ this.conversationHistory.pop();
754
+ }
755
+ }
756
+ }
757
+ // ============================================================================
758
+ // Factory Functions
759
+ // ============================================================================
760
+ /**
761
+ * Create text generation pipeline
762
+ */
763
+ export function createTextGenerationPipeline(config) {
764
+ return new TextGenerationPipeline(config);
765
+ }
766
+ //# sourceMappingURL=text-generation.js.map