@mlx-node/core 0.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/index.cjs +766 -0
  2. package/index.d.cts +2728 -0
  3. package/package.json +36 -0
package/index.d.cts ADDED
@@ -0,0 +1,2728 @@
1
+ /* auto-generated by NAPI-RS */
2
+ /* eslint-disable */
3
+ /**
4
+ * Result from batch text generation
5
+ *
6
+ * Contains results for N prompts × G completions per prompt.
7
+ * Results are stored flat in arrays of length N*G, where:
8
+ * - First G elements are completions for prompt 0
9
+ * - Next G elements are completions for prompt 1
10
+ * - etc.
11
+ */
12
+ export declare class BatchGenerationResult {
13
+ /** Get all generated token arrays (N*G arrays) */
14
+ get tokens(): Array<MxArray>;
15
+ /** Get all log probability arrays (N*G arrays) */
16
+ get logprobs(): Array<MxArray>;
17
+ /** Get all decoded texts (N*G strings) */
18
+ get texts(): Array<string>;
19
+ /** Get finish reasons grouped by prompt (N arrays of G finish reasons) */
20
+ get finishReasons(): Array<Array<string>>;
21
+ /** Get token counts grouped by prompt (N arrays of G counts) */
22
+ get tokenCounts(): Array<Array<number>>;
23
+ /** Get number of prompts */
24
+ get numPrompts(): number;
25
+ /** Get group size (completions per prompt) */
26
+ get groupSize(): number;
27
+ }
28
+
29
+ /**
30
+ * Result from the high-level `chat()` API
31
+ *
32
+ * Contains structured responses with:
33
+ * - Tool calls parsed as native JavaScript objects
34
+ * - Thinking/reasoning extracted from `<think>` tags
35
+ * - Clean text with all special tags stripped
36
+ *
37
+ * ## Example
38
+ * ```typescript
39
+ * const result = await model.chat(messages, { tools });
40
+ * console.log(result.text); // Clean response
41
+ * console.log(result.thinking); // Chain-of-thought (if any)
42
+ * console.log(result.toolCalls); // Parsed tool calls
43
+ * ```
44
+ */
45
+ export declare class ChatResult {
46
+ /** Get the cleaned text (tool_call and think tags removed) */
47
+ get text(): string;
48
+ /** Get the extracted tool calls */
49
+ get toolCalls(): Array<ToolCallResult>;
50
+ /**
51
+ * Get the extracted thinking/reasoning content
52
+ *
53
+ * Returns the content from within `<think>...</think>` tags, or null if
54
+ * no thinking tags were present in the response.
55
+ *
56
+ * This is useful for:
57
+ * - Debugging model reasoning
58
+ * - Displaying chain-of-thought to users (optional)
59
+ * - Analyzing model decision-making
60
+ */
61
+ get thinking(): string | null;
62
+ /** Get the generated tokens */
63
+ get tokens(): MxArray;
64
+ /** Get the log probabilities */
65
+ get logprobs(): MxArray;
66
+ /** Get the finish reason ("stop", "length", "tool_calls", or "repetition") */
67
+ get finishReason(): 'stop' | 'length' | 'tool_calls' | 'repetition';
68
+ /** Get the number of tokens generated */
69
+ get numTokens(): number;
70
+ /** Get the raw text before tool call stripping (for debugging) */
71
+ get rawText(): string;
72
+ }
73
+
74
+ /** Result from text generation with detailed metadata */
75
+ export declare class GenerationResult {
76
+ /** Get the decoded text */
77
+ get text(): string;
78
+ /** Get the generated tokens */
79
+ get tokens(): MxArray;
80
+ /** Get the log probabilities */
81
+ get logprobs(): MxArray;
82
+ /** Get the finish reason ("eos", "length", or "repetition") */
83
+ get finishReason(): 'eos' | 'length' | 'repetition';
84
+ /** Get the number of tokens generated */
85
+ get numTokens(): number;
86
+ }
87
+
88
+ /**
89
+ * GRPO Training Engine
90
+ *
91
+ * Complete training engine that runs entirely in Rust.
92
+ */
93
+ export declare class GrpoTrainingEngine {
94
+ /**
95
+ * Create a new training engine from an existing model
96
+ *
97
+ * # Arguments
98
+ * * `model` - The Qwen3 model to train (will be cloned internally)
99
+ * * `config` - Engine configuration
100
+ */
101
+ constructor(model: Qwen3Model, config: GrpoEngineConfig);
102
+ /** Register a built-in reward function */
103
+ registerBuiltinReward(config: BuiltinRewardConfig): void;
104
+ /**
105
+ * Run a training step with provided rewards
106
+ *
107
+ * This method performs the complete training cycle:
108
+ * 1. Generate completions for each prompt (G times per prompt)
109
+ * 2. Use provided rewards to compute advantages
110
+ * 3. Compute GRPO loss and gradients
111
+ * 4. Apply gradients (respecting accumulation steps)
112
+ *
113
+ * # Arguments
114
+ * * `prompts` - Array of chat conversations to use as prompts
115
+ * * `rewards` - Reward values for each completion (num_prompts * group_size)
116
+ *
117
+ * # Returns
118
+ * * Training step metrics
119
+ */
120
+ trainStep(prompts: Array<Array<ChatMessage>>, rewards: Array<number>): Promise<EngineStepMetrics>;
121
+ /**
122
+ * Generate completions without training
123
+ *
124
+ * Use this to generate completions for scoring by external reward functions.
125
+ * Returns completion texts along with the internal token data needed for training.
126
+ */
127
+ generateBatch(prompts: Array<Array<ChatMessage>>): Promise<Array<string>>;
128
+ /**
129
+ * Generate completions with all data needed for training
130
+ *
131
+ * Returns completion texts, tokens, log probabilities, and lengths.
132
+ * Use this when you need to score completions externally and then train.
133
+ */
134
+ generateBatchForTraining(prompts: Array<Array<ChatMessage>>): Promise<GenerateBatchResult>;
135
+ /**
136
+ * Run a training step with pre-generated completions
137
+ *
138
+ * This method performs training using pre-generated completions,
139
+ * eliminating the double-generation issue.
140
+ *
141
+ * # Arguments
142
+ * * `prompts` - Array of chat conversations to use as prompts
143
+ * * `rewards` - Reward values for each completion (num_prompts * group_size)
144
+ * * `generation_result` - Pre-generated completion data from generate_batch_for_training
145
+ *
146
+ * # Returns
147
+ * * Training step metrics
148
+ */
149
+ trainStepWithGenerations(
150
+ prompts: Array<Array<ChatMessage>>,
151
+ rewards: Array<number>,
152
+ generationResult: GenerateBatchResult,
153
+ ): Promise<EngineStepMetrics>;
154
+ /**
155
+ * Unified training step with JS reward callback and optional output recording
156
+ *
157
+ * Same as `train_step_auto` but optionally captures the full RewardOutput data
158
+ * for persistence to an output store database.
159
+ *
160
+ * # Arguments
161
+ * * `prompts` - Array of chat conversations to use as prompts
162
+ * * `reward_fn` - JavaScript function to compute rewards
163
+ * * `record_outputs` - If true, return the serialized RewardOutput JSON
164
+ *
165
+ * # Returns
166
+ * * Training step result including metrics, completions, rewards, and optionally outputs_json
167
+ */
168
+ trainStepAuto(
169
+ prompts: ChatMessage[][],
170
+ rewardFn: (err: Error | null, outputsJson: string) => Promise<number[]>,
171
+ recordOutputs: boolean,
172
+ ): Promise<TrainStepResultWithOutputs>;
173
+ /**
174
+ * Score completions using registered built-in rewards
175
+ *
176
+ * # Arguments
177
+ * * `prompts` - Prompt texts (expanded to match completions)
178
+ * * `completions` - Completion texts to score
179
+ */
180
+ scoreCompletions(prompts: Array<string>, completions: Array<string>): Array<number>;
181
+ /** Get current training step */
182
+ get step(): number;
183
+ /** Get current epoch */
184
+ get epoch(): number;
185
+ /** Start a new epoch */
186
+ startEpoch(): void;
187
+ /** End the current epoch and get metrics */
188
+ endEpoch(epochTimeSecs: number): EngineEpochMetrics;
189
+ /** Reset the engine for a fresh training run */
190
+ reset(): void;
191
+ /** Check if reward registry has any rewards registered */
192
+ get hasBuiltinRewards(): boolean;
193
+ /** Get names of registered reward functions */
194
+ get rewardNames(): Array<string>;
195
+ /** Get current micro-step within gradient accumulation */
196
+ get microStep(): number;
197
+ /**
198
+ * Check if an emergency checkpoint should be saved
199
+ * This flag is set when consecutive NaN gradients reach the threshold
200
+ */
201
+ get needsEmergencySave(): boolean;
202
+ /** Get current NaN gradient count */
203
+ get nanGradientCount(): number;
204
+ /** Clear the emergency save flag (call after saving emergency checkpoint) */
205
+ clearEmergencySaveFlag(): void;
206
+ }
207
+ export type GRPOTrainingEngine = GrpoTrainingEngine;
208
+
209
+ export declare class MxArray {
210
+ static fromInt32(data: Int32Array, shape: BigInt64Array): MxArray;
211
+ static fromInt64(data: BigInt64Array, shape: BigInt64Array): MxArray;
212
+ static fromUint32(data: Uint32Array, shape: BigInt64Array): MxArray;
213
+ static fromFloat32(data: Float32Array, shape: BigInt64Array): MxArray;
214
+ static zeros(shape: BigInt64Array, dtype?: DType | undefined | null): MxArray;
215
+ static scalarFloat(value: number): MxArray;
216
+ static scalarInt(value: number): MxArray;
217
+ static ones(shape: BigInt64Array, dtype?: DType | undefined | null): MxArray;
218
+ static full(shape: BigInt64Array, fillValue: number | MxArray, dtype?: DType | undefined | null): MxArray;
219
+ static linspace(
220
+ start: number,
221
+ stop: number,
222
+ num?: number | undefined | null,
223
+ dtype?: DType | undefined | null,
224
+ ): MxArray;
225
+ static eye(
226
+ n: number,
227
+ m?: number | undefined | null,
228
+ k?: number | undefined | null,
229
+ dtype?: DType | undefined | null,
230
+ ): MxArray;
231
+ static arange(
232
+ start: number,
233
+ stop: number,
234
+ step?: number | undefined | null,
235
+ dtype?: DType | undefined | null,
236
+ ): MxArray;
237
+ reshape(shape: BigInt64Array): MxArray;
238
+ astype(dtype: DType): MxArray;
239
+ /**
240
+ * Create a copy of this array with a new handle.
241
+ * This is useful for parameter loading to avoid handle aliasing issues.
242
+ */
243
+ copy(): MxArray;
244
+ logSoftmax(axis: number): MxArray;
245
+ exp(): MxArray;
246
+ log(): MxArray;
247
+ sum(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
248
+ mean(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
249
+ clip(minimum?: number | undefined | null, maximum?: number | undefined | null): MxArray;
250
+ minimum(other: MxArray): MxArray;
251
+ maximum(other: MxArray): MxArray;
252
+ add(other: MxArray): MxArray;
253
+ sub(other: MxArray): MxArray;
254
+ mul(other: MxArray): MxArray;
255
+ div(other: MxArray): MxArray;
256
+ addScalar(value: number): MxArray;
257
+ mulScalar(value: number): MxArray;
258
+ subScalar(value: number): MxArray;
259
+ divScalar(value: number): MxArray;
260
+ matmul(other: MxArray): MxArray;
261
+ /**
262
+ * Fused matrix multiply-add: D = beta * C + alpha * (self @ B)
263
+ * where self is A. More efficient than separate matmul and add operations.
264
+ * Default: alpha=1.0, beta=1.0, giving D = C + (self @ B)
265
+ */
266
+ addmm(c: MxArray, b: MxArray, alpha?: number | undefined | null, beta?: number | undefined | null): MxArray;
267
+ transpose(axes?: Int32Array | undefined | null): MxArray;
268
+ take(indices: MxArray, axis: number): MxArray;
269
+ takeAlongAxis(indices: MxArray, axis: number): MxArray;
270
+ /**
271
+ * Put values into array at specified indices along an axis
272
+ * Equivalent to: result = array.copy(); result[..., indices] = values
273
+ * This matches MLX's put_along_axis for efficient in-place-style updates
274
+ */
275
+ putAlongAxis(indices: MxArray, values: MxArray, axis: number): MxArray;
276
+ slice(starts: BigInt64Array, stops: BigInt64Array): MxArray;
277
+ /**
278
+ * Concatenate two arrays along an axis
279
+ * Optimized for the common binary concatenation case
280
+ */
281
+ static concatenate(a: MxArray, b: MxArray, axis: number): MxArray;
282
+ /**
283
+ * Concatenate multiple arrays along an axis
284
+ * For concatenating 3 or more arrays
285
+ */
286
+ static concatenateMany(arrays: Array<MxArray>, axis?: number | undefined | null): MxArray;
287
+ sort(axis?: number | undefined | null): MxArray;
288
+ argsort(axis?: number | undefined | null): MxArray;
289
+ partition(kth: number, axis?: number | undefined | null): MxArray;
290
+ argpartition(kth: number, axis?: number | undefined | null): MxArray;
291
+ eval(): void;
292
+ evalAsync(): Promise<undefined>;
293
+ size(): bigint;
294
+ ndim(): number;
295
+ shape(): BigInt64Array;
296
+ /**
297
+ * Get a single dimension from the array shape without copying the entire shape
298
+ * This is more efficient when you only need one dimension
299
+ *
300
+ * Note: axis is u32 because NAPI doesn't support usize, but internally converted to usize
301
+ */
302
+ shapeAt(axis: number): number;
303
+ /**
304
+ * Get batch and sequence length for 2D arrays (common pattern in transformers)
305
+ * More efficient than calling shape() and extracting dimensions
306
+ */
307
+ getBatchSeqLen(): Array<number>;
308
+ /**
309
+ * Get batch, sequence length, and hidden size for 3D arrays (common pattern in transformers)
310
+ * More efficient than calling shape() and extracting dimensions
311
+ */
312
+ getBatchSeqHidden(): Array<number>;
313
+ dtype(): DType;
314
+ /**
315
+ * Copy entire array from GPU to CPU as Float32Array
316
+ *
317
+ * ⚠吅 **PERFORMANCE WARNING**: This triggers a FULL GPU→CPU memory transfer!
318
+ *
319
+ * **Performance impact**:
320
+ * - Forces evaluation of lazy operations
321
+ * - Copies entire array from GPU to CPU memory
322
+ * - Can be extremely slow for large arrays
323
+ *
324
+ * **Use sparingly**:
325
+ * - Prefer `item_float32()` for scalars
326
+ * - Prefer `item_at_float32(index)` for single elements
327
+ * - Only use when you truly need all array data on CPU
328
+ *
329
+ * **Acceptable use cases**:
330
+ * - Test validation and assertions
331
+ * - CPU-only operations (e.g., sorting for quantiles)
332
+ * - Final output extraction
333
+ */
334
+ toFloat32(): Float32Array;
335
+ /**
336
+ * Copy entire array from GPU to CPU as Int32Array
337
+ *
338
+ * ⚠吅 **PERFORMANCE WARNING**: This triggers a FULL GPU→CPU memory transfer!
339
+ *
340
+ * See `to_float32()` documentation for performance implications and alternatives.
341
+ * Prefer `item_int32()` for scalars.
342
+ */
343
+ toInt32(): Int32Array;
344
+ /**
345
+ * Copy entire array from GPU to CPU as Uint32Array
346
+ *
347
+ * ⚠吅 **PERFORMANCE WARNING**: This triggers a FULL GPU→CPU memory transfer!
348
+ *
349
+ * See `to_float32()` documentation for performance implications and alternatives.
350
+ */
351
+ toUint32(): Uint32Array;
352
+ static stack(arrays: Array<MxArray>, axis?: number | undefined | null): MxArray;
353
+ static randomUniform(shape: BigInt64Array, low: number, high: number, dtype?: DType | undefined | null): MxArray;
354
+ static randomNormal(shape: BigInt64Array, mean: number, std: number, dtype?: DType | undefined | null): MxArray;
355
+ static randomBernoulli(shape: BigInt64Array, prob: number): MxArray;
356
+ static randint(shape: BigInt64Array, low: number, high: number): MxArray;
357
+ /**
358
+ * Sample from categorical distribution
359
+ * Takes logits and returns sampled indices
360
+ */
361
+ categorical(axis?: number | undefined | null): MxArray;
362
+ equal(other: MxArray): MxArray;
363
+ notEqual(other: MxArray): MxArray;
364
+ less(other: MxArray): MxArray;
365
+ lessEqual(other: MxArray): MxArray;
366
+ greater(other: MxArray): MxArray;
367
+ greaterEqual(other: MxArray): MxArray;
368
+ logicalAnd(other: MxArray): MxArray;
369
+ logicalOr(other: MxArray): MxArray;
370
+ logicalNot(): MxArray;
371
+ where(x: MxArray, y: MxArray): MxArray;
372
+ argmax(axis: number, keepdims?: boolean | undefined | null): MxArray;
373
+ argmin(axis: number, keepdims?: boolean | undefined | null): MxArray;
374
+ max(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
375
+ min(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
376
+ prod(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
377
+ var(
378
+ axes?: Int32Array | undefined | null,
379
+ keepdims?: boolean | undefined | null,
380
+ ddof?: number | undefined | null,
381
+ ): MxArray;
382
+ std(
383
+ axes?: Int32Array | undefined | null,
384
+ keepdims?: boolean | undefined | null,
385
+ ddof?: number | undefined | null,
386
+ ): MxArray;
387
+ logsumexp(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
388
+ cumsum(axis: number): MxArray;
389
+ cumprod(axis: number): MxArray;
390
+ pad(padWidth: Int32Array, constantValue: number): MxArray;
391
+ roll(shift: number, axis: number): MxArray;
392
+ split(indicesOrSections: number, axis?: number | undefined | null): Array<MxArray>;
393
+ tile(reps: Int32Array): MxArray;
394
+ repeat(repeats: number, axis: number): MxArray;
395
+ squeeze(axes?: Int32Array | undefined | null): MxArray;
396
+ expandDims(axis: number): MxArray;
397
+ broadcastTo(shape: BigInt64Array): MxArray;
398
+ abs(): MxArray;
399
+ negative(): MxArray;
400
+ sign(): MxArray;
401
+ sqrt(): MxArray;
402
+ square(): MxArray;
403
+ power(other: MxArray): MxArray;
404
+ sin(): MxArray;
405
+ cos(): MxArray;
406
+ tan(): MxArray;
407
+ sinh(): MxArray;
408
+ cosh(): MxArray;
409
+ tanh(): MxArray;
410
+ floor(): MxArray;
411
+ ceil(): MxArray;
412
+ round(): MxArray;
413
+ floorDivide(other: MxArray): MxArray;
414
+ remainder(other: MxArray): MxArray;
415
+ reciprocal(): MxArray;
416
+ arcsin(): MxArray;
417
+ arccos(): MxArray;
418
+ arctan(): MxArray;
419
+ log10(): MxArray;
420
+ log2(): MxArray;
421
+ log1p(): MxArray;
422
+ /**
423
+ * Element-wise check for NaN values
424
+ *
425
+ * Returns a boolean array where True indicates the element is NaN.
426
+ * This is a GPU-native operation that avoids CPU data transfer.
427
+ */
428
+ isnan(): MxArray;
429
+ /**
430
+ * Element-wise check for Inf values
431
+ *
432
+ * Returns a boolean array where True indicates the element is +Inf or -Inf.
433
+ * This is a GPU-native operation that avoids CPU data transfer.
434
+ */
435
+ isinf(): MxArray;
436
+ /**
437
+ * Element-wise check for finite values
438
+ *
439
+ * Returns a boolean array where True indicates the element is finite (not NaN and not Inf).
440
+ * This is a GPU-native operation that avoids CPU data transfer.
441
+ */
442
+ isfinite(): MxArray;
443
+ }
444
+
445
+ /** NAPI-exported reward registry wrapper */
446
+ export declare class NativeRewardRegistry {
447
+ /** Create a new reward registry */
448
+ constructor();
449
+ /** Register a built-in reward function */
450
+ register(config: BuiltinRewardConfig): void;
451
+ /** Score a single completion */
452
+ score(prompt: string, completion: string): number;
453
+ /** Score a batch of completions */
454
+ scoreBatch(prompts: Array<string>, completions: Array<string>): Array<number>;
455
+ /** Check if registry is empty */
456
+ get isEmpty(): boolean;
457
+ /** Get registered reward names */
458
+ get names(): Array<string>;
459
+ /** Set whether to normalize scores */
460
+ setNormalize(normalize: boolean): void;
461
+ }
462
+
463
+ /**
464
+ * OutputStore - Persistence layer for training outputs
465
+ *
466
+ * Stores all model outputs during GRPO training for debugging and research.
467
+ * Supports local SQLite files.
468
+ */
469
+ export declare class OutputStore {
470
+ /** Create a new output store with local SQLite file */
471
+ static local(path: string): Promise<OutputStore>;
472
+ /** Create from config object */
473
+ static fromConfig(config: OutputStoreConfig): Promise<OutputStore>;
474
+ /** Start a new training run */
475
+ startRun(modelName: string, modelPath: string | undefined | null, config: string): Promise<string>;
476
+ /** Start a new training run with a name */
477
+ startRunWithName(
478
+ name: string | undefined | null,
479
+ modelName: string,
480
+ modelPath: string | undefined | null,
481
+ config: string,
482
+ ): Promise<string>;
483
+ /** End the current training run */
484
+ endRun(status: string): Promise<void>;
485
+ /** Get current run ID */
486
+ currentRunId(): Promise<string | null>;
487
+ /** Find a run by name */
488
+ findRunByName(name: string): Promise<TrainingRunRecord | null>;
489
+ /** Resume an existing run (sets status to running and makes it current) */
490
+ resumeRun(runId: string): Promise<void>;
491
+ /** Delete all steps after a given step number (for resume cleanup) */
492
+ deleteStepsAfter(runId: string, afterStep: number): Promise<number>;
493
+ /**
494
+ * Delete all records after a given step (for checkpoint resume)
495
+ *
496
+ * Cascades through: training_steps → generations → tool_calls, and logs.
497
+ * Use this when resuming from checkpoint to ensure clean database state.
498
+ */
499
+ deleteAllAfterStep(runId: string, afterStep: number): Promise<CleanupStats>;
500
+ /**
501
+ * Get recent step metrics for TUI sparkline restoration
502
+ *
503
+ * Returns metrics ordered by step (oldest first) for easy insertion into VecDeque.
504
+ */
505
+ getRecentStepMetrics(runId: string, limit: number): Promise<Array<StepMetricSummary>>;
506
+ /**
507
+ * Get aggregate statistics for a training run
508
+ *
509
+ * Returns pre-computed aggregates for restoring TUI state on resume.
510
+ */
511
+ getRunAggregates(runId: string): Promise<RunAggregates>;
512
+ /**
513
+ * Get recent generations for sample panel restoration
514
+ *
515
+ * Returns generations ordered by step DESC, reward DESC (most recent high-reward first).
516
+ */
517
+ getRecentGenerations(runId: string, limit: number): Promise<Array<GenerationRecord>>;
518
+ /** Get store configuration */
519
+ get config(): OutputStoreConfig;
520
+ /** Record from RewardOutput JSON (direct integration with training engine) */
521
+ recordStepFromOutputs(
522
+ step: number,
523
+ metrics: EngineStepMetrics,
524
+ outputsJson: string,
525
+ rewards: Array<number>,
526
+ groupSize: number,
527
+ ): Promise<number>;
528
+ /**
529
+ * Record a complete training step with all generations and tool calls
530
+ *
531
+ * Lower-level API for direct control over step recording.
532
+ */
533
+ recordStep(
534
+ step: StepRecord,
535
+ generations: Array<GenerationRecord>,
536
+ toolCalls: Array<Array<ToolCallRecord>>,
537
+ ): Promise<number>;
538
+ /** Flush any pending writes */
539
+ flush(): Promise<void>;
540
+ /** List all training runs */
541
+ listRuns(limit?: number | undefined | null, status?: string | undefined | null): Promise<Array<TrainingRunRecord>>;
542
+ /** Get a specific run */
543
+ getRun(runId: string): Promise<TrainingRunRecord | null>;
544
+ /** Get step summaries for a run */
545
+ getStepSummaries(
546
+ runId: string,
547
+ startStep?: number | undefined | null,
548
+ endStep?: number | undefined | null,
549
+ ): Promise<Array<StepSummary>>;
550
+ /** Get all generations for a step */
551
+ getGenerations(runId: string, step: number): Promise<Array<GenerationWithToolCalls>>;
552
+ /** Get top/bottom generations by reward */
553
+ getGenerationsByReward(
554
+ runId: string,
555
+ topN?: number | undefined | null,
556
+ bottomN?: number | undefined | null,
557
+ stepRange?: Array<number> | undefined | null,
558
+ ): Promise<Array<GenerationWithToolCalls>>;
559
+ /** Get generations with specific finish reason */
560
+ getGenerationsByFinishReason(
561
+ runId: string,
562
+ finishReason: string,
563
+ limit?: number | undefined | null,
564
+ ): Promise<Array<GenerationWithToolCalls>>;
565
+ /** Get generations containing tool calls */
566
+ getGenerationsWithToolCalls(
567
+ runId: string,
568
+ toolName?: string | undefined | null,
569
+ status?: string | undefined | null,
570
+ limit?: number | undefined | null,
571
+ ): Promise<Array<GenerationWithToolCalls>>;
572
+ /** Search generations by text content */
573
+ searchGenerations(
574
+ runId: string,
575
+ query: string,
576
+ searchIn?: string | undefined | null,
577
+ limit?: number | undefined | null,
578
+ ): Promise<Array<GenerationWithToolCalls>>;
579
+ /** Get reward distribution statistics */
580
+ getRewardStats(runId: string, stepRange?: Array<number> | undefined | null): Promise<RewardStats>;
581
+ /** Export to JSONL file */
582
+ exportJsonl(runId: string, outputPath: string, includeToolCalls?: boolean | undefined | null): Promise<number>;
583
+ /** Execute raw SQL query (for advanced users) */
584
+ queryRaw(sql: string): Promise<string>;
585
+ }
586
+
587
+ /**
588
+ * Qwen3 Model with automatic differentiation support
589
+ *
590
+ * Uses interior mutability (RwLock) for layers, final_norm, and lm_head
591
+ * to allow gradient application without deep cloning the model.
592
+ * This eliminates the previous ~4GB memory overhead from clone_for_session().
593
+ */
594
+ export declare class Qwen3Model {
595
+ /** Create a new Qwen3 model with the given configuration */
596
+ constructor(config: Qwen3Config);
597
+ /**
598
+ * Forward pass through the model
599
+ *
600
+ * # Arguments
601
+ * * `input_ids` - Token IDs, shape: [batch_size, seq_len]
602
+ *
603
+ * # Returns
604
+ * * Logits, shape: [batch_size, seq_len, vocab_size]
605
+ */
606
+ forward(inputIds: MxArray): MxArray;
607
+ /**
608
+ * Initialize KV caches for incremental generation
609
+ *
610
+ * Creates one KV cache per transformer layer. Call this before starting generation.
611
+ */
612
+ initKvCaches(): void;
613
+ /**
614
+ * Reset all KV caches
615
+ *
616
+ * Clears cached key-value states. Call this between different generation sequences.
617
+ */
618
+ resetKvCaches(): void;
619
+ /** Check if paged attention is enabled for this model */
620
+ hasPagedAttention(): boolean;
621
+ /**
622
+ * Get paged attention memory statistics (if enabled)
623
+ *
624
+ * Returns memory usage statistics for the paged KV cache.
625
+ */
626
+ pagedCacheStats(): PagedCacheStats | null;
627
+ /**
628
+ * Get scheduler statistics (if paged attention is enabled)
629
+ *
630
+ * Returns the number of waiting, running, and completed sequences.
631
+ */
632
+ schedulerStats(): SchedulerStatsNapi | null;
633
+ /**
634
+ * Forward pass with KV caching for incremental generation
635
+ *
636
+ * # Arguments
637
+ * * `input_ids` - Token IDs, shape: [batch_size, seq_len]
638
+ * * `use_cache` - Whether to use KV caching (must call init_kv_caches() first)
639
+ *
640
+ * # Returns
641
+ * * Logits, shape: [batch_size, seq_len, vocab_size]
642
+ */
643
+ forwardWithCache(inputIds: MxArray, useCache: boolean): MxArray;
644
+ /**
645
+ * Forward pass with paged attention for memory-efficient inference.
646
+ *
647
+ * This method uses block-based KV cache management via Metal kernels for:
648
+ * - Variable-length sequences with efficient memory usage
649
+ * - Continuous batching with dynamic batch composition
650
+ * - Long context support beyond GPU memory limits
651
+ *
652
+ * # Arguments
653
+ * * `input_ids` - Token IDs, shape: [num_seqs, 1] for decode
654
+ * * `slot_mapping` - Slot indices for cache updates, shape: [num_seqs]
655
+ * * `seq_ids` - Sequence IDs in the batch (for looking up block tables/context lens)
656
+ * * `positions` - Token positions for RoPE, shape: [num_seqs] (per-sequence positions)
657
+ *
658
+ * # Returns
659
+ * * Logits, shape: [num_seqs, 1, vocab_size] for decode
660
+ */
661
+ forwardPaged(inputIds: MxArray, slotMapping: MxArray, seqIds: Array<number>, positions: MxArray): MxArray;
662
+ /**
663
+ * Prefill a sequence using standard attention and write K/V to paged cache.
664
+ *
665
+ * This method should be called before `step_paged_generation()` for each
666
+ * new prompt. It runs the full forward pass using standard attention
667
+ * (which is faster for long sequences), then writes the K/V cache to
668
+ * the paged cache for subsequent decode steps.
669
+ *
670
+ * # Arguments
671
+ * * `prompt_tokens` - Token IDs for the prompt (as u32 array)
672
+ * * `seq_id` - Sequence ID (obtained from scheduler)
673
+ *
674
+ * # Returns
675
+ * * Logits for the last token, shape: [1, vocab_size]
676
+ */
677
+ prefillPaged(promptTokens: Array<number>, seqId: number): MxArray;
678
+ /**
679
+ * Add a request to the paged attention scheduler.
680
+ *
681
+ * The scheduler queues requests and allocates blocks for KV cache.
682
+ * Use `step_paged_generation()` to process the scheduled batch.
683
+ *
684
+ * Note: The actual sequence ID is assigned during scheduling, not when the
685
+ * request is added. Use the `request_id` to track your requests through
686
+ * the generation process.
687
+ *
688
+ * # Arguments
689
+ * * `request_id` - Unique identifier for the request (returned in outputs)
690
+ * * `prompt_tokens` - Token IDs for the prompt
691
+ * * `max_new_tokens` - Maximum new tokens to generate
692
+ * * `priority` - Optional priority (higher = scheduled first)
693
+ *
694
+ * # Returns
695
+ * * Number of pending requests in the queue
696
+ */
697
+ addPagedRequest(
698
+ requestId: string,
699
+ promptTokens: Array<number>,
700
+ maxNewTokens: number,
701
+ priority?: number | undefined | null,
702
+ ): number;
703
+ /**
704
+ * Schedule and execute one step of paged generation.
705
+ *
706
+ * This method:
707
+ * 1. Schedules the next batch of sequences
708
+ * 2. Runs forward pass with paged attention
709
+ * 3. Samples next tokens
710
+ * 4. Returns the generated tokens for each sequence
711
+ *
712
+ * # Arguments
713
+ * * `config` - Generation configuration (temperature, top_k, etc.)
714
+ *
715
+ * # Returns
716
+ * * `PagedGenerationStep` with token outputs for each sequence
717
+ */
718
+ stepPagedGeneration(config?: GenerationConfig | undefined | null): PagedGenerationStep | null;
719
+ /**
720
+ * Get completed sequences from the scheduler.
721
+ *
722
+ * Call this after `step_paged_generation()` returns outputs with `is_finished: true`.
723
+ */
724
+ getCompletedSequences(): Array<PagedCompletedSequence>;
725
+ /** Check if the scheduler has pending work. */
726
+ hasPagedWork(): boolean;
727
+ /** Get model configuration */
728
+ getConfig(): Qwen3Config;
729
+ /**
730
+ * Generate tokens using speculative decoding with a draft model.
731
+ *
732
+ * Speculative decoding uses a smaller draft model to generate tokens speculatively,
733
+ * then verifies them with the target model in a single forward pass. This can achieve
734
+ * 2-3x speedup when the draft model has high acceptance rate.
735
+ *
736
+ * # Algorithm
737
+ * 1. Draft model generates N tokens speculatively (cheap forward passes)
738
+ * 2. Target model (self) verifies all N tokens in one forward pass
739
+ * 3. Accept/reject using rejection sampling
740
+ * 4. On rejection, resample from adjusted distribution
741
+ * 5. Rewind caches and continue
742
+ *
743
+ * # Arguments
744
+ * * `draft_model` - Smaller model for speculative generation (should share tokenizer)
745
+ * * `input_ids` - Input token IDs [1, seq_len]
746
+ * * `config` - Generation configuration (includes num_draft_tokens)
747
+ *
748
+ * # Returns
749
+ * GenerationResult with tokens, logprobs, and speculative stats in finish_reason
750
+ *
751
+ * # Example (TypeScript)
752
+ * ```typescript
753
+ * const targetModel = await ModelLoader.loadPretrained('qwen3-7b');
754
+ * const draftModel = await ModelLoader.loadPretrained('qwen3-0.5b');
755
+ *
756
+ * const result = targetModel.generateSpeculativeSync(draftModel, inputIds, {
757
+ * numDraftTokens: 5,
758
+ * maxNewTokens: 100,
759
+ * temperature: 0.7,
760
+ * });
761
+ * ```
762
+ */
763
+ generateSpeculativeSync(
764
+ draftModel: Qwen3Model,
765
+ inputIds: MxArray,
766
+ config?: GenerationConfig | undefined | null,
767
+ ): GenerationResult;
768
+ /** Count total number of parameters in the model */
769
+ numParameters(): number;
770
+ /**
771
+ * Get all model parameters as a dictionary mapping names to arrays
772
+ *
773
+ * This matches the TypeScript API for compatibility
774
+ */
775
+ getParameters(): Record<string, MxArray>;
776
+ /** Load parameters from a dictionary */
777
+ loadParameters(params: Record<string, MxArray>): void;
778
+ /**
779
+ * Compute forward pass and loss (for evaluation)
780
+ *
781
+ * # Arguments
782
+ * * `input_ids` - Input token IDs, shape: [batch_size, seq_len]
783
+ * * `labels` - Target token IDs, shape: [batch_size, seq_len]
784
+ *
785
+ * # Returns
786
+ * * Scalar loss value
787
+ */
788
+ computeLoss(inputIds: MxArray, labels: MxArray): MxArray;
789
+ /**
790
+ * Compute loss and gradients using a hybrid approach
791
+ *
792
+ * This implementation computes gradients for the output layers and uses
793
+ * numerical approximations for other parameters. This is sufficient to
794
+ * demonstrate that training works while we build out full MLX autograd integration.
795
+ *
796
+ * # Arguments
797
+ * * `input_ids` - Input token IDs, shape: [batch_size, seq_len]
798
+ * * `labels` - Target token IDs, shape: [batch_size, seq_len]
799
+ *
800
+ * # Returns
801
+ * * A tuple of (loss, gradients_dict) where gradients_dict maps parameter names to gradient arrays
802
+ *
803
+ * # Phase 6A Status
804
+ * Current implementation computes:
805
+ * - ✅ Exact gradients for LM head (output layer)
806
+ * - ⚠吅 Numerical approximations for other layers
807
+ *
808
+ * Future: Full MLX autograd will compute exact gradients for all 250+ parameters
809
+ */
810
+ computeLossAndGradients(inputIds: MxArray, labels: MxArray): [MxArray, Record<string, MxArray>];
811
+ /**
812
+ * Complete GRPO training step using MLX Autograd (RECOMMENDED)
813
+ *
814
+ * This method uses automatic differentiation to compute gradients, eliminating
815
+ * the need for manual backward pass implementation. This is the preferred approach.
816
+ *
817
+ * # Arguments
818
+ * * `prompt_tokens` - Prompt token sequences [batch_size, seq_len] (1D arrays)
819
+ * * `completion_tokens` - Completion sequences [batch*G, completion_len] (1D arrays)
820
+ * * `completion_logprobs` - Logprobs from generation [batch*G, completion_len] (1D arrays)
821
+ * * `rewards` - Reward scores for each completion [batch*G]
822
+ * * `group_size` - Number of completions per prompt (G)
823
+ * * `config` - GRPO loss configuration
824
+ * * `learning_rate` - Learning rate for parameter updates
825
+ *
826
+ * # Returns
827
+ * * Tuple of (loss_value, metrics_dict)
828
+ */
829
+ trainStepGrpoAutograd(
830
+ promptTokens: Array<MxArray>,
831
+ completionTokens: Array<MxArray>,
832
+ completionLogprobs: Array<MxArray>,
833
+ rewards: Float64Array,
834
+ groupSize: number,
835
+ config: GrpoLossConfig,
836
+ learningRate: number,
837
+ ): [number, Record<string, number>];
838
+ /**
839
+ * Compute gradients only without applying them (for gradient accumulation)
840
+ *
841
+ * This method computes GRPO loss and gradients but does NOT update parameters.
842
+ * Used for gradient accumulation where gradients are summed across multiple
843
+ * micro-batches before applying them.
844
+ *
845
+ * # Arguments
846
+ * * `prompt_tokens` - Prompt token sequences [batch_size, seq_len] (1D arrays)
847
+ * * `completion_tokens` - Completion sequences [batch*G, completion_len] (1D arrays)
848
+ * * `completion_logprobs` - Logprobs from generation [batch*G, completion_len] (1D arrays)
849
+ * * `rewards` - Reward scores for each completion [batch*G]
850
+ * * `group_size` - Number of completions per prompt (G)
851
+ * * `config` - GRPO loss configuration
852
+ *
853
+ * # Returns
854
+ * * Tuple of (loss_value, gradients_dict, metrics_dict)
855
+ */
856
+ computeGradientsOnlyGrpoAutograd(
857
+ promptTokens: Array<MxArray>,
858
+ completionTokens: Array<MxArray>,
859
+ completionLogprobs: Array<MxArray>,
860
+ rewards: Float64Array,
861
+ groupSize: number,
862
+ config: GrpoLossConfig,
863
+ ): [number, Record<string, MxArray>, Record<string, number>];
864
+ /**
865
+ * Accumulate gradients into existing gradient dictionary
866
+ *
867
+ * This is a helper method for gradient accumulation. It adds new_gradients
868
+ * to accumulated_gradients element-wise.
869
+ *
870
+ * # Arguments
871
+ * * `accumulated_gradients` - Existing accumulated gradients (will be modified in-place conceptually, but returns new dict)
872
+ * * `new_gradients` - New gradients to add
873
+ *
874
+ * # Returns
875
+ * * Updated gradient dictionary with accumulated values
876
+ */
877
+ static accumulateGradients(
878
+ accumulatedGradients: Record<string, MxArray>,
879
+ newGradients: Record<string, MxArray>,
880
+ ): Record<string, MxArray>;
881
+ /**
882
+ * Complete GRPO training step using manual gradients (Legacy)
883
+ *
884
+ * This method performs a full GRPO training iteration:
885
+ * 1. Takes completions (already generated) with their logprobs and rewards
886
+ * 2. Computes advantages
887
+ * 3. Computes GRPO loss and gradients
888
+ * 4. Updates model parameters
889
+ *
890
+ * NOTE: Use train_step_grpo_autograd instead for automatic differentiation.
891
+ *
892
+ * # Arguments
893
+ * * `prompt_tokens` - Prompt token sequences [batch_size, seq_len] (1D arrays)
894
+ * * `completion_tokens` - Completion sequences [batch*G, completion_len] (1D arrays)
895
+ * * `completion_logprobs` - Logprobs from generation [batch*G, completion_len] (1D arrays)
896
+ * * `rewards` - Reward scores for each completion [batch*G]
897
+ * * `group_size` - Number of completions per prompt (G)
898
+ * * `config` - GRPO loss configuration
899
+ * * `learning_rate` - Learning rate for parameter updates
900
+ *
901
+ * # Returns
902
+ * * Tuple of (loss_value, metrics_dict)
903
+ */
904
+ trainStepGrpo(
905
+ promptTokens: Array<MxArray>,
906
+ completionTokens: Array<MxArray>,
907
+ completionLogprobs: Array<MxArray>,
908
+ rewards: Float64Array,
909
+ groupSize: number,
910
+ config: GrpoLossConfig,
911
+ learningRate: number,
912
+ ): [number, Record<string, number>];
913
+ /**
914
+ * Apply gradients to model parameters
915
+ *
916
+ * # Arguments
917
+ * * `gradients` - Dictionary mapping parameter names to gradient arrays
918
+ * * `learning_rate` - Learning rate for gradient descent
919
+ *
920
+ * This performs a simple SGD update: param = param - lr * grad
921
+ * Only updates parameters that have gradients; others remain unchanged.
922
+ *
923
+ * IMPORTANT: This function preserves the original dtype of parameters.
924
+ * The learning rate scalar is cast to match param dtype to prevent
925
+ * promotion to float32 during arithmetic operations.
926
+ */
927
+ applyGradients(gradients: Record<string, MxArray>, learningRate: number): void;
928
+ /**
929
+ * Text-to-text generation with integrated tokenization
930
+ *
931
+ * This is a high-level API that handles chat template formatting, tokenization,
932
+ * generation, and decoding internally. It takes chat messages, applies the ChatML
933
+ * template, generates tokens, and decodes them back to text.
934
+ *
935
+ * # Arguments
936
+ * * `messages` - Array of chat messages with role and content
937
+ * * `config` - Generation configuration
938
+ *
939
+ * # Returns
940
+ * * GenerationResult with text, tokens, logprobs, finish reason, and token count
941
+ *
942
+ * # Example
943
+ * ```typescript
944
+ * const model = await Qwen3Model.loadPretrained("path/to/model");
945
+ * const messages = [
946
+ * { role: "user", content: "What is 2+2?" }
947
+ * ];
948
+ * const result = await model.generate(messages, {
949
+ * maxNewTokens: 50,
950
+ * temperature: 0.8,
951
+ * topP: 0.95,
952
+ * });
953
+ * console.log(result.text); // Decoded text output
954
+ * console.log(result.tokens); // Token IDs (for GRPO)
955
+ * console.log(result.logprobs); // Log probabilities (for GRPO)
956
+ * ```
957
+ */
958
+ generate(messages: Array<ChatMessage>, config?: GenerationConfig | undefined | null): Promise<GenerationResult>;
959
+ /**
960
+ * High-level chat API with structured response parsing
961
+ *
962
+ * The primary API for conversational AI. Handles:
963
+ * - Chat message formatting with Jinja2 templates
964
+ * - Tool/function calling with structured output
965
+ * - Thinking extraction from `<think>` tags
966
+ * - Clean response text with all special tags stripped
967
+ *
968
+ * ## `chat()` vs `generate()`
969
+ *
970
+ * | Feature | `chat()` | `generate()` |
971
+ * |---------|----------|--------------|
972
+ * | **Purpose** | Conversational AI with tools | Raw text generation |
973
+ * | **Input** | Chat messages | Token IDs (MxArray) |
974
+ * | **Tool Support** | Built-in parsing | None |
975
+ * | **Thinking** | Extracts `<think>` content | Raw text only |
976
+ * | **Output** | Structured `ChatResult` | Basic `GenerationResult` |
977
+ * | **Use Case** | Chat apps, agents, assistants | Training, low-level control |
978
+ *
979
+ * ## When to use `chat()`
980
+ * - Building conversational applications
981
+ * - Need tool/function calling
982
+ * - Want structured responses with thinking separated
983
+ * - Working with chat message format
984
+ *
985
+ * ## When to use `generate()`
986
+ * - Training and fine-tuning (need raw logprobs)
987
+ * - Custom tokenization pipeline
988
+ * - Low-level generation control
989
+ * - Non-chat use cases
990
+ *
991
+ * # Arguments
992
+ * * `messages` - Array of chat messages (user/assistant/system roles)
993
+ * * `config` - Chat configuration including optional tools and generation params
994
+ *
995
+ * # Returns
996
+ * * `ChatResult` containing:
997
+ * - `text`: Clean response (tool_call and think tags stripped)
998
+ * - `thinking`: Extracted chain-of-thought reasoning (or null)
999
+ * - `toolCalls`: Parsed tool calls with native JS object arguments
1000
+ * - `finishReason`: "stop" | "length" | "tool_calls"
1001
+ * - `rawText`: Original text before processing (for debugging)
1002
+ *
1003
+ * # Example
1004
+ * ```typescript
1005
+ * // Simple chat
1006
+ * const result = await model.chat(messages);
1007
+ * console.log(result.text);
1008
+ *
1009
+ * // With tools
1010
+ * const result = await model.chat(messages, {
1011
+ * tools: [{ type: 'function', function: { name: 'get_weather' } }],
1012
+ * maxNewTokens: 2048,
1013
+ * temperature: 0.7,
1014
+ * });
1015
+ *
1016
+ * // Handle tool calls
1017
+ * for (const call of result.toolCalls) {
1018
+ * if (call.status === 'ok') {
1019
+ * console.log(call.name, call.arguments); // Arguments is a JS object!
1020
+ * }
1021
+ * }
1022
+ *
1023
+ * // Access thinking (chain-of-thought)
1024
+ * if (result.thinking) {
1025
+ * console.log('Model reasoning:', result.thinking);
1026
+ * }
1027
+ * ```
1028
+ */
1029
+ chat(messages: Array<ChatMessage>, config?: ChatConfig | undefined | null): Promise<ChatResult>;
1030
+ /**
1031
+ * Generate multiple completions for multiple prompts in batch
1032
+ *
1033
+ * This is an optimized method for GRPO training that generates G completions
1034
+ * for each of N prompts. It performs all tokenization, generation, and decoding
1035
+ * in 3 blocking tasks instead of N*(1+2G) tasks.
1036
+ *
1037
+ * # Arguments
1038
+ * * `prompts` - Array of N prompt message arrays
1039
+ * * `group_size` - Number of completions (G) to generate per prompt
1040
+ * * `config` - Generation configuration (sampling params, etc.)
1041
+ *
1042
+ * # Returns
1043
+ * * BatchGenerationResult containing N*G completions with:
1044
+ * - tokens: Flat array of N*G token arrays
1045
+ * - logprobs: Flat array of N*G logprob arrays
1046
+ * - texts: Flat array of N*G decoded texts
1047
+ * - finish_reasons: N arrays of G finish reasons
1048
+ * - token_counts: N arrays of G token counts
1049
+ *
1050
+ * # Performance
1051
+ * For N=10 prompts, G=8 completions:
1052
+ * - Old approach: N*(1 tokenize + G generate + G decode) = 10*(1+8+8) = 170 blocking tasks
1053
+ * - New approach: 1 tokenize + N*G generate + 1 decode = 1+80+1 = 82 blocking tasks (2.1x reduction)
1054
+ *
1055
+ * # Example
1056
+ * ```typescript
1057
+ * const result = await model.generateBatch(
1058
+ * [messages1, messages2, ...], // N prompts
1059
+ * 8, // G completions per prompt
1060
+ * config
1061
+ * );
1062
+ * ```
1063
+ */
1064
+ generateBatch(
1065
+ prompts: Array<Array<ChatMessage>>,
1066
+ groupSize: number,
1067
+ config?: GenerationConfig | undefined | null,
1068
+ ): Promise<BatchGenerationResult>;
1069
+ /**
1070
+ * Decode token IDs to text using the internal tokenizer
1071
+ *
1072
+ * Helper method for decoding generated tokens. The model must have been loaded
1073
+ * via load_pretrained() to have a tokenizer available.
1074
+ *
1075
+ * # Arguments
1076
+ * * `token_ids` - Token IDs to decode as Uint32Array
1077
+ * * `skip_special_tokens` - Whether to skip special tokens (default: true)
1078
+ *
1079
+ * # Returns
1080
+ * * Decoded text string
1081
+ */
1082
+ decode(tokenIds: Uint32Array, skipSpecialTokens?: boolean | undefined | null): Promise<string>;
1083
+ /**
1084
+ * Apply chat template and encode to token IDs
1085
+ *
1086
+ * Formats messages using ChatML format (or Jinja2 template with tools) and encodes to tokens.
1087
+ * The model must have been loaded via load_pretrained() to have a tokenizer available.
1088
+ *
1089
+ * # Arguments
1090
+ * * `messages` - Array of chat messages
1091
+ * * `add_generation_prompt` - Whether to add generation prompt (default: true)
1092
+ * * `tools` - Optional array of tool definitions for function calling
1093
+ * * `enable_thinking` - Optional flag to enable thinking mode (<think> tags)
1094
+ *
1095
+ * # Returns
1096
+ * * Encoded token IDs as Uint32Array
1097
+ */
1098
+ applyChatTemplate(
1099
+ messages: Array<ChatMessage>,
1100
+ addGenerationPrompt?: boolean | undefined | null,
1101
+ tools?: Array<ToolDefinition> | undefined | null,
1102
+ enableThinking?: boolean | undefined | null,
1103
+ ): Promise<Uint32Array>;
1104
+ /**
1105
+ * Load a pretrained model from disk
1106
+ *
1107
+ * This loads a model from a directory containing:
1108
+ * - config.json: Model configuration
1109
+ * - weights.mlx (optional): MLX format weights with data arrays
1110
+ * - weights.safetensors (optional): SafeTensors format (not yet supported)
1111
+ *
1112
+ * # Arguments
1113
+ * * `model_path` - Path to the model directory
1114
+ *
1115
+ * # Returns
1116
+ * * A fully initialized Qwen3Model with loaded weights
1117
+ */
1118
+ static loadPretrained(modelPath: string): Promise<Qwen3Model>;
1119
+ /**
1120
+ * Save model configuration and weights to disk
1121
+ *
1122
+ * This saves:
1123
+ * - config.json: Model configuration
1124
+ * - weights.safetensors: Full model weights in SafeTensors format
1125
+ * - weights.mlx: Parameter metadata (for reference)
1126
+ *
1127
+ * # Arguments
1128
+ * * `save_path` - Directory to save the model
1129
+ */
1130
+ saveModel(savePath: string): Promise<undefined>;
1131
+ /**
1132
+ * Validate that a set of parameters has all required weights with correct shapes
1133
+ *
1134
+ * This is useful for validating parameters before loading them into a model,
1135
+ * or for checking that saved weights are valid before training.
1136
+ *
1137
+ * # Arguments
1138
+ * * `params` - HashMap of parameter names to MxArray values
1139
+ *
1140
+ * # Returns
1141
+ * * Ok(()) if all validations pass
1142
+ * * Err with descriptive message if validation fails
1143
+ */
1144
+ validateParameters(params: Record<string, MxArray>): void;
1145
+ }
1146
+
1147
+ /** Qwen3 Tokenizer class with NAPI bindings */
1148
+ export declare class Qwen3Tokenizer {
1149
+ /**
1150
+ * Load tokenizer from tokenizer.json file
1151
+ *
1152
+ * # Arguments
1153
+ * * `path` - Path to tokenizer.json file (default: "../.cache/assets/tokenizers/qwen3_tokenizer.json")
1154
+ *
1155
+ * # Example
1156
+ * ```typescript
1157
+ * const tokenizer = Qwen3Tokenizer.fromPretrained();
1158
+ * const tokens = tokenizer.encode("Hello, world!");
1159
+ * ```
1160
+ */
1161
+ static fromPretrained(tokenizerPath: string): Promise<Qwen3Tokenizer>;
1162
+ /**
1163
+ * Encode text to token IDs
1164
+ *
1165
+ * # Arguments
1166
+ * * `text` - Text to encode
1167
+ * * `add_special_tokens` - Whether to add special tokens (default: true)
1168
+ *
1169
+ * # Returns
1170
+ * Array of token IDs as Int32Array
1171
+ *
1172
+ * # Example
1173
+ * ```typescript
1174
+ * const tokens = tokenizer.encode("Hello, world!");
1175
+ * console.log(tokens); // Int32Array [9906, 11, 1879, 0]
1176
+ * ```
1177
+ */
1178
+ encode(text: string, addSpecialTokens?: boolean | undefined | null): Promise<Uint32Array>;
1179
+ /**
1180
+ * Encode multiple texts in batch
1181
+ *
1182
+ * # Arguments
1183
+ * * `texts` - Array of texts to encode
1184
+ * * `add_special_tokens` - Whether to add special tokens (default: true)
1185
+ *
1186
+ * # Returns
1187
+ * Array of Int32Arrays, one for each text
1188
+ */
1189
+ encodeBatch(texts: Array<string>, addSpecialTokens?: boolean | undefined | null): Promise<Array<Uint32Array>>;
1190
+ /**
1191
+ * Decode token IDs to text
1192
+ *
1193
+ * # Arguments
1194
+ * * `token_ids` - Token IDs to decode
1195
+ * * `skip_special_tokens` - Whether to skip special tokens (default: true)
1196
+ *
1197
+ * # Returns
1198
+ * Decoded text string
1199
+ *
1200
+ * # Example
1201
+ * ```typescript
1202
+ * const text = tokenizer.decode(new Int32Array([9906, 11, 1879, 0]));
1203
+ * console.log(text); // "Hello, world!"
1204
+ * ```
1205
+ */
1206
+ decode(tokenIds: Uint32Array, skipSpecialTokens?: boolean | undefined | null): Promise<string>;
1207
+ /**
1208
+ * Decode multiple token sequences in batch
1209
+ *
1210
+ * # Arguments
1211
+ * * `token_ids_batch` - Array of token ID arrays to decode
1212
+ * * `skip_special_tokens` - Whether to skip special tokens (default: true)
1213
+ *
1214
+ * # Returns
1215
+ * Array of decoded text strings
1216
+ */
1217
+ decodeBatch(
1218
+ tokenIdsBatch: Array<Uint32Array>,
1219
+ skipSpecialTokens?: boolean | undefined | null,
1220
+ ): Promise<Array<string>>;
1221
+ /**
1222
+ * Apply chat template to messages and encode
1223
+ *
1224
+ * Supports both simple ChatML format and full Jinja2 template rendering with tools.
1225
+ * When tools are provided or a chat template exists, uses Jinja2 rendering.
1226
+ * Otherwise falls back to simple ChatML format.
1227
+ *
1228
+ * # Arguments
1229
+ * * `messages` - Array of chat messages
1230
+ * * `add_generation_prompt` - Whether to add assistant prompt at end (default: true)
1231
+ * * `tools` - Optional array of tool definitions for function calling
1232
+ * * `enable_thinking` - Optional flag to enable thinking mode (<think> tags)
1233
+ *
1234
+ * # Returns
1235
+ * Encoded token IDs ready for model input
1236
+ *
1237
+ * # Example
1238
+ * ```typescript
1239
+ * const messages = [
1240
+ * { role: "system", content: "You are a helpful assistant." },
1241
+ * { role: "user", content: "What is 2+2?" }
1242
+ * ];
1243
+ * const tokens = tokenizer.applyChatTemplate(messages, true);
1244
+ *
1245
+ * // With tools
1246
+ * const tools = [{
1247
+ * type: "function",
1248
+ * function: { name: "get_weather", description: "Get weather info" }
1249
+ * }];
1250
+ * const tokens = tokenizer.applyChatTemplate(messages, true, tools);
1251
+ * ```
1252
+ */
1253
+ applyChatTemplate(
1254
+ messages: Array<ChatMessage>,
1255
+ addGenerationPrompt?: boolean | undefined | null,
1256
+ tools?: Array<ToolDefinition> | undefined | null,
1257
+ enableThinking?: boolean | undefined | null,
1258
+ ): Promise<Uint32Array>;
1259
+ /** Get vocabulary size */
1260
+ vocabSize(): number;
1261
+ /** Get PAD token ID */
1262
+ getPadTokenId(): number;
1263
+ /** Get EOS token ID */
1264
+ getEosTokenId(): number;
1265
+ /** Get BOS token ID (if exists) */
1266
+ getBosTokenId(): number | null;
1267
+ /** Convert token ID to string */
1268
+ idToToken(id: number): string | null;
1269
+ /** Convert token string to ID */
1270
+ tokenToId(token: string): number | null;
1271
+ /** Get the special token for IM_START */
1272
+ getImStartToken(): string;
1273
+ /** Get the special token for IM_END */
1274
+ getImEndToken(): string;
1275
+ /** Get the special token for ENDOFTEXT (used as PAD) */
1276
+ getEndoftextToken(): string;
1277
+ }
1278
+
1279
+ /** SFT Training Engine */
1280
+ export declare class SftTrainingEngine {
1281
+ /** Create a new SFT training engine */
1282
+ constructor(model: Qwen3Model, config: SftEngineConfig);
1283
+ /** Run a single training step */
1284
+ trainStep(inputIds: MxArray, labels: MxArray): Promise<SftStepMetrics>;
1285
+ /** Get current step number */
1286
+ getStep(): number;
1287
+ /** Get current epoch */
1288
+ getEpoch(): number;
1289
+ /**
1290
+ * Flush any accumulated gradients at epoch end
1291
+ *
1292
+ * When stepsPerEpoch % gradient_accumulation_steps != 0, there may be
1293
+ * leftover gradients from the final micro-batches. This method applies
1294
+ * them with proper averaging, matching TRL behavior.
1295
+ */
1296
+ flushGradients(): boolean;
1297
+ /**
1298
+ * Compute the resume position given current state and dataset info
1299
+ *
1300
+ * This centralizes all resume logic in Rust for correctness.
1301
+ * Uses i64 math internally to avoid overflow on long runs.
1302
+ */
1303
+ computeResumePosition(stepsPerEpoch: number): ResumePosition;
1304
+ /** Check if emergency save is needed */
1305
+ needsEmergencySave(): boolean;
1306
+ /** Clear emergency save flag */
1307
+ clearEmergencySave(): void;
1308
+ /**
1309
+ * Signal start of a new epoch
1310
+ *
1311
+ * Takes the epoch number directly from TypeScript to ensure synchronization.
1312
+ * The epoch is 0-indexed to match the TypeScript training loop.
1313
+ */
1314
+ startEpoch(epoch: number): void;
1315
+ /** End current epoch and return metrics */
1316
+ endEpoch(epochTimeSecs: number): SftEpochMetrics;
1317
+ /** Reset training state (for new training run) */
1318
+ reset(): void;
1319
+ /** Restore training state (for resuming from checkpoint) */
1320
+ restoreState(step: number, epoch: number): void;
1321
+ /** Get the underlying model for checkpointing */
1322
+ getModel(): Qwen3Model;
1323
+ }
1324
+
1325
+ /**
1326
+ * A tensor that tracks gradients for automatic differentiation
1327
+ *
1328
+ * This is a wrapper around MxArray that provides:
1329
+ * - Gradient tracking
1330
+ * - Automatic gradient accumulation
1331
+ * - Integration with manual backward passes
1332
+ */
1333
+ export declare class Tensor {
1334
+ /** Create a tensor from float32 data */
1335
+ static fromFloat32(data: Float32Array, shape: BigInt64Array, requiresGrad?: boolean | undefined | null): Tensor;
1336
+ /** Create a tensor from int32 data */
1337
+ static fromInt32(data: Int32Array, shape: BigInt64Array, requiresGrad?: boolean | undefined | null): Tensor;
1338
+ /** Get the shape of the underlying data */
1339
+ dataShape(): BigInt64Array;
1340
+ /** Get the shape of the gradient (if it exists) */
1341
+ gradShape(): BigInt64Array | null;
1342
+ /** Check if gradient exists */
1343
+ hasGrad(): boolean;
1344
+ /** Check if this tensor requires gradients */
1345
+ get requiresGrad(): boolean;
1346
+ /** Set whether this tensor requires gradients */
1347
+ set requiresGrad(requiresGrad: boolean);
1348
+ /** Zero out the gradient */
1349
+ zeroGrad(): void;
1350
+ /**
1351
+ * Accumulate gradient
1352
+ *
1353
+ * If gradient already exists, add to it. Otherwise, set it.
1354
+ * Note: This takes ownership of the gradient array.
1355
+ */
1356
+ accumulateGrad(grad: MxArray): void;
1357
+ /** Get the shape of the tensor */
1358
+ shape(): BigInt64Array;
1359
+ /** Convert data to Float32 array */
1360
+ toFloat32(): Float32Array;
1361
+ /** Convert gradient to Float32 array (if it exists) */
1362
+ gradToFloat32(): Float32Array | null;
1363
+ /** Convert to Int32 array */
1364
+ toInt32(): Int32Array;
1365
+ /**
1366
+ * Detach this tensor from the computation graph
1367
+ *
1368
+ * Returns a new tensor with the same data but no gradient tracking
1369
+ */
1370
+ detach(): Tensor;
1371
+ /** Create a tensor of zeros */
1372
+ static zeros(
1373
+ shape: BigInt64Array,
1374
+ dtype?: DType | undefined | null,
1375
+ requiresGrad?: boolean | undefined | null,
1376
+ ): Tensor;
1377
+ /** Create a tensor of ones */
1378
+ static ones(
1379
+ shape: BigInt64Array,
1380
+ dtype?: DType | undefined | null,
1381
+ requiresGrad?: boolean | undefined | null,
1382
+ ): Tensor;
1383
+ /** Evaluate the underlying array */
1384
+ eval(): void;
1385
+ }
1386
+
1387
+ /** Result from VLM chat */
1388
+ export declare class VlmChatResult {
1389
+ /** Get the response text */
1390
+ get text(): string;
1391
+ /** Get the generated tokens */
1392
+ get tokens(): MxArray;
1393
+ /** Get the log probabilities */
1394
+ get logprobs(): MxArray;
1395
+ /** Get the finish reason */
1396
+ get finishReason(): 'stop' | 'length' | 'repetition';
1397
+ /** Get the number of tokens generated */
1398
+ get numTokens(): number;
1399
+ }
1400
+ export type VLMChatResult = VlmChatResult;
1401
+
1402
+ /**
1403
+ * Vision-Language Model
1404
+ *
1405
+ * A generic VLM for OCR and document understanding tasks.
1406
+ * Currently supports PaddleOCR-VL architecture (vision encoder + ERNIE language model).
1407
+ */
1408
+ export declare class VLModel {
1409
+ /** Create a new PaddleOCR-VL model */
1410
+ constructor(config: ModelConfig);
1411
+ /** Set the tokenizer */
1412
+ setTokenizer(tokenizer: Qwen3Tokenizer): void;
1413
+ /** Check if tokenizer is available */
1414
+ get hasTokenizer(): boolean;
1415
+ /**
1416
+ * Chat with the VLM model
1417
+ *
1418
+ * High-level API for conversational interaction with images.
1419
+ *
1420
+ * # Arguments
1421
+ * * `messages` - Chat messages (role + content)
1422
+ * * `config` - Chat configuration (including image_paths for automatic processing)
1423
+ *
1424
+ * # Returns
1425
+ * * VLMChatResult with generated text
1426
+ *
1427
+ * # Example
1428
+ * ```typescript
1429
+ * const result = model.chat(
1430
+ * [{ role: 'user', content: 'Describe this image.' }],
1431
+ * { imagePaths: ['./photo.jpg'], maxNewTokens: 256 }
1432
+ * );
1433
+ * ```
1434
+ */
1435
+ chat(messages: Array<VlmChatMessage>, config?: VlmChatConfig | undefined | null): VlmChatResult;
1436
+ /**
1437
+ * Simple OCR: extract text from an image file
1438
+ *
1439
+ * Convenience method that processes an image and extracts all text.
1440
+ *
1441
+ * # Arguments
1442
+ * * `image_path` - Path to the image file
1443
+ * * `prompt` - Optional custom prompt (default: "Extract all text from this image.")
1444
+ *
1445
+ * # Returns
1446
+ * * Extracted text as a string
1447
+ *
1448
+ * # Example
1449
+ * ```typescript
1450
+ * const text = await model.ocr('./receipt.jpg');
1451
+ * console.log(text);
1452
+ * ```
1453
+ */
1454
+ ocr(imagePath: string, prompt?: string | undefined | null): string;
1455
+ /**
1456
+ * Get input embeddings with vision features merged
1457
+ *
1458
+ * # Arguments
1459
+ * * `input_ids` - Token IDs [batch, seq_len]
1460
+ * * `pixel_values` - Optional image patches [batch, seq, channels, patch_h, patch_w]
1461
+ * * `image_grid_thw` - Optional grid dimensions [num_images, 3]
1462
+ *
1463
+ * # Returns
1464
+ * * Input embeddings with vision features inserted at image token positions
1465
+ */
1466
+ getInputEmbeddings(
1467
+ inputIds: MxArray,
1468
+ pixelValues?: MxArray | undefined | null,
1469
+ imageGridThw?: MxArray | undefined | null,
1470
+ ): MxArray;
1471
+ /**
1472
+ * Forward pass
1473
+ *
1474
+ * # Arguments
1475
+ * * `input_ids` - Token IDs [batch, seq_len]
1476
+ * * `pixel_values` - Optional image patches
1477
+ * * `image_grid_thw` - Optional grid dimensions
1478
+ * * `mask` - Optional attention mask
1479
+ *
1480
+ * # Returns
1481
+ * * Logits [batch, seq_len, vocab_size]
1482
+ */
1483
+ forward(
1484
+ inputIds: MxArray,
1485
+ pixelValues?: MxArray | undefined | null,
1486
+ imageGridThw?: MxArray | undefined | null,
1487
+ mask?: MxArray | undefined | null,
1488
+ ): MxArray;
1489
+ /**
1490
+ * Generate text tokens given input tokens and optional image
1491
+ *
1492
+ * Uses KV caching for efficient generation - each step only processes the
1493
+ * new token(s) while reusing cached key-value states from previous tokens.
1494
+ * Vision features are computed once at the start and cached.
1495
+ *
1496
+ * # Arguments
1497
+ * * `input_ids` - Input token IDs [1, seq_len]
1498
+ * * `pixel_values` - Optional image patches [1, num_patches, C, H, W]
1499
+ * * `image_grid_thw` - Optional grid dimensions [1, 3]
1500
+ * * `config` - Generation configuration
1501
+ *
1502
+ * # Returns
1503
+ * * GenerationResult with tokens, logprobs, and finish reason
1504
+ */
1505
+ generate(
1506
+ inputIds: MxArray,
1507
+ pixelValues?: MxArray | undefined | null,
1508
+ imageGridThw?: MxArray | undefined | null,
1509
+ config?: GenerationConfig | undefined | null,
1510
+ ): GenerationResult;
1511
+ /** Get model configuration */
1512
+ get config(): ModelConfig;
1513
+ /** Check if model is fully initialized */
1514
+ get isInitialized(): boolean;
1515
+ /**
1516
+ * Load a VLM from disk
1517
+ *
1518
+ * Loads a model from a directory containing:
1519
+ * - config.json: Model configuration
1520
+ * - model.safetensors or model-*.safetensors: Model weights in SafeTensors format
1521
+ *
1522
+ * # Arguments
1523
+ * * `model_path` - Path to the model directory
1524
+ *
1525
+ * # Returns
1526
+ * * A fully initialized VLModel with loaded weights
1527
+ *
1528
+ * # Example
1529
+ * ```typescript
1530
+ * import { VLModel } from '@mlx-node/vlm';
1531
+ * const model = await VLModel.load('./models/paddleocr-vl');
1532
+ * const result = model.chat(messages, { imagePaths: ['./image.jpg'] });
1533
+ * ```
1534
+ */
1535
+ static load(modelPath: string): Promise<VLModel>;
1536
+ /**
1537
+ * Load model configuration from disk without loading weights
1538
+ *
1539
+ * This is useful for inspecting model configuration before loading the full model.
1540
+ *
1541
+ * # Arguments
1542
+ * * `model_path` - Path to the model directory containing config.json
1543
+ *
1544
+ * # Returns
1545
+ * * ModelConfig with vision and text configuration
1546
+ *
1547
+ * # Example
1548
+ * ```typescript
1549
+ * import { VLModel } from '@mlx-node/vlm';
1550
+ * const config = await VLModel.loadConfig('./models/paddleocr-vl');
1551
+ * console.log(config.visionConfig.hiddenSize);
1552
+ * ```
1553
+ */
1554
+ static loadConfig(modelPath: string): Promise<ModelConfig>;
1555
+ }
1556
+
1557
+ /**
1558
+ * Build RewardOutput array from generation results.
1559
+ *
1560
+ * Parses tool calls and thinking from completions, creating structured outputs
1561
+ * aligned with the ChatResult structure.
1562
+ *
1563
+ * # Arguments
1564
+ * * `prompts` - Array of prompt texts (one per unique prompt, will be expanded by group_size)
1565
+ * * `completions` - Array of completion texts (prompts.len() * group_size total)
1566
+ * * `token_counts` - Array of token counts for each completion
1567
+ * * `finish_reasons` - Array of finish reasons from generation ("eos", "length", "stop", "repetition")
1568
+ * * `group_size` - Number of completions per prompt
1569
+ *
1570
+ * # Returns
1571
+ * Array of RewardOutput objects with structured completion data
1572
+ *
1573
+ * # Example
1574
+ * ```typescript
1575
+ * import { buildRewardOutputs } from '@mlx-node/core';
1576
+ *
1577
+ * const outputs = buildRewardOutputs(
1578
+ * ['What is 2+2?'], // prompts
1579
+ * ['<think>Let me calculate</think>
1580
+
1581
+ 4', '4'], // completions (group_size=2)
1582
+ * [10, 5], // token counts
1583
+ * ['eos', 'length'], // finish reasons
1584
+ * 2 // group_size
1585
+ * );
1586
+ *
1587
+ * outputs[0].completion.thinking; // "Let me calculate"
1588
+ * outputs[0].completion.text; // "4"
1589
+ * outputs[0].completion.finishReason; // "eos"
1590
+ * ```
1591
+ */
1592
+ export declare function buildRewardOutputs(
1593
+ prompts: Array<string>,
1594
+ completions: Array<string>,
1595
+ tokenCounts: Array<number>,
1596
+ finishReasons: Array<string>,
1597
+ groupSize: number,
1598
+ ): Array<RewardOutput>;
1599
+
1600
+ /** Configuration for built-in rewards */
1601
+ export interface BuiltinRewardConfig {
1602
+ /** Type of reward function */
1603
+ rewardType: BuiltinRewardType;
1604
+ /** Weight for this reward (default 1.0) */
1605
+ weight?: number;
1606
+ /** Allowed tool names (for ToolUse) */
1607
+ allowedTools?: Array<string>;
1608
+ /** Required tags (for XmlFormat) */
1609
+ requiredTags?: Array<string>;
1610
+ /** Minimum length (for Length) */
1611
+ minLength?: number;
1612
+ /** Maximum length (for Length) */
1613
+ maxLength?: number;
1614
+ /** Use character count vs word count (for Length) */
1615
+ useChars?: boolean;
1616
+ /** Required JSON fields (for JsonSchema) */
1617
+ requiredFields?: Array<string>;
1618
+ /** Whether tool call is required (for ToolUse) */
1619
+ required?: boolean;
1620
+ }
1621
+
1622
+ /** Built-in reward function types */
1623
+ export declare const enum BuiltinRewardType {
1624
+ /** Tool use validation */
1625
+ ToolUse = 'ToolUse',
1626
+ /** XML format validation */
1627
+ XmlFormat = 'XmlFormat',
1628
+ /** Length-based scoring */
1629
+ Length = 'Length',
1630
+ /** JSON schema validation */
1631
+ JsonSchema = 'JsonSchema',
1632
+ }
1633
+
1634
+ /**
1635
+ * Configuration for the high-level `chat()` API
1636
+ *
1637
+ * Combines tool definitions with generation parameters in a single config object.
1638
+ * Tools are optional - when not provided, `chat()` works as a simple conversational API.
1639
+ *
1640
+ * ## Example
1641
+ * ```typescript
1642
+ * // Simple chat (no tools)
1643
+ * const result = await model.chat(messages);
1644
+ *
1645
+ * // With tools
1646
+ * const result = await model.chat(messages, {
1647
+ * tools: [weatherTool, searchTool],
1648
+ * maxNewTokens: 2048,
1649
+ * temperature: 0.7,
1650
+ * });
1651
+ * ```
1652
+ */
1653
+ export interface ChatConfig {
1654
+ /**
1655
+ * Tool definitions for function calling (optional)
1656
+ *
1657
+ * When provided, the model can invoke these tools during generation.
1658
+ * Tool calls are parsed and returned in `ChatResult.toolCalls`.
1659
+ */
1660
+ tools?: Array<ToolDefinition>;
1661
+ /** Maximum number of new tokens to generate (default: 2048 for chat) */
1662
+ maxNewTokens?: number;
1663
+ /** Sampling temperature (0 = greedy, higher = more random) (default: 0.7) */
1664
+ temperature?: number;
1665
+ /** Top-k sampling: keep only top k tokens (0 = disabled) (default: 0) */
1666
+ topK?: number;
1667
+ /** Top-p (nucleus) sampling: keep tokens with cumulative prob < p (default: 0.9) */
1668
+ topP?: number;
1669
+ /** Min-p sampling: keep tokens with prob > min_p * max_prob (default: 0.0) */
1670
+ minP?: number;
1671
+ /** Repetition penalty factor (1.0 = no penalty) (default: 1.0) */
1672
+ repetitionPenalty?: number;
1673
+ /** Number of recent tokens to consider for repetition penalty (default: 20) */
1674
+ repetitionContextSize?: number;
1675
+ /** Stop if same token repeats this many times consecutively (default: 16) */
1676
+ maxConsecutiveTokens?: number;
1677
+ /** Stop if an n-gram pattern repeats this many times (default: 8) */
1678
+ maxNgramRepeats?: number;
1679
+ /** N-gram size for repetition detection (default: 3) */
1680
+ ngramSize?: number;
1681
+ /** EOS token ID (generation stops when this is generated) */
1682
+ eosTokenId?: number;
1683
+ /** Whether to return log probabilities (default: true) */
1684
+ returnLogprobs?: boolean;
1685
+ }
1686
+
1687
+ /** Chat message with tool calling support */
1688
+ export interface ChatMessage {
1689
+ /** Role: "system", "user", "assistant", or "tool" */
1690
+ role: string;
1691
+ /** Message content */
1692
+ content: string;
1693
+ /** Tool calls made by the assistant (for assistant messages) */
1694
+ toolCalls?: Array<ToolCall>;
1695
+ /** Tool call ID this message is responding to (for tool messages) */
1696
+ toolCallId?: string;
1697
+ /** Reasoning content for thinking mode (used with <think> tags) */
1698
+ reasoningContent?: string;
1699
+ }
1700
+
1701
+ /** Chat message role */
1702
+ export declare const enum ChatRole {
1703
+ /** User message */
1704
+ User = 'User',
1705
+ /** Assistant response */
1706
+ Assistant = 'Assistant',
1707
+ /** System prompt */
1708
+ System = 'System',
1709
+ }
1710
+
1711
+ /** Statistics about cleanup operations (NAPI wrapper) */
1712
+ export interface CleanupStats {
1713
+ /** Number of training steps deleted */
1714
+ stepsDeleted: number;
1715
+ /** Number of generations deleted */
1716
+ generationsDeleted: number;
1717
+ /** Number of tool calls deleted */
1718
+ toolCallsDeleted: number;
1719
+ /** Number of logs deleted */
1720
+ logsDeleted: number;
1721
+ }
1722
+
1723
+ /**
1724
+ * Structured completion information aligned with ChatResult.
1725
+ * Contains pre-parsed tool calls, thinking, and clean text.
1726
+ */
1727
+ export interface CompletionInfo {
1728
+ /** Clean text with <tool_call> and <think> tags removed */
1729
+ text: string;
1730
+ /** Raw output before tag stripping (for debugging/XML parsing) */
1731
+ rawText: string;
1732
+ /** Parsed tool calls (arguments are already JS objects) */
1733
+ toolCalls: Array<ToolCallResult>;
1734
+ /** Extracted thinking/reasoning from <think> tags (null if none) */
1735
+ thinking?: string;
1736
+ /** Number of tokens generated */
1737
+ numTokens: number;
1738
+ /** Finish reason: "stop" | "length" | "tool_calls" */
1739
+ finishReason: string;
1740
+ }
1741
+
1742
+ export interface ConversionOptions {
1743
+ /** Input directory containing model files (config.json, model.safetensors) */
1744
+ inputDir: string;
1745
+ /** Output directory for converted model */
1746
+ outputDir: string;
1747
+ /** Target dtype for conversion (default: "float32") */
1748
+ dtype?: string;
1749
+ /** Whether to verbose logging (default: false) */
1750
+ verbose?: boolean;
1751
+ }
1752
+
1753
+ export interface ConversionResult {
1754
+ /** Number of tensors converted */
1755
+ numTensors: number;
1756
+ /** Total number of parameters */
1757
+ numParameters: number;
1758
+ /** Output model path */
1759
+ outputPath: string;
1760
+ /** List of converted tensor names */
1761
+ tensorNames: Array<string>;
1762
+ }
1763
+
1764
+ /**
1765
+ * Convert a HuggingFace SafeTensors model to MLX format
1766
+ *
1767
+ * This function:
1768
+ * 1. Loads SafeTensors model from input directory
1769
+ * 2. Converts all tensors to specified dtype (default: float32)
1770
+ * 3. Saves converted model to output directory
1771
+ * 4. Copies config.json and tokenizer files
1772
+ *
1773
+ * # Arguments
1774
+ * * `options` - Conversion options (input_dir, output_dir, dtype, verbose)
1775
+ *
1776
+ * # Returns
1777
+ * * ConversionResult with statistics about the conversion
1778
+ *
1779
+ * # Example
1780
+ * ```typescript
1781
+ * import { convertModel } from '../../index.cjs';
1782
+ *
1783
+ * const result = await convertModel({
1784
+ * inputDir: '.cache/models/qwen3-0.6b',
1785
+ * outputDir: '.cache/models/qwen3-0.6b-mlx',
1786
+ * dtype: 'float32',
1787
+ * verbose: true
1788
+ * });
1789
+ *
1790
+ * console.log(`Converted ${result.numTensors} tensors (${result.numParameters} parameters)`);
1791
+ * ```
1792
+ */
1793
+ export declare function convertModel(options: ConversionOptions): Promise<ConversionResult>;
1794
+
1795
+ export declare function convertParquetToJsonl(inputPath: string, outputPath: string): void;
1796
+
1797
+ /** Create a default PaddleOCR-VL 1.5 configuration (JS factory function) */
1798
+ export declare function createPaddleocrVlConfig(): ModelConfig;
1799
+
1800
+ /** Document element - either a table or paragraph */
1801
+ export interface DocumentElement {
1802
+ elementType: ElementType;
1803
+ /** Table data (only present if element_type is Table) */
1804
+ table?: Table;
1805
+ /** Paragraph data (only present if element_type is Paragraph) */
1806
+ paragraph?: Paragraph;
1807
+ }
1808
+
1809
+ export declare const enum DType {
1810
+ Float32 = 0,
1811
+ Int32 = 1,
1812
+ Float16 = 2,
1813
+ BFloat16 = 3,
1814
+ Uint32 = 4,
1815
+ }
1816
+
1817
+ /** Document element type */
1818
+ export declare const enum ElementType {
1819
+ Table = 'Table',
1820
+ Paragraph = 'Paragraph',
1821
+ }
1822
+
1823
+ /** Metrics from a training epoch */
1824
+ export interface EngineEpochMetrics {
1825
+ /** Epoch number */
1826
+ epoch: number;
1827
+ /** Average loss for the epoch */
1828
+ avgLoss: number;
1829
+ /** Average reward for the epoch */
1830
+ avgReward: number;
1831
+ /** Total steps in the epoch */
1832
+ totalSteps: number;
1833
+ /** Total tokens processed */
1834
+ totalTokens: number;
1835
+ /** Time for the epoch (seconds) */
1836
+ epochTimeSecs: number;
1837
+ }
1838
+
1839
+ /** Metrics from a single training step */
1840
+ export interface EngineStepMetrics {
1841
+ /** Current step number */
1842
+ step: number;
1843
+ /** GRPO loss value */
1844
+ loss: number;
1845
+ /** Mean reward across completions */
1846
+ meanReward: number;
1847
+ /** Standard deviation of rewards */
1848
+ stdReward: number;
1849
+ /** Mean advantage value */
1850
+ meanAdvantage: number;
1851
+ /** Standard deviation of advantages */
1852
+ stdAdvantage: number;
1853
+ /** Total tokens generated this step */
1854
+ totalTokens: number;
1855
+ /** Whether gradients were applied */
1856
+ gradientsApplied: boolean;
1857
+ /** Time for generation (ms) */
1858
+ generationTimeMs: number;
1859
+ /** Time for training (ms) */
1860
+ trainingTimeMs: number;
1861
+ /** Peak memory usage this step (MB) */
1862
+ peakMemoryMb: number;
1863
+ /** Active memory at end of step (MB) */
1864
+ activeMemoryMb: number;
1865
+ }
1866
+
1867
+ /** Format parsed document according to config */
1868
+ export declare function formatDocument(doc: ParsedDocument, config?: ParserConfig | undefined | null): string;
1869
+
1870
+ /** Function definition for tool calling */
1871
+ export interface FunctionDefinition {
1872
+ /** Name of the function */
1873
+ name: string;
1874
+ /** Description of what the function does */
1875
+ description?: string;
1876
+ /** Parameter schema */
1877
+ parameters?: FunctionParameters;
1878
+ }
1879
+
1880
+ /** Function parameters schema (JSON Schema subset) */
1881
+ export interface FunctionParameters {
1882
+ /** Type (usually "object") */
1883
+ type: string;
1884
+ /** JSON string of property definitions */
1885
+ properties?: string;
1886
+ /** List of required parameter names */
1887
+ required?: Array<string>;
1888
+ }
1889
+
1890
+ /** Result from generate_batch_for_training with all data needed for training */
1891
+ export interface GenerateBatchResult {
1892
+ /** Generated completion texts */
1893
+ completionTexts: Array<string>;
1894
+ /** Completion token IDs (flattened, concatenated) */
1895
+ completionTokens: Array<number>;
1896
+ /** Completion log probabilities (flattened, concatenated) */
1897
+ completionLogprobs: Array<number>;
1898
+ /** Lengths of each completion (for reconstruction) */
1899
+ completionLengths: Array<number>;
1900
+ /** Finish reasons for each completion ("eos", "length", or "repetition") */
1901
+ finishReasons: Array<string>;
1902
+ }
1903
+
1904
+ /** Configuration for text generation */
1905
+ export interface GenerationConfig {
1906
+ /** Maximum number of new tokens to generate (default: 100) */
1907
+ maxNewTokens?: number;
1908
+ /** Sampling temperature (0 = greedy, higher = more random) (default: 1.0) */
1909
+ temperature?: number;
1910
+ /** Top-k sampling: keep only top k tokens (0 = disabled) (default: 0) */
1911
+ topK?: number;
1912
+ /** Top-p (nucleus) sampling: keep tokens with cumulative prob < p (default: 1.0) */
1913
+ topP?: number;
1914
+ /** Min-p sampling: keep tokens with prob > min_p * max_prob (default: 0.0) */
1915
+ minP?: number;
1916
+ /** Repetition penalty factor (1.0 = no penalty, 1.1-1.5 typical) (default: 1.0) */
1917
+ repetitionPenalty?: number;
1918
+ /**
1919
+ * Number of recent tokens to consider for repetition penalty (default: 20)
1920
+ * Matches mlx-lm default. Larger values catch longer patterns but use more memory
1921
+ */
1922
+ repetitionContextSize?: number;
1923
+ /**
1924
+ * Stop if same token repeats this many times consecutively (default: 16)
1925
+ * Set to 0 to disable. Prevents OOM from degenerate repetitive generation.
1926
+ */
1927
+ maxConsecutiveTokens?: number;
1928
+ /**
1929
+ * Stop if an n-gram pattern repeats this many times (default: 8)
1930
+ * Set to 0 to disable. Detects patterns like "A B A B A B A B".
1931
+ */
1932
+ maxNgramRepeats?: number;
1933
+ /**
1934
+ * N-gram size for repetition detection (default: 3)
1935
+ * Used with max_ngram_repeats to detect repeating patterns.
1936
+ */
1937
+ ngramSize?: number;
1938
+ /** EOS token ID (generation stops when this is generated) */
1939
+ eosTokenId?: number;
1940
+ /** Whether to return log probabilities (always true for GRPO) */
1941
+ returnLogprobs?: boolean;
1942
+ /**
1943
+ * Prefill step size for chunked processing of long prompts (default: 2048)
1944
+ * When the prompt length exceeds this value, it will be processed in chunks
1945
+ * to improve memory efficiency and enable async pipelining.
1946
+ * Set to 0 to disable chunking and process the entire prompt at once.
1947
+ */
1948
+ prefillStepSize?: number;
1949
+ /**
1950
+ * KV cache quantization bits (default: 16 = no quantization)
1951
+ * - 16: Full precision (bfloat16/float16), no quantization
1952
+ * - 8: 8-bit quantization, ~2x memory savings, minimal quality loss
1953
+ * - 4: 4-bit quantization, ~4x memory savings, some quality degradation
1954
+ *
1955
+ * Quantized KV cache is useful for long sequences where memory becomes a bottleneck.
1956
+ * Note: Adds dequantization overhead per forward pass.
1957
+ */
1958
+ kvCacheBits?: number;
1959
+ /**
1960
+ * KV cache quantization group size (default: 64)
1961
+ * Number of elements per quantization group. Smaller groups = better accuracy
1962
+ * but more overhead from storing scales/biases.
1963
+ * Only used when kv_cache_bits is 4 or 8.
1964
+ */
1965
+ kvCacheGroupSize?: number;
1966
+ /**
1967
+ * Number of draft tokens to generate speculatively (default: 5)
1968
+ * Only used when a draft model is provided for speculative decoding.
1969
+ * Higher values can increase throughput but may reduce acceptance rate.
1970
+ */
1971
+ numDraftTokens?: number;
1972
+ }
1973
+
1974
+ /** A generation record (NAPI wrapper) */
1975
+ export interface GenerationRecord {
1976
+ batchIndex: number;
1977
+ groupIndex: number;
1978
+ prompt: string;
1979
+ expectedAnswer?: string;
1980
+ completionText: string;
1981
+ completionRaw: string;
1982
+ thinking?: string;
1983
+ numTokens: number;
1984
+ finishReason: string;
1985
+ reward: number;
1986
+ }
1987
+
1988
+ /** A generation with its associated tool calls (NAPI wrapper) */
1989
+ export interface GenerationWithToolCalls {
1990
+ generation: GenerationRecord;
1991
+ toolCalls: Array<ToolCallRecord>;
1992
+ }
1993
+
1994
+ /** Get expected weight keys for PaddleOCR-VL model */
1995
+ export declare function getExpectedWeightKeys(): Array<string>;
1996
+
1997
+ /** Configuration for the GRPO training engine */
1998
+ export interface GrpoEngineConfig {
1999
+ /** Learning rate (default: 1e-6) */
2000
+ learningRate?: number;
2001
+ /** Gradient accumulation steps (default: 1) */
2002
+ gradientAccumulationSteps?: number;
2003
+ /** Maximum gradient norm for clipping (default: 1.0) */
2004
+ gradientClipNorm?: number;
2005
+ /**
2006
+ * Maximum gradient value for element-wise clipping (default: 1.0)
2007
+ * This clamps individual gradient elements to [-value, value]
2008
+ */
2009
+ gradientClipValue?: number;
2010
+ /** Number of completions per prompt (default: 4) */
2011
+ groupSize?: number;
2012
+ /** PPO clipping epsilon (default: 0.2) */
2013
+ clipEpsilon?: number;
2014
+ /** KL divergence coefficient (default: 0.0) */
2015
+ klCoef?: number;
2016
+ /** Loss type: "grpo", "dapo", "dr_grpo", "bnpo" (default: "grpo") */
2017
+ lossType?: string;
2018
+ /**
2019
+ * Maximum completion length for both generation and training (default: 256)
2020
+ * Matches Python TRL's max_completion_length config.
2021
+ */
2022
+ maxCompletionLength?: number;
2023
+ /** Sampling temperature (default: 0.8) */
2024
+ temperature?: number;
2025
+ /** Top-p (nucleus) sampling (default: 0.95) */
2026
+ topP?: number;
2027
+ /** Top-k sampling (optional) */
2028
+ topK?: number;
2029
+ /** Repetition penalty (default: 1.1) */
2030
+ repetitionPenalty?: number;
2031
+ /**
2032
+ * Maximum allowed NaN gradient occurrences before stopping training (default: 100)
2033
+ * When exceeded, training will stop with an error to prevent model corruption.
2034
+ */
2035
+ maxNanGradients?: number;
2036
+ /**
2037
+ * Consecutive NaN gradients that trigger emergency checkpoint (default: 5)
2038
+ * When reached, the needs_emergency_save flag is set for the TypeScript layer.
2039
+ */
2040
+ emergencySaveThreshold?: number;
2041
+ /**
2042
+ * Enable detailed NaN/Inf detection with per-element counts (default: false)
2043
+ * When false (default), uses GPU-native has_nan_or_inf() which only transfers a single
2044
+ * boolean to CPU. When true, transfers the entire gradient tensor to CPU for detailed
2045
+ * per-element analysis - useful for debugging but has significant performance overhead
2046
+ * for large models (e.g., 2.4GB for Qwen3-0.6B).
2047
+ */
2048
+ verboseNanDetection?: boolean;
2049
+ /**
2050
+ * Enable thinking mode for Qwen3 models (default: true)
2051
+ * When false, adds empty <think></think> tags to disable model thinking.
2052
+ * This is useful for tool-use training where you want direct outputs.
2053
+ */
2054
+ enableThinking?: boolean;
2055
+ /**
2056
+ * Tool definitions for function calling
2057
+ * When provided, tools are included in the chat template so the model
2058
+ * can generate tool calls. This is essential for tool-use training.
2059
+ */
2060
+ tools?: Array<ToolDefinition>;
2061
+ /**
2062
+ * Batch chunk size for LM head computation (memory optimization).
2063
+ * When set, the LM head (hidden_states -> logits) is computed in chunks
2064
+ * of this size to reduce peak memory usage.
2065
+ * Default: None (no chunking, full batch at once)
2066
+ * Recommended: 2 for batch_size >= 4 with large vocabularies (e.g., 151936)
2067
+ * This reduces peak memory from ~1.2GB to ~300MB for Qwen3 (vocab=151936).
2068
+ */
2069
+ lmHeadChunkSize?: number;
2070
+ /**
2071
+ * Batch chunk size for transformer forward pass (memory optimization).
2072
+ * When set, the transformer layers process the batch in chunks of this size,
2073
+ * reducing peak memory from O(batch × heads × seq²) for attention.
2074
+ * Default: None (no chunking, full batch at once)
2075
+ * Recommended: 4 for batch_size >= 4 with groupSize >= 4
2076
+ * Memory savings: ~70-80% for batch=4, groupSize=4 (16 sequences → 4 at a time)
2077
+ */
2078
+ forwardChunkSize?: number;
2079
+ /**
2080
+ * Chunk size for vocabulary dimension in cross-entropy computation.
2081
+ * When computing logsumexp over large vocabularies (e.g., Qwen3's 151,936 tokens),
2082
+ * the computation is split into chunks of this size to reduce peak memory usage.
2083
+ * Default: 65536 (2^16)
2084
+ * Recommended: 65536 for Qwen3 (vocab=151936) splits into 3 chunks
2085
+ * Set to a larger value to reduce chunking overhead or smaller for tighter memory constraints.
2086
+ */
2087
+ vocabChunkSize?: number;
2088
+ /**
2089
+ * Enable true parallel batch generation (default: false).
2090
+ * When true, all N*G sequences are processed in parallel using batched FFI
2091
+ * with per-sequence RoPE offsets. This provides 2-4x speedup for GRPO training.
2092
+ * When false, uses the sequential generation (process one prompt at a time,
2093
+ * then expand KV cache for G completions).
2094
+ */
2095
+ useParallelBatchGeneration?: boolean;
2096
+ }
2097
+
2098
+ /** Configuration for GRPO loss computation */
2099
+ export interface GrpoLossConfig {
2100
+ /** Lower clipping bound (default: 0.2, means clip to [1-0.2, 1+epsilon_high]) */
2101
+ epsilonLow: number;
2102
+ /** Upper clipping bound (default: same as epsilon_low) */
2103
+ epsilonHigh?: number;
2104
+ /** KL divergence penalty coefficient (default: 0.0, no penalty) */
2105
+ beta: number;
2106
+ /** Loss aggregation type: "grpo", "bnpo", "dr_grpo", or "dapo" */
2107
+ lossType: string;
2108
+ /** Importance sampling level: "token" or "sequence" */
2109
+ importanceSamplingLevel: string;
2110
+ /**
2111
+ * Maximum completion length (legacy, no longer used by dr_grpo)
2112
+ * Kept for backwards compatibility but ignored in current implementation.
2113
+ */
2114
+ maxCompletionLength?: number;
2115
+ /** Total number of items in batch across all processes (needed for dapo) */
2116
+ numItemsInBatch?: number;
2117
+ /** Current gradient accumulation step (for loss scaling) */
2118
+ gradientAccumulationSteps: number;
2119
+ /**
2120
+ * Batch chunk size for LM head computation (memory optimization).
2121
+ * When set, the LM head (hidden_states -> logits) is computed in chunks
2122
+ * of this size to reduce peak memory usage.
2123
+ * Default: None (no chunking, full batch at once)
2124
+ * Recommended: 2 for batch_size >= 4 with large vocabularies (e.g., 151936)
2125
+ */
2126
+ lmHeadChunkSize?: number;
2127
+ /**
2128
+ * Batch chunk size for transformer forward pass (memory optimization).
2129
+ * When set, the transformer layers process the batch in chunks of this size,
2130
+ * reducing peak memory from O(batch × heads × seq²) for attention.
2131
+ * Default: None (no chunking, full batch at once)
2132
+ * Recommended: 4 for batch_size >= 4 with groupSize >= 4
2133
+ * Memory savings: ~70-80% for batch=4, groupSize=4 (16 sequences → 4 at a time)
2134
+ */
2135
+ forwardChunkSize?: number;
2136
+ /**
2137
+ * Chunk size for vocabulary dimension in cross-entropy computation.
2138
+ * When computing logsumexp over large vocabularies (e.g., Qwen3's 151,936 tokens),
2139
+ * the computation is split into chunks of this size to reduce peak memory usage.
2140
+ * Default: 65536 (2^16)
2141
+ * Recommended: 65536 for Qwen3 (vocab=151936) splits into 3 chunks
2142
+ */
2143
+ vocabChunkSize?: number;
2144
+ }
2145
+
2146
+ /** Full model configuration */
2147
+ export interface ModelConfig {
2148
+ visionConfig: VisionConfig;
2149
+ textConfig: TextConfig;
2150
+ modelType: string;
2151
+ ignoreIndex: number;
2152
+ imageTokenId: number;
2153
+ videoTokenId: number;
2154
+ visionStartTokenId: number;
2155
+ visionEndTokenId: number;
2156
+ eosTokenId: number;
2157
+ }
2158
+
2159
+ /** Output format options */
2160
+ export declare const enum OutputFormat {
2161
+ /** Raw output with minimal processing */
2162
+ Raw = 'Raw',
2163
+ /** Plain text with aligned columns */
2164
+ Plain = 'Plain',
2165
+ /** Markdown tables */
2166
+ Markdown = 'Markdown',
2167
+ /** HTML tables */
2168
+ Html = 'Html',
2169
+ }
2170
+
2171
+ /** Configuration for creating an OutputStore connection */
2172
+ export interface OutputStoreConfig {
2173
+ /** Local SQLite file path (e.g., "training_outputs.db") */
2174
+ localPath: string;
2175
+ }
2176
+
2177
+ /** Paged attention memory statistics (NAPI-compatible) */
2178
+ export interface PagedCacheStats {
2179
+ /** Total number of blocks in the pool */
2180
+ totalBlocks: number;
2181
+ /** Number of free blocks */
2182
+ freeBlocks: number;
2183
+ /** Number of allocated blocks */
2184
+ allocatedBlocks: number;
2185
+ /** Total memory in MB */
2186
+ totalMemoryMb: number;
2187
+ /** Used memory in MB */
2188
+ usedMemoryMb: number;
2189
+ /** Utilization percentage */
2190
+ utilizationPercent: number;
2191
+ }
2192
+
2193
+ /** A completed sequence from paged generation */
2194
+ export interface PagedCompletedSequence {
2195
+ /** Original request ID */
2196
+ requestId: string;
2197
+ /** All generated tokens (excluding prompt) */
2198
+ tokens: Array<number>;
2199
+ /** Reason for completion ("eos", "max_tokens", etc.) */
2200
+ finishReason: string;
2201
+ }
2202
+
2203
+ /** Result of a paged generation step */
2204
+ export interface PagedGenerationStep {
2205
+ /** Token outputs for each sequence in the batch */
2206
+ outputs: Array<PagedTokenOutput>;
2207
+ /** Number of sequences that were in prefill phase */
2208
+ numPrefill: number;
2209
+ /** Number of sequences that were in decode phase */
2210
+ numDecode: number;
2211
+ }
2212
+
2213
+ /** Output from a single token generation step in paged attention */
2214
+ export interface PagedTokenOutput {
2215
+ /** Sequence ID in the scheduler */
2216
+ seqId: number;
2217
+ /** Request ID for this sequence */
2218
+ requestId: string;
2219
+ /** Generated token ID */
2220
+ token: number;
2221
+ /** Log probability of the token (f64 for NAPI compatibility) */
2222
+ logprob: number;
2223
+ /** Whether this sequence has finished */
2224
+ isFinished: boolean;
2225
+ }
2226
+
2227
+ /** A text paragraph */
2228
+ export interface Paragraph {
2229
+ content: string;
2230
+ }
2231
+
2232
+ /** Parsed document structure */
2233
+ export interface ParsedDocument {
2234
+ elements: Array<DocumentElement>;
2235
+ }
2236
+
2237
+ /**
2238
+ * Parse and format PaddleOCR-VL response in one step
2239
+ *
2240
+ * Convenience function that parses the VLM output and formats it
2241
+ * according to the specified configuration.
2242
+ *
2243
+ * # Arguments
2244
+ * * `text` - Raw VLM output containing table tokens
2245
+ * * `config` - Optional parser configuration (format, trim_cells, etc.)
2246
+ *
2247
+ * # Returns
2248
+ * * Formatted string in the requested format (markdown, plain, html, raw)
2249
+ *
2250
+ * # Example
2251
+ * ```typescript
2252
+ * import { parsePaddleResponse } from '@mlx-node/core';
2253
+ *
2254
+ * // Parse and format as markdown (default)
2255
+ * const markdown = parsePaddleResponse(vlmResult.text);
2256
+ *
2257
+ * // Parse and format as HTML
2258
+ * const html = parsePaddleResponse(vlmResult.text, { format: 'html' });
2259
+ *
2260
+ * // Parse and format as plain text
2261
+ * const plain = parsePaddleResponse(vlmResult.text, { format: 'plain' });
2262
+ * ```
2263
+ */
2264
+ export declare function parsePaddleResponse(text: string, config?: ParserConfig | undefined | null): string;
2265
+
2266
+ /** Parser configuration */
2267
+ export interface ParserConfig {
2268
+ /** Output format (default: 'markdown') */
2269
+ format?: OutputFormat;
2270
+ /** Whether to trim whitespace from cells (default: true) */
2271
+ trimCells?: boolean;
2272
+ /** Whether to collapse empty rows (default: true) */
2273
+ collapseEmptyRows?: boolean;
2274
+ }
2275
+
2276
+ /**
2277
+ * Parse tool calls from text (NAPI export)
2278
+ *
2279
+ * Extracts tool calls from model-generated text and returns both the cleaned text
2280
+ * and the parsed tool calls.
2281
+ *
2282
+ * # Example
2283
+ * ```typescript
2284
+ * import { parseToolCallsFromText } from '@mlx-node/core';
2285
+ *
2286
+ * const result = parseToolCallsFromText('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>');
2287
+ * console.log(result.text); // ""
2288
+ * console.log(result.toolCalls[0].name); // "search"
2289
+ * console.log(result.toolCalls[0].arguments.q); // "test"
2290
+ * ```
2291
+ */
2292
+ export declare function parseToolCallsFromText(text: string): ParseToolCallsResult;
2293
+
2294
+ /** Result of parsing tool calls from text */
2295
+ export interface ParseToolCallsResult {
2296
+ /** Cleaned text with tool_call tags removed */
2297
+ text: string;
2298
+ /** Parsed tool calls */
2299
+ toolCalls: Array<ToolCallResult>;
2300
+ }
2301
+
2302
+ /** Parse VLM output into structured document */
2303
+ export declare function parseVlmOutput(text: string): ParsedDocument;
2304
+
2305
+ /** Qwen3 model configuration */
2306
+ export interface Qwen3Config {
2307
+ vocabSize: number;
2308
+ hiddenSize: number;
2309
+ numLayers: number;
2310
+ numHeads: number;
2311
+ numKvHeads: number;
2312
+ intermediateSize: number;
2313
+ rmsNormEps: number;
2314
+ ropeTheta: number;
2315
+ maxPositionEmbeddings: number;
2316
+ headDim: number;
2317
+ useQkNorm: boolean;
2318
+ tieWordEmbeddings: boolean;
2319
+ padTokenId: number;
2320
+ eosTokenId: number;
2321
+ bosTokenId: number;
2322
+ /**
2323
+ * Enable paged attention for memory-efficient inference.
2324
+ * Default: false (use standard KVCache)
2325
+ */
2326
+ usePagedAttention?: boolean | undefined;
2327
+ /**
2328
+ * GPU memory budget for paged KV cache in megabytes.
2329
+ * Only used when use_paged_attention is true.
2330
+ * Default: 2048 (2GB)
2331
+ */
2332
+ pagedCacheMemoryMb?: number | undefined;
2333
+ /**
2334
+ * Block size for paged attention (tokens per block).
2335
+ * Only used when use_paged_attention is true.
2336
+ * Default: 16
2337
+ */
2338
+ pagedBlockSize?: number | undefined;
2339
+ /**
2340
+ * Use FP8 cache for 2x memory reduction (experimental).
2341
+ * Only used when use_paged_attention is true.
2342
+ * Default: false
2343
+ */
2344
+ useFp8Cache?: boolean | undefined;
2345
+ }
2346
+
2347
+ /** Result of resume position computation */
2348
+ export interface ResumePosition {
2349
+ /** Epoch to start from (0-indexed) */
2350
+ startEpoch: number;
2351
+ /** Batch index within epoch to start from */
2352
+ startBatchIdx: number;
2353
+ /** Whether we're at an epoch boundary */
2354
+ isEpochBoundary: boolean;
2355
+ }
2356
+
2357
+ /**
2358
+ * Reward function input for a single completion.
2359
+ * Provides all context needed to compute a reward score.
2360
+ */
2361
+ export interface RewardOutput {
2362
+ /** The input prompt text */
2363
+ prompt: string;
2364
+ /** Structured completion data aligned with ChatResult */
2365
+ completion: CompletionInfo;
2366
+ }
2367
+
2368
+ /** Reward distribution statistics (NAPI wrapper) */
2369
+ export interface RewardStats {
2370
+ count: number;
2371
+ mean: number;
2372
+ std: number;
2373
+ min: number;
2374
+ max: number;
2375
+ median: number;
2376
+ p25: number;
2377
+ p75: number;
2378
+ }
2379
+
2380
+ /** Aggregate statistics for a training run for resume state (NAPI wrapper) */
2381
+ export interface RunAggregates {
2382
+ /** Best (highest) reward seen */
2383
+ bestReward: number;
2384
+ /** Average reward */
2385
+ avgReward: number;
2386
+ /** Total reward count */
2387
+ rewardCount: number;
2388
+ /** Best (lowest) loss seen */
2389
+ bestLoss: number;
2390
+ /** Average loss */
2391
+ avgLoss: number;
2392
+ /** Total loss count */
2393
+ lossCount: number;
2394
+ /** Total tokens generated */
2395
+ totalTokens: number;
2396
+ /** Current step number */
2397
+ currentStep: number;
2398
+ /** Average generation time (milliseconds) */
2399
+ avgGenerationTimeMs: number;
2400
+ /** Average training time (milliseconds) */
2401
+ avgTrainingTimeMs: number;
2402
+ }
2403
+
2404
+ /**
2405
+ * Configuration for sampling strategies
2406
+ * ⚡ PERFORMANCE: Made Copy to avoid cloning on every token
2407
+ */
2408
+ export interface SamplingConfig {
2409
+ /** Temperature for softmax (default: 1.0). Lower = more deterministic */
2410
+ temperature?: number;
2411
+ /** Number of top tokens to keep (top-k sampling). 0 = disabled */
2412
+ topK?: number;
2413
+ /** Cumulative probability threshold (top-p/nucleus sampling). 1.0 = disabled */
2414
+ topP?: number;
2415
+ /** Minimum probability threshold relative to max (min-p sampling). 0 = disabled */
2416
+ minP?: number;
2417
+ }
2418
+
2419
+ /** Scheduler statistics (NAPI-compatible) */
2420
+ export interface SchedulerStatsNapi {
2421
+ /** Number of requests waiting to be scheduled */
2422
+ numWaiting: number;
2423
+ /** Number of sequences currently running */
2424
+ numRunning: number;
2425
+ /** Number of completed sequences */
2426
+ numCompleted: number;
2427
+ /** Number of sequences in prefill phase */
2428
+ numPrefill: number;
2429
+ /** Number of sequences in decode phase */
2430
+ numDecode: number;
2431
+ /** Total tokens across all running sequences */
2432
+ totalRunningTokens: number;
2433
+ }
2434
+
2435
+ /** Configuration for the SFT training engine */
2436
+ export interface SftEngineConfig {
2437
+ /** Learning rate (default: 2e-5) */
2438
+ learningRate?: number;
2439
+ /** Gradient accumulation steps (default: 1) */
2440
+ gradientAccumulationSteps?: number;
2441
+ /** Maximum gradient norm for clipping (default: 1.0) */
2442
+ gradientClipNorm?: number;
2443
+ /** Maximum gradient value for element-wise clipping (optional) */
2444
+ gradientClipValue?: number;
2445
+ /** Weight decay (L2 regularization) (default: 0.01) */
2446
+ weightDecay?: number;
2447
+ /** Label smoothing factor (default: 0.0) */
2448
+ labelSmoothing?: number;
2449
+ /** Steps between heavy cleanup (default: 25) */
2450
+ heavyCleanupInterval?: number;
2451
+ /** Maximum allowed NaN gradient occurrences (default: 100) */
2452
+ maxNanGradients?: number;
2453
+ /** Consecutive NaN gradients that trigger emergency checkpoint (default: 5) */
2454
+ emergencySaveThreshold?: number;
2455
+ /** Compute token accuracy (requires extra forward pass) (default: false) */
2456
+ computeAccuracy?: boolean;
2457
+ /**
2458
+ * Enable detailed NaN/Inf detection with per-element counts (default: false)
2459
+ * When false (default), uses GPU-native has_nan_or_inf() which only transfers a single
2460
+ * boolean to CPU. When true, transfers the entire gradient tensor to CPU for detailed
2461
+ * per-element analysis - useful for debugging but has significant performance overhead.
2462
+ */
2463
+ verboseNanDetection?: boolean;
2464
+ }
2465
+
2466
+ /** Metrics from a training epoch */
2467
+ export interface SftEpochMetrics {
2468
+ /** Epoch number */
2469
+ epoch: number;
2470
+ /** Average loss for the epoch */
2471
+ avgLoss: number;
2472
+ /** Total steps in the epoch */
2473
+ totalSteps: number;
2474
+ /** Total tokens processed */
2475
+ totalTokens: number;
2476
+ /** Time for the epoch (seconds) */
2477
+ epochTimeSecs: number;
2478
+ }
2479
+
2480
+ /** Metrics from a single training step */
2481
+ export interface SftStepMetrics {
2482
+ /** Current step number */
2483
+ step: number;
2484
+ /** Cross-entropy loss value */
2485
+ loss: number;
2486
+ /** Total tokens processed this step (non-ignored) */
2487
+ totalTokens: number;
2488
+ /** Token-level accuracy (if compute_accuracy enabled) */
2489
+ tokenAccuracy?: number;
2490
+ /** Whether gradients were applied (vs accumulated) */
2491
+ gradientsApplied: boolean;
2492
+ /** Time for training step (ms) */
2493
+ trainingTimeMs: number;
2494
+ }
2495
+
2496
+ /** Metrics from a single training step for sparkline restoration (NAPI wrapper) */
2497
+ export interface StepMetricSummary {
2498
+ /** Step number */
2499
+ step: number;
2500
+ /** Loss value */
2501
+ loss: number;
2502
+ /** Mean reward (GRPO) */
2503
+ meanReward: number;
2504
+ /** Mean advantage (GRPO) */
2505
+ meanAdvantage: number;
2506
+ /** Std advantage (GRPO) - indicates reward variance within groups */
2507
+ stdAdvantage: number;
2508
+ /** Perplexity (SFT, optional) */
2509
+ perplexity?: number;
2510
+ /** Token accuracy (SFT, optional) */
2511
+ tokenAccuracy?: number;
2512
+ /** Total tokens this step */
2513
+ totalTokens: number;
2514
+ /** Time for generation phase (milliseconds) */
2515
+ generationTimeMs?: number;
2516
+ /** Time for training phase (milliseconds) */
2517
+ trainingTimeMs?: number;
2518
+ }
2519
+
2520
+ /** A training step record (NAPI wrapper) */
2521
+ export interface StepRecord {
2522
+ runId: string;
2523
+ step: number;
2524
+ epoch?: number;
2525
+ loss: number;
2526
+ meanReward: number;
2527
+ stdReward: number;
2528
+ meanAdvantage?: number;
2529
+ stdAdvantage: number;
2530
+ totalTokens?: number;
2531
+ generationTimeMs?: number;
2532
+ trainingTimeMs?: number;
2533
+ gradientsApplied: boolean;
2534
+ }
2535
+
2536
+ /** Summary of a training step (NAPI wrapper) */
2537
+ export interface StepSummary {
2538
+ step: number;
2539
+ loss: number;
2540
+ meanReward: number;
2541
+ numGenerations: number;
2542
+ numToolCalls: number;
2543
+ eosCount: number;
2544
+ lengthCount: number;
2545
+ }
2546
+
2547
+ /** A table structure */
2548
+ export interface Table {
2549
+ rows: Array<TableRow>;
2550
+ }
2551
+
2552
+ /** A single cell in a table */
2553
+ export interface TableCell {
2554
+ content: string;
2555
+ isEmpty: boolean;
2556
+ }
2557
+
2558
+ /** A row in a table */
2559
+ export interface TableRow {
2560
+ cells: Array<TableCell>;
2561
+ }
2562
+
2563
+ /** Language model (text decoder) configuration */
2564
+ export interface TextConfig {
2565
+ modelType: string;
2566
+ hiddenSize: number;
2567
+ numHiddenLayers: number;
2568
+ intermediateSize: number;
2569
+ numAttentionHeads: number;
2570
+ rmsNormEps: number;
2571
+ vocabSize: number;
2572
+ numKeyValueHeads: number;
2573
+ maxPositionEmbeddings: number;
2574
+ ropeTheta: number;
2575
+ ropeTraditional: boolean;
2576
+ useBias: boolean;
2577
+ headDim: number;
2578
+ /**
2579
+ * Multimodal RoPE sections: [temporal, height, width]
2580
+ * These define how the head_dim is split for 3D position encoding
2581
+ */
2582
+ mropeSection: Array<number>;
2583
+ }
2584
+
2585
+ /** Tool call made by an assistant */
2586
+ export interface ToolCall {
2587
+ /** Optional unique identifier for the tool call */
2588
+ id?: string;
2589
+ /** Name of the tool/function to call */
2590
+ name: string;
2591
+ /** JSON string of arguments to pass to the tool */
2592
+ arguments: string;
2593
+ }
2594
+
2595
+ /** A tool call record (NAPI wrapper) */
2596
+ export interface ToolCallRecord {
2597
+ callIndex: number;
2598
+ status: string;
2599
+ toolName?: string;
2600
+ arguments?: string;
2601
+ rawContent: string;
2602
+ errorMessage?: string;
2603
+ }
2604
+
2605
+ /** Structured tool call with parsed arguments */
2606
+ export interface ToolCallResult {
2607
+ /** Unique identifier for this tool call (format: call_<uuid>) */
2608
+ id: string;
2609
+ /** Name of the tool/function to call */
2610
+ name: string;
2611
+ /**
2612
+ * Parsed arguments as native object (serde_json::Value -> JS object)
2613
+ *
2614
+ * When status is "ok", this contains the parsed arguments object.
2615
+ * When status is "parse_error", this contains the original unparsed string.
2616
+ * Otherwise, this is an empty object {}.
2617
+ */
2618
+ arguments: Record<string, unknown> | string;
2619
+ /**
2620
+ * Parsing status: "ok" | "invalid_json" | "missing_name" | "parse_error"
2621
+ *
2622
+ * - "ok": Successfully parsed tool call
2623
+ * - "invalid_json": The tool_call tag content was not valid JSON
2624
+ * - "missing_name": Valid JSON but no "name" field
2625
+ * - "parse_error": Valid JSON but the "arguments" string field couldn't be parsed as JSON
2626
+ */
2627
+ status: string;
2628
+ /** Error message if status != "ok" */
2629
+ error?: string;
2630
+ /**
2631
+ * Raw content from <tool_call> tag (preserved for debugging/persistence)
2632
+ * Defaults to empty string for backward compatibility with older JSON
2633
+ */
2634
+ rawContent: string;
2635
+ }
2636
+
2637
+ /** OpenAI-compatible tool definition */
2638
+ export interface ToolDefinition {
2639
+ /** Tool type (currently only "function" is supported) */
2640
+ type: string;
2641
+ /** Function definition */
2642
+ function: FunctionDefinition;
2643
+ }
2644
+
2645
+ /** A training run record (NAPI wrapper) */
2646
+ export interface TrainingRunRecord {
2647
+ id: string;
2648
+ name?: string;
2649
+ modelName: string;
2650
+ modelPath?: string;
2651
+ config: string;
2652
+ startedAt: number;
2653
+ endedAt?: number;
2654
+ totalSteps: number;
2655
+ status: string;
2656
+ }
2657
+
2658
+ /** Result from train_step_auto including metrics, completions, and rewards */
2659
+ export interface TrainStepResult {
2660
+ /** Training metrics */
2661
+ metrics: EngineStepMetrics;
2662
+ /** Generated completion texts (for TUI logging) */
2663
+ completions: Array<string>;
2664
+ /** Computed reward values (for TUI logging) */
2665
+ rewards: Array<number>;
2666
+ }
2667
+
2668
+ /** Result from train_step_auto_with_recording including optional full RewardOutput data */
2669
+ export interface TrainStepResultWithOutputs {
2670
+ /** Training metrics */
2671
+ metrics: EngineStepMetrics;
2672
+ /** Generated completion texts (for TUI logging) */
2673
+ completions: Array<string>;
2674
+ /** Computed reward values (for TUI logging) */
2675
+ rewards: Array<number>;
2676
+ /**
2677
+ * Full RewardOutput data as JSON (only populated when record_outputs is true)
2678
+ * This enables zero-copy persistence of training outputs
2679
+ */
2680
+ outputsJson?: string;
2681
+ /** Actual token counts for each completion (for accurate TUI display) */
2682
+ completionLengths: Array<number>;
2683
+ }
2684
+
2685
+ /** Vision encoder configuration */
2686
+ export interface VisionConfig {
2687
+ modelType: string;
2688
+ hiddenSize: number;
2689
+ intermediateSize: number;
2690
+ numHiddenLayers: number;
2691
+ numAttentionHeads: number;
2692
+ numChannels: number;
2693
+ imageSize: number;
2694
+ patchSize: number;
2695
+ hiddenAct: string;
2696
+ layerNormEps: number;
2697
+ attentionDropout: number;
2698
+ spatialMergeSize: number;
2699
+ }
2700
+
2701
+ /** Configuration for VLM chat */
2702
+ export interface VlmChatConfig {
2703
+ /**
2704
+ * Image paths to process (alternative to passing pre-processed images)
2705
+ * These will be automatically processed using the ImageProcessor
2706
+ */
2707
+ imagePaths?: Array<string>;
2708
+ /** Maximum number of new tokens to generate (default: 512) */
2709
+ maxNewTokens?: number;
2710
+ /** Sampling temperature (0 = greedy, higher = more random) (default: 0.0 for OCR) */
2711
+ temperature?: number;
2712
+ /** Top-k sampling (default: 0) */
2713
+ topK?: number;
2714
+ /** Top-p (nucleus) sampling (default: 1.0) */
2715
+ topP?: number;
2716
+ /** Repetition penalty (default: 1.5) */
2717
+ repetitionPenalty?: number;
2718
+ /** Whether to return log probabilities (default: false) */
2719
+ returnLogprobs?: boolean;
2720
+ }
2721
+
2722
+ /** A chat message with optional image */
2723
+ export interface VlmChatMessage {
2724
+ /** Role of the message sender */
2725
+ role: ChatRole;
2726
+ /** Text content of the message */
2727
+ content: string;
2728
+ }