@mlx-node/core 0.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/index.cjs +766 -0
- package/index.d.cts +2728 -0
- package/package.json +36 -0
package/index.d.cts
ADDED
|
@@ -0,0 +1,2728 @@
|
|
|
1
|
+
/* auto-generated by NAPI-RS */
|
|
2
|
+
/* eslint-disable */
|
|
3
|
+
/**
|
|
4
|
+
* Result from batch text generation
|
|
5
|
+
*
|
|
6
|
+
* Contains results for N prompts × G completions per prompt.
|
|
7
|
+
* Results are stored flat in arrays of length N*G, where:
|
|
8
|
+
* - First G elements are completions for prompt 0
|
|
9
|
+
* - Next G elements are completions for prompt 1
|
|
10
|
+
* - etc.
|
|
11
|
+
*/
|
|
12
|
+
export declare class BatchGenerationResult {
|
|
13
|
+
/** Get all generated token arrays (N*G arrays) */
|
|
14
|
+
get tokens(): Array<MxArray>;
|
|
15
|
+
/** Get all log probability arrays (N*G arrays) */
|
|
16
|
+
get logprobs(): Array<MxArray>;
|
|
17
|
+
/** Get all decoded texts (N*G strings) */
|
|
18
|
+
get texts(): Array<string>;
|
|
19
|
+
/** Get finish reasons grouped by prompt (N arrays of G finish reasons) */
|
|
20
|
+
get finishReasons(): Array<Array<string>>;
|
|
21
|
+
/** Get token counts grouped by prompt (N arrays of G counts) */
|
|
22
|
+
get tokenCounts(): Array<Array<number>>;
|
|
23
|
+
/** Get number of prompts */
|
|
24
|
+
get numPrompts(): number;
|
|
25
|
+
/** Get group size (completions per prompt) */
|
|
26
|
+
get groupSize(): number;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
/**
|
|
30
|
+
* Result from the high-level `chat()` API
|
|
31
|
+
*
|
|
32
|
+
* Contains structured responses with:
|
|
33
|
+
* - Tool calls parsed as native JavaScript objects
|
|
34
|
+
* - Thinking/reasoning extracted from `<think>` tags
|
|
35
|
+
* - Clean text with all special tags stripped
|
|
36
|
+
*
|
|
37
|
+
* ## Example
|
|
38
|
+
* ```typescript
|
|
39
|
+
* const result = await model.chat(messages, { tools });
|
|
40
|
+
* console.log(result.text); // Clean response
|
|
41
|
+
* console.log(result.thinking); // Chain-of-thought (if any)
|
|
42
|
+
* console.log(result.toolCalls); // Parsed tool calls
|
|
43
|
+
* ```
|
|
44
|
+
*/
|
|
45
|
+
export declare class ChatResult {
|
|
46
|
+
/** Get the cleaned text (tool_call and think tags removed) */
|
|
47
|
+
get text(): string;
|
|
48
|
+
/** Get the extracted tool calls */
|
|
49
|
+
get toolCalls(): Array<ToolCallResult>;
|
|
50
|
+
/**
|
|
51
|
+
* Get the extracted thinking/reasoning content
|
|
52
|
+
*
|
|
53
|
+
* Returns the content from within `<think>...</think>` tags, or null if
|
|
54
|
+
* no thinking tags were present in the response.
|
|
55
|
+
*
|
|
56
|
+
* This is useful for:
|
|
57
|
+
* - Debugging model reasoning
|
|
58
|
+
* - Displaying chain-of-thought to users (optional)
|
|
59
|
+
* - Analyzing model decision-making
|
|
60
|
+
*/
|
|
61
|
+
get thinking(): string | null;
|
|
62
|
+
/** Get the generated tokens */
|
|
63
|
+
get tokens(): MxArray;
|
|
64
|
+
/** Get the log probabilities */
|
|
65
|
+
get logprobs(): MxArray;
|
|
66
|
+
/** Get the finish reason ("stop", "length", "tool_calls", or "repetition") */
|
|
67
|
+
get finishReason(): 'stop' | 'length' | 'tool_calls' | 'repetition';
|
|
68
|
+
/** Get the number of tokens generated */
|
|
69
|
+
get numTokens(): number;
|
|
70
|
+
/** Get the raw text before tool call stripping (for debugging) */
|
|
71
|
+
get rawText(): string;
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
/** Result from text generation with detailed metadata */
|
|
75
|
+
export declare class GenerationResult {
|
|
76
|
+
/** Get the decoded text */
|
|
77
|
+
get text(): string;
|
|
78
|
+
/** Get the generated tokens */
|
|
79
|
+
get tokens(): MxArray;
|
|
80
|
+
/** Get the log probabilities */
|
|
81
|
+
get logprobs(): MxArray;
|
|
82
|
+
/** Get the finish reason ("eos", "length", or "repetition") */
|
|
83
|
+
get finishReason(): 'eos' | 'length' | 'repetition';
|
|
84
|
+
/** Get the number of tokens generated */
|
|
85
|
+
get numTokens(): number;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
/**
|
|
89
|
+
* GRPO Training Engine
|
|
90
|
+
*
|
|
91
|
+
* Complete training engine that runs entirely in Rust.
|
|
92
|
+
*/
|
|
93
|
+
export declare class GrpoTrainingEngine {
|
|
94
|
+
/**
|
|
95
|
+
* Create a new training engine from an existing model
|
|
96
|
+
*
|
|
97
|
+
* # Arguments
|
|
98
|
+
* * `model` - The Qwen3 model to train (will be cloned internally)
|
|
99
|
+
* * `config` - Engine configuration
|
|
100
|
+
*/
|
|
101
|
+
constructor(model: Qwen3Model, config: GrpoEngineConfig);
|
|
102
|
+
/** Register a built-in reward function */
|
|
103
|
+
registerBuiltinReward(config: BuiltinRewardConfig): void;
|
|
104
|
+
/**
|
|
105
|
+
* Run a training step with provided rewards
|
|
106
|
+
*
|
|
107
|
+
* This method performs the complete training cycle:
|
|
108
|
+
* 1. Generate completions for each prompt (G times per prompt)
|
|
109
|
+
* 2. Use provided rewards to compute advantages
|
|
110
|
+
* 3. Compute GRPO loss and gradients
|
|
111
|
+
* 4. Apply gradients (respecting accumulation steps)
|
|
112
|
+
*
|
|
113
|
+
* # Arguments
|
|
114
|
+
* * `prompts` - Array of chat conversations to use as prompts
|
|
115
|
+
* * `rewards` - Reward values for each completion (num_prompts * group_size)
|
|
116
|
+
*
|
|
117
|
+
* # Returns
|
|
118
|
+
* * Training step metrics
|
|
119
|
+
*/
|
|
120
|
+
trainStep(prompts: Array<Array<ChatMessage>>, rewards: Array<number>): Promise<EngineStepMetrics>;
|
|
121
|
+
/**
|
|
122
|
+
* Generate completions without training
|
|
123
|
+
*
|
|
124
|
+
* Use this to generate completions for scoring by external reward functions.
|
|
125
|
+
* Returns completion texts along with the internal token data needed for training.
|
|
126
|
+
*/
|
|
127
|
+
generateBatch(prompts: Array<Array<ChatMessage>>): Promise<Array<string>>;
|
|
128
|
+
/**
|
|
129
|
+
* Generate completions with all data needed for training
|
|
130
|
+
*
|
|
131
|
+
* Returns completion texts, tokens, log probabilities, and lengths.
|
|
132
|
+
* Use this when you need to score completions externally and then train.
|
|
133
|
+
*/
|
|
134
|
+
generateBatchForTraining(prompts: Array<Array<ChatMessage>>): Promise<GenerateBatchResult>;
|
|
135
|
+
/**
|
|
136
|
+
* Run a training step with pre-generated completions
|
|
137
|
+
*
|
|
138
|
+
* This method performs training using pre-generated completions,
|
|
139
|
+
* eliminating the double-generation issue.
|
|
140
|
+
*
|
|
141
|
+
* # Arguments
|
|
142
|
+
* * `prompts` - Array of chat conversations to use as prompts
|
|
143
|
+
* * `rewards` - Reward values for each completion (num_prompts * group_size)
|
|
144
|
+
* * `generation_result` - Pre-generated completion data from generate_batch_for_training
|
|
145
|
+
*
|
|
146
|
+
* # Returns
|
|
147
|
+
* * Training step metrics
|
|
148
|
+
*/
|
|
149
|
+
trainStepWithGenerations(
|
|
150
|
+
prompts: Array<Array<ChatMessage>>,
|
|
151
|
+
rewards: Array<number>,
|
|
152
|
+
generationResult: GenerateBatchResult,
|
|
153
|
+
): Promise<EngineStepMetrics>;
|
|
154
|
+
/**
|
|
155
|
+
* Unified training step with JS reward callback and optional output recording
|
|
156
|
+
*
|
|
157
|
+
* Same as `train_step_auto` but optionally captures the full RewardOutput data
|
|
158
|
+
* for persistence to an output store database.
|
|
159
|
+
*
|
|
160
|
+
* # Arguments
|
|
161
|
+
* * `prompts` - Array of chat conversations to use as prompts
|
|
162
|
+
* * `reward_fn` - JavaScript function to compute rewards
|
|
163
|
+
* * `record_outputs` - If true, return the serialized RewardOutput JSON
|
|
164
|
+
*
|
|
165
|
+
* # Returns
|
|
166
|
+
* * Training step result including metrics, completions, rewards, and optionally outputs_json
|
|
167
|
+
*/
|
|
168
|
+
trainStepAuto(
|
|
169
|
+
prompts: ChatMessage[][],
|
|
170
|
+
rewardFn: (err: Error | null, outputsJson: string) => Promise<number[]>,
|
|
171
|
+
recordOutputs: boolean,
|
|
172
|
+
): Promise<TrainStepResultWithOutputs>;
|
|
173
|
+
/**
|
|
174
|
+
* Score completions using registered built-in rewards
|
|
175
|
+
*
|
|
176
|
+
* # Arguments
|
|
177
|
+
* * `prompts` - Prompt texts (expanded to match completions)
|
|
178
|
+
* * `completions` - Completion texts to score
|
|
179
|
+
*/
|
|
180
|
+
scoreCompletions(prompts: Array<string>, completions: Array<string>): Array<number>;
|
|
181
|
+
/** Get current training step */
|
|
182
|
+
get step(): number;
|
|
183
|
+
/** Get current epoch */
|
|
184
|
+
get epoch(): number;
|
|
185
|
+
/** Start a new epoch */
|
|
186
|
+
startEpoch(): void;
|
|
187
|
+
/** End the current epoch and get metrics */
|
|
188
|
+
endEpoch(epochTimeSecs: number): EngineEpochMetrics;
|
|
189
|
+
/** Reset the engine for a fresh training run */
|
|
190
|
+
reset(): void;
|
|
191
|
+
/** Check if reward registry has any rewards registered */
|
|
192
|
+
get hasBuiltinRewards(): boolean;
|
|
193
|
+
/** Get names of registered reward functions */
|
|
194
|
+
get rewardNames(): Array<string>;
|
|
195
|
+
/** Get current micro-step within gradient accumulation */
|
|
196
|
+
get microStep(): number;
|
|
197
|
+
/**
|
|
198
|
+
* Check if an emergency checkpoint should be saved
|
|
199
|
+
* This flag is set when consecutive NaN gradients reach the threshold
|
|
200
|
+
*/
|
|
201
|
+
get needsEmergencySave(): boolean;
|
|
202
|
+
/** Get current NaN gradient count */
|
|
203
|
+
get nanGradientCount(): number;
|
|
204
|
+
/** Clear the emergency save flag (call after saving emergency checkpoint) */
|
|
205
|
+
clearEmergencySaveFlag(): void;
|
|
206
|
+
}
|
|
207
|
+
export type GRPOTrainingEngine = GrpoTrainingEngine;
|
|
208
|
+
|
|
209
|
+
export declare class MxArray {
|
|
210
|
+
static fromInt32(data: Int32Array, shape: BigInt64Array): MxArray;
|
|
211
|
+
static fromInt64(data: BigInt64Array, shape: BigInt64Array): MxArray;
|
|
212
|
+
static fromUint32(data: Uint32Array, shape: BigInt64Array): MxArray;
|
|
213
|
+
static fromFloat32(data: Float32Array, shape: BigInt64Array): MxArray;
|
|
214
|
+
static zeros(shape: BigInt64Array, dtype?: DType | undefined | null): MxArray;
|
|
215
|
+
static scalarFloat(value: number): MxArray;
|
|
216
|
+
static scalarInt(value: number): MxArray;
|
|
217
|
+
static ones(shape: BigInt64Array, dtype?: DType | undefined | null): MxArray;
|
|
218
|
+
static full(shape: BigInt64Array, fillValue: number | MxArray, dtype?: DType | undefined | null): MxArray;
|
|
219
|
+
static linspace(
|
|
220
|
+
start: number,
|
|
221
|
+
stop: number,
|
|
222
|
+
num?: number | undefined | null,
|
|
223
|
+
dtype?: DType | undefined | null,
|
|
224
|
+
): MxArray;
|
|
225
|
+
static eye(
|
|
226
|
+
n: number,
|
|
227
|
+
m?: number | undefined | null,
|
|
228
|
+
k?: number | undefined | null,
|
|
229
|
+
dtype?: DType | undefined | null,
|
|
230
|
+
): MxArray;
|
|
231
|
+
static arange(
|
|
232
|
+
start: number,
|
|
233
|
+
stop: number,
|
|
234
|
+
step?: number | undefined | null,
|
|
235
|
+
dtype?: DType | undefined | null,
|
|
236
|
+
): MxArray;
|
|
237
|
+
reshape(shape: BigInt64Array): MxArray;
|
|
238
|
+
astype(dtype: DType): MxArray;
|
|
239
|
+
/**
|
|
240
|
+
* Create a copy of this array with a new handle.
|
|
241
|
+
* This is useful for parameter loading to avoid handle aliasing issues.
|
|
242
|
+
*/
|
|
243
|
+
copy(): MxArray;
|
|
244
|
+
logSoftmax(axis: number): MxArray;
|
|
245
|
+
exp(): MxArray;
|
|
246
|
+
log(): MxArray;
|
|
247
|
+
sum(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
|
|
248
|
+
mean(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
|
|
249
|
+
clip(minimum?: number | undefined | null, maximum?: number | undefined | null): MxArray;
|
|
250
|
+
minimum(other: MxArray): MxArray;
|
|
251
|
+
maximum(other: MxArray): MxArray;
|
|
252
|
+
add(other: MxArray): MxArray;
|
|
253
|
+
sub(other: MxArray): MxArray;
|
|
254
|
+
mul(other: MxArray): MxArray;
|
|
255
|
+
div(other: MxArray): MxArray;
|
|
256
|
+
addScalar(value: number): MxArray;
|
|
257
|
+
mulScalar(value: number): MxArray;
|
|
258
|
+
subScalar(value: number): MxArray;
|
|
259
|
+
divScalar(value: number): MxArray;
|
|
260
|
+
matmul(other: MxArray): MxArray;
|
|
261
|
+
/**
|
|
262
|
+
* Fused matrix multiply-add: D = beta * C + alpha * (self @ B)
|
|
263
|
+
* where self is A. More efficient than separate matmul and add operations.
|
|
264
|
+
* Default: alpha=1.0, beta=1.0, giving D = C + (self @ B)
|
|
265
|
+
*/
|
|
266
|
+
addmm(c: MxArray, b: MxArray, alpha?: number | undefined | null, beta?: number | undefined | null): MxArray;
|
|
267
|
+
transpose(axes?: Int32Array | undefined | null): MxArray;
|
|
268
|
+
take(indices: MxArray, axis: number): MxArray;
|
|
269
|
+
takeAlongAxis(indices: MxArray, axis: number): MxArray;
|
|
270
|
+
/**
|
|
271
|
+
* Put values into array at specified indices along an axis
|
|
272
|
+
* Equivalent to: result = array.copy(); result[..., indices] = values
|
|
273
|
+
* This matches MLX's put_along_axis for efficient in-place-style updates
|
|
274
|
+
*/
|
|
275
|
+
putAlongAxis(indices: MxArray, values: MxArray, axis: number): MxArray;
|
|
276
|
+
slice(starts: BigInt64Array, stops: BigInt64Array): MxArray;
|
|
277
|
+
/**
|
|
278
|
+
* Concatenate two arrays along an axis
|
|
279
|
+
* Optimized for the common binary concatenation case
|
|
280
|
+
*/
|
|
281
|
+
static concatenate(a: MxArray, b: MxArray, axis: number): MxArray;
|
|
282
|
+
/**
|
|
283
|
+
* Concatenate multiple arrays along an axis
|
|
284
|
+
* For concatenating 3 or more arrays
|
|
285
|
+
*/
|
|
286
|
+
static concatenateMany(arrays: Array<MxArray>, axis?: number | undefined | null): MxArray;
|
|
287
|
+
sort(axis?: number | undefined | null): MxArray;
|
|
288
|
+
argsort(axis?: number | undefined | null): MxArray;
|
|
289
|
+
partition(kth: number, axis?: number | undefined | null): MxArray;
|
|
290
|
+
argpartition(kth: number, axis?: number | undefined | null): MxArray;
|
|
291
|
+
eval(): void;
|
|
292
|
+
evalAsync(): Promise<undefined>;
|
|
293
|
+
size(): bigint;
|
|
294
|
+
ndim(): number;
|
|
295
|
+
shape(): BigInt64Array;
|
|
296
|
+
/**
|
|
297
|
+
* Get a single dimension from the array shape without copying the entire shape
|
|
298
|
+
* This is more efficient when you only need one dimension
|
|
299
|
+
*
|
|
300
|
+
* Note: axis is u32 because NAPI doesn't support usize, but internally converted to usize
|
|
301
|
+
*/
|
|
302
|
+
shapeAt(axis: number): number;
|
|
303
|
+
/**
|
|
304
|
+
* Get batch and sequence length for 2D arrays (common pattern in transformers)
|
|
305
|
+
* More efficient than calling shape() and extracting dimensions
|
|
306
|
+
*/
|
|
307
|
+
getBatchSeqLen(): Array<number>;
|
|
308
|
+
/**
|
|
309
|
+
* Get batch, sequence length, and hidden size for 3D arrays (common pattern in transformers)
|
|
310
|
+
* More efficient than calling shape() and extracting dimensions
|
|
311
|
+
*/
|
|
312
|
+
getBatchSeqHidden(): Array<number>;
|
|
313
|
+
dtype(): DType;
|
|
314
|
+
/**
|
|
315
|
+
* Copy entire array from GPU to CPU as Float32Array
|
|
316
|
+
*
|
|
317
|
+
* ⚠吅 **PERFORMANCE WARNING**: This triggers a FULL GPU→CPU memory transfer!
|
|
318
|
+
*
|
|
319
|
+
* **Performance impact**:
|
|
320
|
+
* - Forces evaluation of lazy operations
|
|
321
|
+
* - Copies entire array from GPU to CPU memory
|
|
322
|
+
* - Can be extremely slow for large arrays
|
|
323
|
+
*
|
|
324
|
+
* **Use sparingly**:
|
|
325
|
+
* - Prefer `item_float32()` for scalars
|
|
326
|
+
* - Prefer `item_at_float32(index)` for single elements
|
|
327
|
+
* - Only use when you truly need all array data on CPU
|
|
328
|
+
*
|
|
329
|
+
* **Acceptable use cases**:
|
|
330
|
+
* - Test validation and assertions
|
|
331
|
+
* - CPU-only operations (e.g., sorting for quantiles)
|
|
332
|
+
* - Final output extraction
|
|
333
|
+
*/
|
|
334
|
+
toFloat32(): Float32Array;
|
|
335
|
+
/**
|
|
336
|
+
* Copy entire array from GPU to CPU as Int32Array
|
|
337
|
+
*
|
|
338
|
+
* ⚠吅 **PERFORMANCE WARNING**: This triggers a FULL GPU→CPU memory transfer!
|
|
339
|
+
*
|
|
340
|
+
* See `to_float32()` documentation for performance implications and alternatives.
|
|
341
|
+
* Prefer `item_int32()` for scalars.
|
|
342
|
+
*/
|
|
343
|
+
toInt32(): Int32Array;
|
|
344
|
+
/**
|
|
345
|
+
* Copy entire array from GPU to CPU as Uint32Array
|
|
346
|
+
*
|
|
347
|
+
* ⚠吅 **PERFORMANCE WARNING**: This triggers a FULL GPU→CPU memory transfer!
|
|
348
|
+
*
|
|
349
|
+
* See `to_float32()` documentation for performance implications and alternatives.
|
|
350
|
+
*/
|
|
351
|
+
toUint32(): Uint32Array;
|
|
352
|
+
static stack(arrays: Array<MxArray>, axis?: number | undefined | null): MxArray;
|
|
353
|
+
static randomUniform(shape: BigInt64Array, low: number, high: number, dtype?: DType | undefined | null): MxArray;
|
|
354
|
+
static randomNormal(shape: BigInt64Array, mean: number, std: number, dtype?: DType | undefined | null): MxArray;
|
|
355
|
+
static randomBernoulli(shape: BigInt64Array, prob: number): MxArray;
|
|
356
|
+
static randint(shape: BigInt64Array, low: number, high: number): MxArray;
|
|
357
|
+
/**
|
|
358
|
+
* Sample from categorical distribution
|
|
359
|
+
* Takes logits and returns sampled indices
|
|
360
|
+
*/
|
|
361
|
+
categorical(axis?: number | undefined | null): MxArray;
|
|
362
|
+
equal(other: MxArray): MxArray;
|
|
363
|
+
notEqual(other: MxArray): MxArray;
|
|
364
|
+
less(other: MxArray): MxArray;
|
|
365
|
+
lessEqual(other: MxArray): MxArray;
|
|
366
|
+
greater(other: MxArray): MxArray;
|
|
367
|
+
greaterEqual(other: MxArray): MxArray;
|
|
368
|
+
logicalAnd(other: MxArray): MxArray;
|
|
369
|
+
logicalOr(other: MxArray): MxArray;
|
|
370
|
+
logicalNot(): MxArray;
|
|
371
|
+
where(x: MxArray, y: MxArray): MxArray;
|
|
372
|
+
argmax(axis: number, keepdims?: boolean | undefined | null): MxArray;
|
|
373
|
+
argmin(axis: number, keepdims?: boolean | undefined | null): MxArray;
|
|
374
|
+
max(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
|
|
375
|
+
min(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
|
|
376
|
+
prod(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
|
|
377
|
+
var(
|
|
378
|
+
axes?: Int32Array | undefined | null,
|
|
379
|
+
keepdims?: boolean | undefined | null,
|
|
380
|
+
ddof?: number | undefined | null,
|
|
381
|
+
): MxArray;
|
|
382
|
+
std(
|
|
383
|
+
axes?: Int32Array | undefined | null,
|
|
384
|
+
keepdims?: boolean | undefined | null,
|
|
385
|
+
ddof?: number | undefined | null,
|
|
386
|
+
): MxArray;
|
|
387
|
+
logsumexp(axes?: Int32Array | undefined | null, keepdims?: boolean | undefined | null): MxArray;
|
|
388
|
+
cumsum(axis: number): MxArray;
|
|
389
|
+
cumprod(axis: number): MxArray;
|
|
390
|
+
pad(padWidth: Int32Array, constantValue: number): MxArray;
|
|
391
|
+
roll(shift: number, axis: number): MxArray;
|
|
392
|
+
split(indicesOrSections: number, axis?: number | undefined | null): Array<MxArray>;
|
|
393
|
+
tile(reps: Int32Array): MxArray;
|
|
394
|
+
repeat(repeats: number, axis: number): MxArray;
|
|
395
|
+
squeeze(axes?: Int32Array | undefined | null): MxArray;
|
|
396
|
+
expandDims(axis: number): MxArray;
|
|
397
|
+
broadcastTo(shape: BigInt64Array): MxArray;
|
|
398
|
+
abs(): MxArray;
|
|
399
|
+
negative(): MxArray;
|
|
400
|
+
sign(): MxArray;
|
|
401
|
+
sqrt(): MxArray;
|
|
402
|
+
square(): MxArray;
|
|
403
|
+
power(other: MxArray): MxArray;
|
|
404
|
+
sin(): MxArray;
|
|
405
|
+
cos(): MxArray;
|
|
406
|
+
tan(): MxArray;
|
|
407
|
+
sinh(): MxArray;
|
|
408
|
+
cosh(): MxArray;
|
|
409
|
+
tanh(): MxArray;
|
|
410
|
+
floor(): MxArray;
|
|
411
|
+
ceil(): MxArray;
|
|
412
|
+
round(): MxArray;
|
|
413
|
+
floorDivide(other: MxArray): MxArray;
|
|
414
|
+
remainder(other: MxArray): MxArray;
|
|
415
|
+
reciprocal(): MxArray;
|
|
416
|
+
arcsin(): MxArray;
|
|
417
|
+
arccos(): MxArray;
|
|
418
|
+
arctan(): MxArray;
|
|
419
|
+
log10(): MxArray;
|
|
420
|
+
log2(): MxArray;
|
|
421
|
+
log1p(): MxArray;
|
|
422
|
+
/**
|
|
423
|
+
* Element-wise check for NaN values
|
|
424
|
+
*
|
|
425
|
+
* Returns a boolean array where True indicates the element is NaN.
|
|
426
|
+
* This is a GPU-native operation that avoids CPU data transfer.
|
|
427
|
+
*/
|
|
428
|
+
isnan(): MxArray;
|
|
429
|
+
/**
|
|
430
|
+
* Element-wise check for Inf values
|
|
431
|
+
*
|
|
432
|
+
* Returns a boolean array where True indicates the element is +Inf or -Inf.
|
|
433
|
+
* This is a GPU-native operation that avoids CPU data transfer.
|
|
434
|
+
*/
|
|
435
|
+
isinf(): MxArray;
|
|
436
|
+
/**
|
|
437
|
+
* Element-wise check for finite values
|
|
438
|
+
*
|
|
439
|
+
* Returns a boolean array where True indicates the element is finite (not NaN and not Inf).
|
|
440
|
+
* This is a GPU-native operation that avoids CPU data transfer.
|
|
441
|
+
*/
|
|
442
|
+
isfinite(): MxArray;
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
/** NAPI-exported reward registry wrapper */
|
|
446
|
+
export declare class NativeRewardRegistry {
|
|
447
|
+
/** Create a new reward registry */
|
|
448
|
+
constructor();
|
|
449
|
+
/** Register a built-in reward function */
|
|
450
|
+
register(config: BuiltinRewardConfig): void;
|
|
451
|
+
/** Score a single completion */
|
|
452
|
+
score(prompt: string, completion: string): number;
|
|
453
|
+
/** Score a batch of completions */
|
|
454
|
+
scoreBatch(prompts: Array<string>, completions: Array<string>): Array<number>;
|
|
455
|
+
/** Check if registry is empty */
|
|
456
|
+
get isEmpty(): boolean;
|
|
457
|
+
/** Get registered reward names */
|
|
458
|
+
get names(): Array<string>;
|
|
459
|
+
/** Set whether to normalize scores */
|
|
460
|
+
setNormalize(normalize: boolean): void;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
/**
|
|
464
|
+
* OutputStore - Persistence layer for training outputs
|
|
465
|
+
*
|
|
466
|
+
* Stores all model outputs during GRPO training for debugging and research.
|
|
467
|
+
* Supports local SQLite files.
|
|
468
|
+
*/
|
|
469
|
+
export declare class OutputStore {
|
|
470
|
+
/** Create a new output store with local SQLite file */
|
|
471
|
+
static local(path: string): Promise<OutputStore>;
|
|
472
|
+
/** Create from config object */
|
|
473
|
+
static fromConfig(config: OutputStoreConfig): Promise<OutputStore>;
|
|
474
|
+
/** Start a new training run */
|
|
475
|
+
startRun(modelName: string, modelPath: string | undefined | null, config: string): Promise<string>;
|
|
476
|
+
/** Start a new training run with a name */
|
|
477
|
+
startRunWithName(
|
|
478
|
+
name: string | undefined | null,
|
|
479
|
+
modelName: string,
|
|
480
|
+
modelPath: string | undefined | null,
|
|
481
|
+
config: string,
|
|
482
|
+
): Promise<string>;
|
|
483
|
+
/** End the current training run */
|
|
484
|
+
endRun(status: string): Promise<void>;
|
|
485
|
+
/** Get current run ID */
|
|
486
|
+
currentRunId(): Promise<string | null>;
|
|
487
|
+
/** Find a run by name */
|
|
488
|
+
findRunByName(name: string): Promise<TrainingRunRecord | null>;
|
|
489
|
+
/** Resume an existing run (sets status to running and makes it current) */
|
|
490
|
+
resumeRun(runId: string): Promise<void>;
|
|
491
|
+
/** Delete all steps after a given step number (for resume cleanup) */
|
|
492
|
+
deleteStepsAfter(runId: string, afterStep: number): Promise<number>;
|
|
493
|
+
/**
|
|
494
|
+
* Delete all records after a given step (for checkpoint resume)
|
|
495
|
+
*
|
|
496
|
+
* Cascades through: training_steps → generations → tool_calls, and logs.
|
|
497
|
+
* Use this when resuming from checkpoint to ensure clean database state.
|
|
498
|
+
*/
|
|
499
|
+
deleteAllAfterStep(runId: string, afterStep: number): Promise<CleanupStats>;
|
|
500
|
+
/**
|
|
501
|
+
* Get recent step metrics for TUI sparkline restoration
|
|
502
|
+
*
|
|
503
|
+
* Returns metrics ordered by step (oldest first) for easy insertion into VecDeque.
|
|
504
|
+
*/
|
|
505
|
+
getRecentStepMetrics(runId: string, limit: number): Promise<Array<StepMetricSummary>>;
|
|
506
|
+
/**
|
|
507
|
+
* Get aggregate statistics for a training run
|
|
508
|
+
*
|
|
509
|
+
* Returns pre-computed aggregates for restoring TUI state on resume.
|
|
510
|
+
*/
|
|
511
|
+
getRunAggregates(runId: string): Promise<RunAggregates>;
|
|
512
|
+
/**
|
|
513
|
+
* Get recent generations for sample panel restoration
|
|
514
|
+
*
|
|
515
|
+
* Returns generations ordered by step DESC, reward DESC (most recent high-reward first).
|
|
516
|
+
*/
|
|
517
|
+
getRecentGenerations(runId: string, limit: number): Promise<Array<GenerationRecord>>;
|
|
518
|
+
/** Get store configuration */
|
|
519
|
+
get config(): OutputStoreConfig;
|
|
520
|
+
/** Record from RewardOutput JSON (direct integration with training engine) */
|
|
521
|
+
recordStepFromOutputs(
|
|
522
|
+
step: number,
|
|
523
|
+
metrics: EngineStepMetrics,
|
|
524
|
+
outputsJson: string,
|
|
525
|
+
rewards: Array<number>,
|
|
526
|
+
groupSize: number,
|
|
527
|
+
): Promise<number>;
|
|
528
|
+
/**
|
|
529
|
+
* Record a complete training step with all generations and tool calls
|
|
530
|
+
*
|
|
531
|
+
* Lower-level API for direct control over step recording.
|
|
532
|
+
*/
|
|
533
|
+
recordStep(
|
|
534
|
+
step: StepRecord,
|
|
535
|
+
generations: Array<GenerationRecord>,
|
|
536
|
+
toolCalls: Array<Array<ToolCallRecord>>,
|
|
537
|
+
): Promise<number>;
|
|
538
|
+
/** Flush any pending writes */
|
|
539
|
+
flush(): Promise<void>;
|
|
540
|
+
/** List all training runs */
|
|
541
|
+
listRuns(limit?: number | undefined | null, status?: string | undefined | null): Promise<Array<TrainingRunRecord>>;
|
|
542
|
+
/** Get a specific run */
|
|
543
|
+
getRun(runId: string): Promise<TrainingRunRecord | null>;
|
|
544
|
+
/** Get step summaries for a run */
|
|
545
|
+
getStepSummaries(
|
|
546
|
+
runId: string,
|
|
547
|
+
startStep?: number | undefined | null,
|
|
548
|
+
endStep?: number | undefined | null,
|
|
549
|
+
): Promise<Array<StepSummary>>;
|
|
550
|
+
/** Get all generations for a step */
|
|
551
|
+
getGenerations(runId: string, step: number): Promise<Array<GenerationWithToolCalls>>;
|
|
552
|
+
/** Get top/bottom generations by reward */
|
|
553
|
+
getGenerationsByReward(
|
|
554
|
+
runId: string,
|
|
555
|
+
topN?: number | undefined | null,
|
|
556
|
+
bottomN?: number | undefined | null,
|
|
557
|
+
stepRange?: Array<number> | undefined | null,
|
|
558
|
+
): Promise<Array<GenerationWithToolCalls>>;
|
|
559
|
+
/** Get generations with specific finish reason */
|
|
560
|
+
getGenerationsByFinishReason(
|
|
561
|
+
runId: string,
|
|
562
|
+
finishReason: string,
|
|
563
|
+
limit?: number | undefined | null,
|
|
564
|
+
): Promise<Array<GenerationWithToolCalls>>;
|
|
565
|
+
/** Get generations containing tool calls */
|
|
566
|
+
getGenerationsWithToolCalls(
|
|
567
|
+
runId: string,
|
|
568
|
+
toolName?: string | undefined | null,
|
|
569
|
+
status?: string | undefined | null,
|
|
570
|
+
limit?: number | undefined | null,
|
|
571
|
+
): Promise<Array<GenerationWithToolCalls>>;
|
|
572
|
+
/** Search generations by text content */
|
|
573
|
+
searchGenerations(
|
|
574
|
+
runId: string,
|
|
575
|
+
query: string,
|
|
576
|
+
searchIn?: string | undefined | null,
|
|
577
|
+
limit?: number | undefined | null,
|
|
578
|
+
): Promise<Array<GenerationWithToolCalls>>;
|
|
579
|
+
/** Get reward distribution statistics */
|
|
580
|
+
getRewardStats(runId: string, stepRange?: Array<number> | undefined | null): Promise<RewardStats>;
|
|
581
|
+
/** Export to JSONL file */
|
|
582
|
+
exportJsonl(runId: string, outputPath: string, includeToolCalls?: boolean | undefined | null): Promise<number>;
|
|
583
|
+
/** Execute raw SQL query (for advanced users) */
|
|
584
|
+
queryRaw(sql: string): Promise<string>;
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
/**
|
|
588
|
+
* Qwen3 Model with automatic differentiation support
|
|
589
|
+
*
|
|
590
|
+
* Uses interior mutability (RwLock) for layers, final_norm, and lm_head
|
|
591
|
+
* to allow gradient application without deep cloning the model.
|
|
592
|
+
* This eliminates the previous ~4GB memory overhead from clone_for_session().
|
|
593
|
+
*/
|
|
594
|
+
export declare class Qwen3Model {
|
|
595
|
+
/** Create a new Qwen3 model with the given configuration */
|
|
596
|
+
constructor(config: Qwen3Config);
|
|
597
|
+
/**
|
|
598
|
+
* Forward pass through the model
|
|
599
|
+
*
|
|
600
|
+
* # Arguments
|
|
601
|
+
* * `input_ids` - Token IDs, shape: [batch_size, seq_len]
|
|
602
|
+
*
|
|
603
|
+
* # Returns
|
|
604
|
+
* * Logits, shape: [batch_size, seq_len, vocab_size]
|
|
605
|
+
*/
|
|
606
|
+
forward(inputIds: MxArray): MxArray;
|
|
607
|
+
/**
|
|
608
|
+
* Initialize KV caches for incremental generation
|
|
609
|
+
*
|
|
610
|
+
* Creates one KV cache per transformer layer. Call this before starting generation.
|
|
611
|
+
*/
|
|
612
|
+
initKvCaches(): void;
|
|
613
|
+
/**
|
|
614
|
+
* Reset all KV caches
|
|
615
|
+
*
|
|
616
|
+
* Clears cached key-value states. Call this between different generation sequences.
|
|
617
|
+
*/
|
|
618
|
+
resetKvCaches(): void;
|
|
619
|
+
/** Check if paged attention is enabled for this model */
|
|
620
|
+
hasPagedAttention(): boolean;
|
|
621
|
+
/**
|
|
622
|
+
* Get paged attention memory statistics (if enabled)
|
|
623
|
+
*
|
|
624
|
+
* Returns memory usage statistics for the paged KV cache.
|
|
625
|
+
*/
|
|
626
|
+
pagedCacheStats(): PagedCacheStats | null;
|
|
627
|
+
/**
|
|
628
|
+
* Get scheduler statistics (if paged attention is enabled)
|
|
629
|
+
*
|
|
630
|
+
* Returns the number of waiting, running, and completed sequences.
|
|
631
|
+
*/
|
|
632
|
+
schedulerStats(): SchedulerStatsNapi | null;
|
|
633
|
+
/**
|
|
634
|
+
* Forward pass with KV caching for incremental generation
|
|
635
|
+
*
|
|
636
|
+
* # Arguments
|
|
637
|
+
* * `input_ids` - Token IDs, shape: [batch_size, seq_len]
|
|
638
|
+
* * `use_cache` - Whether to use KV caching (must call init_kv_caches() first)
|
|
639
|
+
*
|
|
640
|
+
* # Returns
|
|
641
|
+
* * Logits, shape: [batch_size, seq_len, vocab_size]
|
|
642
|
+
*/
|
|
643
|
+
forwardWithCache(inputIds: MxArray, useCache: boolean): MxArray;
|
|
644
|
+
/**
|
|
645
|
+
* Forward pass with paged attention for memory-efficient inference.
|
|
646
|
+
*
|
|
647
|
+
* This method uses block-based KV cache management via Metal kernels for:
|
|
648
|
+
* - Variable-length sequences with efficient memory usage
|
|
649
|
+
* - Continuous batching with dynamic batch composition
|
|
650
|
+
* - Long context support beyond GPU memory limits
|
|
651
|
+
*
|
|
652
|
+
* # Arguments
|
|
653
|
+
* * `input_ids` - Token IDs, shape: [num_seqs, 1] for decode
|
|
654
|
+
* * `slot_mapping` - Slot indices for cache updates, shape: [num_seqs]
|
|
655
|
+
* * `seq_ids` - Sequence IDs in the batch (for looking up block tables/context lens)
|
|
656
|
+
* * `positions` - Token positions for RoPE, shape: [num_seqs] (per-sequence positions)
|
|
657
|
+
*
|
|
658
|
+
* # Returns
|
|
659
|
+
* * Logits, shape: [num_seqs, 1, vocab_size] for decode
|
|
660
|
+
*/
|
|
661
|
+
forwardPaged(inputIds: MxArray, slotMapping: MxArray, seqIds: Array<number>, positions: MxArray): MxArray;
|
|
662
|
+
/**
|
|
663
|
+
* Prefill a sequence using standard attention and write K/V to paged cache.
|
|
664
|
+
*
|
|
665
|
+
* This method should be called before `step_paged_generation()` for each
|
|
666
|
+
* new prompt. It runs the full forward pass using standard attention
|
|
667
|
+
* (which is faster for long sequences), then writes the K/V cache to
|
|
668
|
+
* the paged cache for subsequent decode steps.
|
|
669
|
+
*
|
|
670
|
+
* # Arguments
|
|
671
|
+
* * `prompt_tokens` - Token IDs for the prompt (as u32 array)
|
|
672
|
+
* * `seq_id` - Sequence ID (obtained from scheduler)
|
|
673
|
+
*
|
|
674
|
+
* # Returns
|
|
675
|
+
* * Logits for the last token, shape: [1, vocab_size]
|
|
676
|
+
*/
|
|
677
|
+
prefillPaged(promptTokens: Array<number>, seqId: number): MxArray;
|
|
678
|
+
/**
|
|
679
|
+
* Add a request to the paged attention scheduler.
|
|
680
|
+
*
|
|
681
|
+
* The scheduler queues requests and allocates blocks for KV cache.
|
|
682
|
+
* Use `step_paged_generation()` to process the scheduled batch.
|
|
683
|
+
*
|
|
684
|
+
* Note: The actual sequence ID is assigned during scheduling, not when the
|
|
685
|
+
* request is added. Use the `request_id` to track your requests through
|
|
686
|
+
* the generation process.
|
|
687
|
+
*
|
|
688
|
+
* # Arguments
|
|
689
|
+
* * `request_id` - Unique identifier for the request (returned in outputs)
|
|
690
|
+
* * `prompt_tokens` - Token IDs for the prompt
|
|
691
|
+
* * `max_new_tokens` - Maximum new tokens to generate
|
|
692
|
+
* * `priority` - Optional priority (higher = scheduled first)
|
|
693
|
+
*
|
|
694
|
+
* # Returns
|
|
695
|
+
* * Number of pending requests in the queue
|
|
696
|
+
*/
|
|
697
|
+
addPagedRequest(
|
|
698
|
+
requestId: string,
|
|
699
|
+
promptTokens: Array<number>,
|
|
700
|
+
maxNewTokens: number,
|
|
701
|
+
priority?: number | undefined | null,
|
|
702
|
+
): number;
|
|
703
|
+
/**
|
|
704
|
+
* Schedule and execute one step of paged generation.
|
|
705
|
+
*
|
|
706
|
+
* This method:
|
|
707
|
+
* 1. Schedules the next batch of sequences
|
|
708
|
+
* 2. Runs forward pass with paged attention
|
|
709
|
+
* 3. Samples next tokens
|
|
710
|
+
* 4. Returns the generated tokens for each sequence
|
|
711
|
+
*
|
|
712
|
+
* # Arguments
|
|
713
|
+
* * `config` - Generation configuration (temperature, top_k, etc.)
|
|
714
|
+
*
|
|
715
|
+
* # Returns
|
|
716
|
+
* * `PagedGenerationStep` with token outputs for each sequence
|
|
717
|
+
*/
|
|
718
|
+
stepPagedGeneration(config?: GenerationConfig | undefined | null): PagedGenerationStep | null;
|
|
719
|
+
/**
|
|
720
|
+
* Get completed sequences from the scheduler.
|
|
721
|
+
*
|
|
722
|
+
* Call this after `step_paged_generation()` returns outputs with `is_finished: true`.
|
|
723
|
+
*/
|
|
724
|
+
getCompletedSequences(): Array<PagedCompletedSequence>;
|
|
725
|
+
/** Check if the scheduler has pending work. */
|
|
726
|
+
hasPagedWork(): boolean;
|
|
727
|
+
/** Get model configuration */
|
|
728
|
+
getConfig(): Qwen3Config;
|
|
729
|
+
/**
|
|
730
|
+
* Generate tokens using speculative decoding with a draft model.
|
|
731
|
+
*
|
|
732
|
+
* Speculative decoding uses a smaller draft model to generate tokens speculatively,
|
|
733
|
+
* then verifies them with the target model in a single forward pass. This can achieve
|
|
734
|
+
* 2-3x speedup when the draft model has high acceptance rate.
|
|
735
|
+
*
|
|
736
|
+
* # Algorithm
|
|
737
|
+
* 1. Draft model generates N tokens speculatively (cheap forward passes)
|
|
738
|
+
* 2. Target model (self) verifies all N tokens in one forward pass
|
|
739
|
+
* 3. Accept/reject using rejection sampling
|
|
740
|
+
* 4. On rejection, resample from adjusted distribution
|
|
741
|
+
* 5. Rewind caches and continue
|
|
742
|
+
*
|
|
743
|
+
* # Arguments
|
|
744
|
+
* * `draft_model` - Smaller model for speculative generation (should share tokenizer)
|
|
745
|
+
* * `input_ids` - Input token IDs [1, seq_len]
|
|
746
|
+
* * `config` - Generation configuration (includes num_draft_tokens)
|
|
747
|
+
*
|
|
748
|
+
* # Returns
|
|
749
|
+
* GenerationResult with tokens, logprobs, and speculative stats in finish_reason
|
|
750
|
+
*
|
|
751
|
+
* # Example (TypeScript)
|
|
752
|
+
* ```typescript
|
|
753
|
+
* const targetModel = await ModelLoader.loadPretrained('qwen3-7b');
|
|
754
|
+
* const draftModel = await ModelLoader.loadPretrained('qwen3-0.5b');
|
|
755
|
+
*
|
|
756
|
+
* const result = targetModel.generateSpeculativeSync(draftModel, inputIds, {
|
|
757
|
+
* numDraftTokens: 5,
|
|
758
|
+
* maxNewTokens: 100,
|
|
759
|
+
* temperature: 0.7,
|
|
760
|
+
* });
|
|
761
|
+
* ```
|
|
762
|
+
*/
|
|
763
|
+
generateSpeculativeSync(
|
|
764
|
+
draftModel: Qwen3Model,
|
|
765
|
+
inputIds: MxArray,
|
|
766
|
+
config?: GenerationConfig | undefined | null,
|
|
767
|
+
): GenerationResult;
|
|
768
|
+
/** Count total number of parameters in the model */
|
|
769
|
+
numParameters(): number;
|
|
770
|
+
/**
|
|
771
|
+
* Get all model parameters as a dictionary mapping names to arrays
|
|
772
|
+
*
|
|
773
|
+
* This matches the TypeScript API for compatibility
|
|
774
|
+
*/
|
|
775
|
+
getParameters(): Record<string, MxArray>;
|
|
776
|
+
/** Load parameters from a dictionary */
|
|
777
|
+
loadParameters(params: Record<string, MxArray>): void;
|
|
778
|
+
/**
|
|
779
|
+
* Compute forward pass and loss (for evaluation)
|
|
780
|
+
*
|
|
781
|
+
* # Arguments
|
|
782
|
+
* * `input_ids` - Input token IDs, shape: [batch_size, seq_len]
|
|
783
|
+
* * `labels` - Target token IDs, shape: [batch_size, seq_len]
|
|
784
|
+
*
|
|
785
|
+
* # Returns
|
|
786
|
+
* * Scalar loss value
|
|
787
|
+
*/
|
|
788
|
+
computeLoss(inputIds: MxArray, labels: MxArray): MxArray;
|
|
789
|
+
/**
|
|
790
|
+
* Compute loss and gradients using a hybrid approach
|
|
791
|
+
*
|
|
792
|
+
* This implementation computes gradients for the output layers and uses
|
|
793
|
+
* numerical approximations for other parameters. This is sufficient to
|
|
794
|
+
* demonstrate that training works while we build out full MLX autograd integration.
|
|
795
|
+
*
|
|
796
|
+
* # Arguments
|
|
797
|
+
* * `input_ids` - Input token IDs, shape: [batch_size, seq_len]
|
|
798
|
+
* * `labels` - Target token IDs, shape: [batch_size, seq_len]
|
|
799
|
+
*
|
|
800
|
+
* # Returns
|
|
801
|
+
* * A tuple of (loss, gradients_dict) where gradients_dict maps parameter names to gradient arrays
|
|
802
|
+
*
|
|
803
|
+
* # Phase 6A Status
|
|
804
|
+
* Current implementation computes:
|
|
805
|
+
* - ✅ Exact gradients for LM head (output layer)
|
|
806
|
+
* - ⚠吅 Numerical approximations for other layers
|
|
807
|
+
*
|
|
808
|
+
* Future: Full MLX autograd will compute exact gradients for all 250+ parameters
|
|
809
|
+
*/
|
|
810
|
+
computeLossAndGradients(inputIds: MxArray, labels: MxArray): [MxArray, Record<string, MxArray>];
|
|
811
|
+
/**
|
|
812
|
+
* Complete GRPO training step using MLX Autograd (RECOMMENDED)
|
|
813
|
+
*
|
|
814
|
+
* This method uses automatic differentiation to compute gradients, eliminating
|
|
815
|
+
* the need for manual backward pass implementation. This is the preferred approach.
|
|
816
|
+
*
|
|
817
|
+
* # Arguments
|
|
818
|
+
* * `prompt_tokens` - Prompt token sequences [batch_size, seq_len] (1D arrays)
|
|
819
|
+
* * `completion_tokens` - Completion sequences [batch*G, completion_len] (1D arrays)
|
|
820
|
+
* * `completion_logprobs` - Logprobs from generation [batch*G, completion_len] (1D arrays)
|
|
821
|
+
* * `rewards` - Reward scores for each completion [batch*G]
|
|
822
|
+
* * `group_size` - Number of completions per prompt (G)
|
|
823
|
+
* * `config` - GRPO loss configuration
|
|
824
|
+
* * `learning_rate` - Learning rate for parameter updates
|
|
825
|
+
*
|
|
826
|
+
* # Returns
|
|
827
|
+
* * Tuple of (loss_value, metrics_dict)
|
|
828
|
+
*/
|
|
829
|
+
trainStepGrpoAutograd(
|
|
830
|
+
promptTokens: Array<MxArray>,
|
|
831
|
+
completionTokens: Array<MxArray>,
|
|
832
|
+
completionLogprobs: Array<MxArray>,
|
|
833
|
+
rewards: Float64Array,
|
|
834
|
+
groupSize: number,
|
|
835
|
+
config: GrpoLossConfig,
|
|
836
|
+
learningRate: number,
|
|
837
|
+
): [number, Record<string, number>];
|
|
838
|
+
/**
|
|
839
|
+
* Compute gradients only without applying them (for gradient accumulation)
|
|
840
|
+
*
|
|
841
|
+
* This method computes GRPO loss and gradients but does NOT update parameters.
|
|
842
|
+
* Used for gradient accumulation where gradients are summed across multiple
|
|
843
|
+
* micro-batches before applying them.
|
|
844
|
+
*
|
|
845
|
+
* # Arguments
|
|
846
|
+
* * `prompt_tokens` - Prompt token sequences [batch_size, seq_len] (1D arrays)
|
|
847
|
+
* * `completion_tokens` - Completion sequences [batch*G, completion_len] (1D arrays)
|
|
848
|
+
* * `completion_logprobs` - Logprobs from generation [batch*G, completion_len] (1D arrays)
|
|
849
|
+
* * `rewards` - Reward scores for each completion [batch*G]
|
|
850
|
+
* * `group_size` - Number of completions per prompt (G)
|
|
851
|
+
* * `config` - GRPO loss configuration
|
|
852
|
+
*
|
|
853
|
+
* # Returns
|
|
854
|
+
* * Tuple of (loss_value, gradients_dict, metrics_dict)
|
|
855
|
+
*/
|
|
856
|
+
computeGradientsOnlyGrpoAutograd(
|
|
857
|
+
promptTokens: Array<MxArray>,
|
|
858
|
+
completionTokens: Array<MxArray>,
|
|
859
|
+
completionLogprobs: Array<MxArray>,
|
|
860
|
+
rewards: Float64Array,
|
|
861
|
+
groupSize: number,
|
|
862
|
+
config: GrpoLossConfig,
|
|
863
|
+
): [number, Record<string, MxArray>, Record<string, number>];
|
|
864
|
+
/**
|
|
865
|
+
* Accumulate gradients into existing gradient dictionary
|
|
866
|
+
*
|
|
867
|
+
* This is a helper method for gradient accumulation. It adds new_gradients
|
|
868
|
+
* to accumulated_gradients element-wise.
|
|
869
|
+
*
|
|
870
|
+
* # Arguments
|
|
871
|
+
* * `accumulated_gradients` - Existing accumulated gradients (will be modified in-place conceptually, but returns new dict)
|
|
872
|
+
* * `new_gradients` - New gradients to add
|
|
873
|
+
*
|
|
874
|
+
* # Returns
|
|
875
|
+
* * Updated gradient dictionary with accumulated values
|
|
876
|
+
*/
|
|
877
|
+
static accumulateGradients(
|
|
878
|
+
accumulatedGradients: Record<string, MxArray>,
|
|
879
|
+
newGradients: Record<string, MxArray>,
|
|
880
|
+
): Record<string, MxArray>;
|
|
881
|
+
/**
|
|
882
|
+
* Complete GRPO training step using manual gradients (Legacy)
|
|
883
|
+
*
|
|
884
|
+
* This method performs a full GRPO training iteration:
|
|
885
|
+
* 1. Takes completions (already generated) with their logprobs and rewards
|
|
886
|
+
* 2. Computes advantages
|
|
887
|
+
* 3. Computes GRPO loss and gradients
|
|
888
|
+
* 4. Updates model parameters
|
|
889
|
+
*
|
|
890
|
+
* NOTE: Use train_step_grpo_autograd instead for automatic differentiation.
|
|
891
|
+
*
|
|
892
|
+
* # Arguments
|
|
893
|
+
* * `prompt_tokens` - Prompt token sequences [batch_size, seq_len] (1D arrays)
|
|
894
|
+
* * `completion_tokens` - Completion sequences [batch*G, completion_len] (1D arrays)
|
|
895
|
+
* * `completion_logprobs` - Logprobs from generation [batch*G, completion_len] (1D arrays)
|
|
896
|
+
* * `rewards` - Reward scores for each completion [batch*G]
|
|
897
|
+
* * `group_size` - Number of completions per prompt (G)
|
|
898
|
+
* * `config` - GRPO loss configuration
|
|
899
|
+
* * `learning_rate` - Learning rate for parameter updates
|
|
900
|
+
*
|
|
901
|
+
* # Returns
|
|
902
|
+
* * Tuple of (loss_value, metrics_dict)
|
|
903
|
+
*/
|
|
904
|
+
trainStepGrpo(
|
|
905
|
+
promptTokens: Array<MxArray>,
|
|
906
|
+
completionTokens: Array<MxArray>,
|
|
907
|
+
completionLogprobs: Array<MxArray>,
|
|
908
|
+
rewards: Float64Array,
|
|
909
|
+
groupSize: number,
|
|
910
|
+
config: GrpoLossConfig,
|
|
911
|
+
learningRate: number,
|
|
912
|
+
): [number, Record<string, number>];
|
|
913
|
+
/**
|
|
914
|
+
* Apply gradients to model parameters
|
|
915
|
+
*
|
|
916
|
+
* # Arguments
|
|
917
|
+
* * `gradients` - Dictionary mapping parameter names to gradient arrays
|
|
918
|
+
* * `learning_rate` - Learning rate for gradient descent
|
|
919
|
+
*
|
|
920
|
+
* This performs a simple SGD update: param = param - lr * grad
|
|
921
|
+
* Only updates parameters that have gradients; others remain unchanged.
|
|
922
|
+
*
|
|
923
|
+
* IMPORTANT: This function preserves the original dtype of parameters.
|
|
924
|
+
* The learning rate scalar is cast to match param dtype to prevent
|
|
925
|
+
* promotion to float32 during arithmetic operations.
|
|
926
|
+
*/
|
|
927
|
+
applyGradients(gradients: Record<string, MxArray>, learningRate: number): void;
|
|
928
|
+
/**
|
|
929
|
+
* Text-to-text generation with integrated tokenization
|
|
930
|
+
*
|
|
931
|
+
* This is a high-level API that handles chat template formatting, tokenization,
|
|
932
|
+
* generation, and decoding internally. It takes chat messages, applies the ChatML
|
|
933
|
+
* template, generates tokens, and decodes them back to text.
|
|
934
|
+
*
|
|
935
|
+
* # Arguments
|
|
936
|
+
* * `messages` - Array of chat messages with role and content
|
|
937
|
+
* * `config` - Generation configuration
|
|
938
|
+
*
|
|
939
|
+
* # Returns
|
|
940
|
+
* * GenerationResult with text, tokens, logprobs, finish reason, and token count
|
|
941
|
+
*
|
|
942
|
+
* # Example
|
|
943
|
+
* ```typescript
|
|
944
|
+
* const model = await Qwen3Model.loadPretrained("path/to/model");
|
|
945
|
+
* const messages = [
|
|
946
|
+
* { role: "user", content: "What is 2+2?" }
|
|
947
|
+
* ];
|
|
948
|
+
* const result = await model.generate(messages, {
|
|
949
|
+
* maxNewTokens: 50,
|
|
950
|
+
* temperature: 0.8,
|
|
951
|
+
* topP: 0.95,
|
|
952
|
+
* });
|
|
953
|
+
* console.log(result.text); // Decoded text output
|
|
954
|
+
* console.log(result.tokens); // Token IDs (for GRPO)
|
|
955
|
+
* console.log(result.logprobs); // Log probabilities (for GRPO)
|
|
956
|
+
* ```
|
|
957
|
+
*/
|
|
958
|
+
generate(messages: Array<ChatMessage>, config?: GenerationConfig | undefined | null): Promise<GenerationResult>;
|
|
959
|
+
/**
|
|
960
|
+
* High-level chat API with structured response parsing
|
|
961
|
+
*
|
|
962
|
+
* The primary API for conversational AI. Handles:
|
|
963
|
+
* - Chat message formatting with Jinja2 templates
|
|
964
|
+
* - Tool/function calling with structured output
|
|
965
|
+
* - Thinking extraction from `<think>` tags
|
|
966
|
+
* - Clean response text with all special tags stripped
|
|
967
|
+
*
|
|
968
|
+
* ## `chat()` vs `generate()`
|
|
969
|
+
*
|
|
970
|
+
* | Feature | `chat()` | `generate()` |
|
|
971
|
+
* |---------|----------|--------------|
|
|
972
|
+
* | **Purpose** | Conversational AI with tools | Raw text generation |
|
|
973
|
+
* | **Input** | Chat messages | Token IDs (MxArray) |
|
|
974
|
+
* | **Tool Support** | Built-in parsing | None |
|
|
975
|
+
* | **Thinking** | Extracts `<think>` content | Raw text only |
|
|
976
|
+
* | **Output** | Structured `ChatResult` | Basic `GenerationResult` |
|
|
977
|
+
* | **Use Case** | Chat apps, agents, assistants | Training, low-level control |
|
|
978
|
+
*
|
|
979
|
+
* ## When to use `chat()`
|
|
980
|
+
* - Building conversational applications
|
|
981
|
+
* - Need tool/function calling
|
|
982
|
+
* - Want structured responses with thinking separated
|
|
983
|
+
* - Working with chat message format
|
|
984
|
+
*
|
|
985
|
+
* ## When to use `generate()`
|
|
986
|
+
* - Training and fine-tuning (need raw logprobs)
|
|
987
|
+
* - Custom tokenization pipeline
|
|
988
|
+
* - Low-level generation control
|
|
989
|
+
* - Non-chat use cases
|
|
990
|
+
*
|
|
991
|
+
* # Arguments
|
|
992
|
+
* * `messages` - Array of chat messages (user/assistant/system roles)
|
|
993
|
+
* * `config` - Chat configuration including optional tools and generation params
|
|
994
|
+
*
|
|
995
|
+
* # Returns
|
|
996
|
+
* * `ChatResult` containing:
|
|
997
|
+
* - `text`: Clean response (tool_call and think tags stripped)
|
|
998
|
+
* - `thinking`: Extracted chain-of-thought reasoning (or null)
|
|
999
|
+
* - `toolCalls`: Parsed tool calls with native JS object arguments
|
|
1000
|
+
* - `finishReason`: "stop" | "length" | "tool_calls"
|
|
1001
|
+
* - `rawText`: Original text before processing (for debugging)
|
|
1002
|
+
*
|
|
1003
|
+
* # Example
|
|
1004
|
+
* ```typescript
|
|
1005
|
+
* // Simple chat
|
|
1006
|
+
* const result = await model.chat(messages);
|
|
1007
|
+
* console.log(result.text);
|
|
1008
|
+
*
|
|
1009
|
+
* // With tools
|
|
1010
|
+
* const result = await model.chat(messages, {
|
|
1011
|
+
* tools: [{ type: 'function', function: { name: 'get_weather' } }],
|
|
1012
|
+
* maxNewTokens: 2048,
|
|
1013
|
+
* temperature: 0.7,
|
|
1014
|
+
* });
|
|
1015
|
+
*
|
|
1016
|
+
* // Handle tool calls
|
|
1017
|
+
* for (const call of result.toolCalls) {
|
|
1018
|
+
* if (call.status === 'ok') {
|
|
1019
|
+
* console.log(call.name, call.arguments); // Arguments is a JS object!
|
|
1020
|
+
* }
|
|
1021
|
+
* }
|
|
1022
|
+
*
|
|
1023
|
+
* // Access thinking (chain-of-thought)
|
|
1024
|
+
* if (result.thinking) {
|
|
1025
|
+
* console.log('Model reasoning:', result.thinking);
|
|
1026
|
+
* }
|
|
1027
|
+
* ```
|
|
1028
|
+
*/
|
|
1029
|
+
chat(messages: Array<ChatMessage>, config?: ChatConfig | undefined | null): Promise<ChatResult>;
|
|
1030
|
+
/**
|
|
1031
|
+
* Generate multiple completions for multiple prompts in batch
|
|
1032
|
+
*
|
|
1033
|
+
* This is an optimized method for GRPO training that generates G completions
|
|
1034
|
+
* for each of N prompts. It performs all tokenization, generation, and decoding
|
|
1035
|
+
* in 3 blocking tasks instead of N*(1+2G) tasks.
|
|
1036
|
+
*
|
|
1037
|
+
* # Arguments
|
|
1038
|
+
* * `prompts` - Array of N prompt message arrays
|
|
1039
|
+
* * `group_size` - Number of completions (G) to generate per prompt
|
|
1040
|
+
* * `config` - Generation configuration (sampling params, etc.)
|
|
1041
|
+
*
|
|
1042
|
+
* # Returns
|
|
1043
|
+
* * BatchGenerationResult containing N*G completions with:
|
|
1044
|
+
* - tokens: Flat array of N*G token arrays
|
|
1045
|
+
* - logprobs: Flat array of N*G logprob arrays
|
|
1046
|
+
* - texts: Flat array of N*G decoded texts
|
|
1047
|
+
* - finish_reasons: N arrays of G finish reasons
|
|
1048
|
+
* - token_counts: N arrays of G token counts
|
|
1049
|
+
*
|
|
1050
|
+
* # Performance
|
|
1051
|
+
* For N=10 prompts, G=8 completions:
|
|
1052
|
+
* - Old approach: N*(1 tokenize + G generate + G decode) = 10*(1+8+8) = 170 blocking tasks
|
|
1053
|
+
* - New approach: 1 tokenize + N*G generate + 1 decode = 1+80+1 = 82 blocking tasks (2.1x reduction)
|
|
1054
|
+
*
|
|
1055
|
+
* # Example
|
|
1056
|
+
* ```typescript
|
|
1057
|
+
* const result = await model.generateBatch(
|
|
1058
|
+
* [messages1, messages2, ...], // N prompts
|
|
1059
|
+
* 8, // G completions per prompt
|
|
1060
|
+
* config
|
|
1061
|
+
* );
|
|
1062
|
+
* ```
|
|
1063
|
+
*/
|
|
1064
|
+
generateBatch(
|
|
1065
|
+
prompts: Array<Array<ChatMessage>>,
|
|
1066
|
+
groupSize: number,
|
|
1067
|
+
config?: GenerationConfig | undefined | null,
|
|
1068
|
+
): Promise<BatchGenerationResult>;
|
|
1069
|
+
/**
|
|
1070
|
+
* Decode token IDs to text using the internal tokenizer
|
|
1071
|
+
*
|
|
1072
|
+
* Helper method for decoding generated tokens. The model must have been loaded
|
|
1073
|
+
* via load_pretrained() to have a tokenizer available.
|
|
1074
|
+
*
|
|
1075
|
+
* # Arguments
|
|
1076
|
+
* * `token_ids` - Token IDs to decode as Uint32Array
|
|
1077
|
+
* * `skip_special_tokens` - Whether to skip special tokens (default: true)
|
|
1078
|
+
*
|
|
1079
|
+
* # Returns
|
|
1080
|
+
* * Decoded text string
|
|
1081
|
+
*/
|
|
1082
|
+
decode(tokenIds: Uint32Array, skipSpecialTokens?: boolean | undefined | null): Promise<string>;
|
|
1083
|
+
/**
|
|
1084
|
+
* Apply chat template and encode to token IDs
|
|
1085
|
+
*
|
|
1086
|
+
* Formats messages using ChatML format (or Jinja2 template with tools) and encodes to tokens.
|
|
1087
|
+
* The model must have been loaded via load_pretrained() to have a tokenizer available.
|
|
1088
|
+
*
|
|
1089
|
+
* # Arguments
|
|
1090
|
+
* * `messages` - Array of chat messages
|
|
1091
|
+
* * `add_generation_prompt` - Whether to add generation prompt (default: true)
|
|
1092
|
+
* * `tools` - Optional array of tool definitions for function calling
|
|
1093
|
+
* * `enable_thinking` - Optional flag to enable thinking mode (<think> tags)
|
|
1094
|
+
*
|
|
1095
|
+
* # Returns
|
|
1096
|
+
* * Encoded token IDs as Uint32Array
|
|
1097
|
+
*/
|
|
1098
|
+
applyChatTemplate(
|
|
1099
|
+
messages: Array<ChatMessage>,
|
|
1100
|
+
addGenerationPrompt?: boolean | undefined | null,
|
|
1101
|
+
tools?: Array<ToolDefinition> | undefined | null,
|
|
1102
|
+
enableThinking?: boolean | undefined | null,
|
|
1103
|
+
): Promise<Uint32Array>;
|
|
1104
|
+
/**
|
|
1105
|
+
* Load a pretrained model from disk
|
|
1106
|
+
*
|
|
1107
|
+
* This loads a model from a directory containing:
|
|
1108
|
+
* - config.json: Model configuration
|
|
1109
|
+
* - weights.mlx (optional): MLX format weights with data arrays
|
|
1110
|
+
* - weights.safetensors (optional): SafeTensors format (not yet supported)
|
|
1111
|
+
*
|
|
1112
|
+
* # Arguments
|
|
1113
|
+
* * `model_path` - Path to the model directory
|
|
1114
|
+
*
|
|
1115
|
+
* # Returns
|
|
1116
|
+
* * A fully initialized Qwen3Model with loaded weights
|
|
1117
|
+
*/
|
|
1118
|
+
static loadPretrained(modelPath: string): Promise<Qwen3Model>;
|
|
1119
|
+
/**
|
|
1120
|
+
* Save model configuration and weights to disk
|
|
1121
|
+
*
|
|
1122
|
+
* This saves:
|
|
1123
|
+
* - config.json: Model configuration
|
|
1124
|
+
* - weights.safetensors: Full model weights in SafeTensors format
|
|
1125
|
+
* - weights.mlx: Parameter metadata (for reference)
|
|
1126
|
+
*
|
|
1127
|
+
* # Arguments
|
|
1128
|
+
* * `save_path` - Directory to save the model
|
|
1129
|
+
*/
|
|
1130
|
+
saveModel(savePath: string): Promise<undefined>;
|
|
1131
|
+
/**
|
|
1132
|
+
* Validate that a set of parameters has all required weights with correct shapes
|
|
1133
|
+
*
|
|
1134
|
+
* This is useful for validating parameters before loading them into a model,
|
|
1135
|
+
* or for checking that saved weights are valid before training.
|
|
1136
|
+
*
|
|
1137
|
+
* # Arguments
|
|
1138
|
+
* * `params` - HashMap of parameter names to MxArray values
|
|
1139
|
+
*
|
|
1140
|
+
* # Returns
|
|
1141
|
+
* * Ok(()) if all validations pass
|
|
1142
|
+
* * Err with descriptive message if validation fails
|
|
1143
|
+
*/
|
|
1144
|
+
validateParameters(params: Record<string, MxArray>): void;
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
/** Qwen3 Tokenizer class with NAPI bindings */
|
|
1148
|
+
export declare class Qwen3Tokenizer {
|
|
1149
|
+
/**
|
|
1150
|
+
* Load tokenizer from tokenizer.json file
|
|
1151
|
+
*
|
|
1152
|
+
* # Arguments
|
|
1153
|
+
* * `path` - Path to tokenizer.json file (default: "../.cache/assets/tokenizers/qwen3_tokenizer.json")
|
|
1154
|
+
*
|
|
1155
|
+
* # Example
|
|
1156
|
+
* ```typescript
|
|
1157
|
+
* const tokenizer = Qwen3Tokenizer.fromPretrained();
|
|
1158
|
+
* const tokens = tokenizer.encode("Hello, world!");
|
|
1159
|
+
* ```
|
|
1160
|
+
*/
|
|
1161
|
+
static fromPretrained(tokenizerPath: string): Promise<Qwen3Tokenizer>;
|
|
1162
|
+
/**
|
|
1163
|
+
* Encode text to token IDs
|
|
1164
|
+
*
|
|
1165
|
+
* # Arguments
|
|
1166
|
+
* * `text` - Text to encode
|
|
1167
|
+
* * `add_special_tokens` - Whether to add special tokens (default: true)
|
|
1168
|
+
*
|
|
1169
|
+
* # Returns
|
|
1170
|
+
* Array of token IDs as Int32Array
|
|
1171
|
+
*
|
|
1172
|
+
* # Example
|
|
1173
|
+
* ```typescript
|
|
1174
|
+
* const tokens = tokenizer.encode("Hello, world!");
|
|
1175
|
+
* console.log(tokens); // Int32Array [9906, 11, 1879, 0]
|
|
1176
|
+
* ```
|
|
1177
|
+
*/
|
|
1178
|
+
encode(text: string, addSpecialTokens?: boolean | undefined | null): Promise<Uint32Array>;
|
|
1179
|
+
/**
|
|
1180
|
+
* Encode multiple texts in batch
|
|
1181
|
+
*
|
|
1182
|
+
* # Arguments
|
|
1183
|
+
* * `texts` - Array of texts to encode
|
|
1184
|
+
* * `add_special_tokens` - Whether to add special tokens (default: true)
|
|
1185
|
+
*
|
|
1186
|
+
* # Returns
|
|
1187
|
+
* Array of Int32Arrays, one for each text
|
|
1188
|
+
*/
|
|
1189
|
+
encodeBatch(texts: Array<string>, addSpecialTokens?: boolean | undefined | null): Promise<Array<Uint32Array>>;
|
|
1190
|
+
/**
|
|
1191
|
+
* Decode token IDs to text
|
|
1192
|
+
*
|
|
1193
|
+
* # Arguments
|
|
1194
|
+
* * `token_ids` - Token IDs to decode
|
|
1195
|
+
* * `skip_special_tokens` - Whether to skip special tokens (default: true)
|
|
1196
|
+
*
|
|
1197
|
+
* # Returns
|
|
1198
|
+
* Decoded text string
|
|
1199
|
+
*
|
|
1200
|
+
* # Example
|
|
1201
|
+
* ```typescript
|
|
1202
|
+
* const text = tokenizer.decode(new Int32Array([9906, 11, 1879, 0]));
|
|
1203
|
+
* console.log(text); // "Hello, world!"
|
|
1204
|
+
* ```
|
|
1205
|
+
*/
|
|
1206
|
+
decode(tokenIds: Uint32Array, skipSpecialTokens?: boolean | undefined | null): Promise<string>;
|
|
1207
|
+
/**
|
|
1208
|
+
* Decode multiple token sequences in batch
|
|
1209
|
+
*
|
|
1210
|
+
* # Arguments
|
|
1211
|
+
* * `token_ids_batch` - Array of token ID arrays to decode
|
|
1212
|
+
* * `skip_special_tokens` - Whether to skip special tokens (default: true)
|
|
1213
|
+
*
|
|
1214
|
+
* # Returns
|
|
1215
|
+
* Array of decoded text strings
|
|
1216
|
+
*/
|
|
1217
|
+
decodeBatch(
|
|
1218
|
+
tokenIdsBatch: Array<Uint32Array>,
|
|
1219
|
+
skipSpecialTokens?: boolean | undefined | null,
|
|
1220
|
+
): Promise<Array<string>>;
|
|
1221
|
+
/**
|
|
1222
|
+
* Apply chat template to messages and encode
|
|
1223
|
+
*
|
|
1224
|
+
* Supports both simple ChatML format and full Jinja2 template rendering with tools.
|
|
1225
|
+
* When tools are provided or a chat template exists, uses Jinja2 rendering.
|
|
1226
|
+
* Otherwise falls back to simple ChatML format.
|
|
1227
|
+
*
|
|
1228
|
+
* # Arguments
|
|
1229
|
+
* * `messages` - Array of chat messages
|
|
1230
|
+
* * `add_generation_prompt` - Whether to add assistant prompt at end (default: true)
|
|
1231
|
+
* * `tools` - Optional array of tool definitions for function calling
|
|
1232
|
+
* * `enable_thinking` - Optional flag to enable thinking mode (<think> tags)
|
|
1233
|
+
*
|
|
1234
|
+
* # Returns
|
|
1235
|
+
* Encoded token IDs ready for model input
|
|
1236
|
+
*
|
|
1237
|
+
* # Example
|
|
1238
|
+
* ```typescript
|
|
1239
|
+
* const messages = [
|
|
1240
|
+
* { role: "system", content: "You are a helpful assistant." },
|
|
1241
|
+
* { role: "user", content: "What is 2+2?" }
|
|
1242
|
+
* ];
|
|
1243
|
+
* const tokens = tokenizer.applyChatTemplate(messages, true);
|
|
1244
|
+
*
|
|
1245
|
+
* // With tools
|
|
1246
|
+
* const tools = [{
|
|
1247
|
+
* type: "function",
|
|
1248
|
+
* function: { name: "get_weather", description: "Get weather info" }
|
|
1249
|
+
* }];
|
|
1250
|
+
* const tokens = tokenizer.applyChatTemplate(messages, true, tools);
|
|
1251
|
+
* ```
|
|
1252
|
+
*/
|
|
1253
|
+
applyChatTemplate(
|
|
1254
|
+
messages: Array<ChatMessage>,
|
|
1255
|
+
addGenerationPrompt?: boolean | undefined | null,
|
|
1256
|
+
tools?: Array<ToolDefinition> | undefined | null,
|
|
1257
|
+
enableThinking?: boolean | undefined | null,
|
|
1258
|
+
): Promise<Uint32Array>;
|
|
1259
|
+
/** Get vocabulary size */
|
|
1260
|
+
vocabSize(): number;
|
|
1261
|
+
/** Get PAD token ID */
|
|
1262
|
+
getPadTokenId(): number;
|
|
1263
|
+
/** Get EOS token ID */
|
|
1264
|
+
getEosTokenId(): number;
|
|
1265
|
+
/** Get BOS token ID (if exists) */
|
|
1266
|
+
getBosTokenId(): number | null;
|
|
1267
|
+
/** Convert token ID to string */
|
|
1268
|
+
idToToken(id: number): string | null;
|
|
1269
|
+
/** Convert token string to ID */
|
|
1270
|
+
tokenToId(token: string): number | null;
|
|
1271
|
+
/** Get the special token for IM_START */
|
|
1272
|
+
getImStartToken(): string;
|
|
1273
|
+
/** Get the special token for IM_END */
|
|
1274
|
+
getImEndToken(): string;
|
|
1275
|
+
/** Get the special token for ENDOFTEXT (used as PAD) */
|
|
1276
|
+
getEndoftextToken(): string;
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
/** SFT Training Engine */
|
|
1280
|
+
export declare class SftTrainingEngine {
|
|
1281
|
+
/** Create a new SFT training engine */
|
|
1282
|
+
constructor(model: Qwen3Model, config: SftEngineConfig);
|
|
1283
|
+
/** Run a single training step */
|
|
1284
|
+
trainStep(inputIds: MxArray, labels: MxArray): Promise<SftStepMetrics>;
|
|
1285
|
+
/** Get current step number */
|
|
1286
|
+
getStep(): number;
|
|
1287
|
+
/** Get current epoch */
|
|
1288
|
+
getEpoch(): number;
|
|
1289
|
+
/**
|
|
1290
|
+
* Flush any accumulated gradients at epoch end
|
|
1291
|
+
*
|
|
1292
|
+
* When stepsPerEpoch % gradient_accumulation_steps != 0, there may be
|
|
1293
|
+
* leftover gradients from the final micro-batches. This method applies
|
|
1294
|
+
* them with proper averaging, matching TRL behavior.
|
|
1295
|
+
*/
|
|
1296
|
+
flushGradients(): boolean;
|
|
1297
|
+
/**
|
|
1298
|
+
* Compute the resume position given current state and dataset info
|
|
1299
|
+
*
|
|
1300
|
+
* This centralizes all resume logic in Rust for correctness.
|
|
1301
|
+
* Uses i64 math internally to avoid overflow on long runs.
|
|
1302
|
+
*/
|
|
1303
|
+
computeResumePosition(stepsPerEpoch: number): ResumePosition;
|
|
1304
|
+
/** Check if emergency save is needed */
|
|
1305
|
+
needsEmergencySave(): boolean;
|
|
1306
|
+
/** Clear emergency save flag */
|
|
1307
|
+
clearEmergencySave(): void;
|
|
1308
|
+
/**
|
|
1309
|
+
* Signal start of a new epoch
|
|
1310
|
+
*
|
|
1311
|
+
* Takes the epoch number directly from TypeScript to ensure synchronization.
|
|
1312
|
+
* The epoch is 0-indexed to match the TypeScript training loop.
|
|
1313
|
+
*/
|
|
1314
|
+
startEpoch(epoch: number): void;
|
|
1315
|
+
/** End current epoch and return metrics */
|
|
1316
|
+
endEpoch(epochTimeSecs: number): SftEpochMetrics;
|
|
1317
|
+
/** Reset training state (for new training run) */
|
|
1318
|
+
reset(): void;
|
|
1319
|
+
/** Restore training state (for resuming from checkpoint) */
|
|
1320
|
+
restoreState(step: number, epoch: number): void;
|
|
1321
|
+
/** Get the underlying model for checkpointing */
|
|
1322
|
+
getModel(): Qwen3Model;
|
|
1323
|
+
}
|
|
1324
|
+
|
|
1325
|
+
/**
|
|
1326
|
+
* A tensor that tracks gradients for automatic differentiation
|
|
1327
|
+
*
|
|
1328
|
+
* This is a wrapper around MxArray that provides:
|
|
1329
|
+
* - Gradient tracking
|
|
1330
|
+
* - Automatic gradient accumulation
|
|
1331
|
+
* - Integration with manual backward passes
|
|
1332
|
+
*/
|
|
1333
|
+
export declare class Tensor {
|
|
1334
|
+
/** Create a tensor from float32 data */
|
|
1335
|
+
static fromFloat32(data: Float32Array, shape: BigInt64Array, requiresGrad?: boolean | undefined | null): Tensor;
|
|
1336
|
+
/** Create a tensor from int32 data */
|
|
1337
|
+
static fromInt32(data: Int32Array, shape: BigInt64Array, requiresGrad?: boolean | undefined | null): Tensor;
|
|
1338
|
+
/** Get the shape of the underlying data */
|
|
1339
|
+
dataShape(): BigInt64Array;
|
|
1340
|
+
/** Get the shape of the gradient (if it exists) */
|
|
1341
|
+
gradShape(): BigInt64Array | null;
|
|
1342
|
+
/** Check if gradient exists */
|
|
1343
|
+
hasGrad(): boolean;
|
|
1344
|
+
/** Check if this tensor requires gradients */
|
|
1345
|
+
get requiresGrad(): boolean;
|
|
1346
|
+
/** Set whether this tensor requires gradients */
|
|
1347
|
+
set requiresGrad(requiresGrad: boolean);
|
|
1348
|
+
/** Zero out the gradient */
|
|
1349
|
+
zeroGrad(): void;
|
|
1350
|
+
/**
|
|
1351
|
+
* Accumulate gradient
|
|
1352
|
+
*
|
|
1353
|
+
* If gradient already exists, add to it. Otherwise, set it.
|
|
1354
|
+
* Note: This takes ownership of the gradient array.
|
|
1355
|
+
*/
|
|
1356
|
+
accumulateGrad(grad: MxArray): void;
|
|
1357
|
+
/** Get the shape of the tensor */
|
|
1358
|
+
shape(): BigInt64Array;
|
|
1359
|
+
/** Convert data to Float32 array */
|
|
1360
|
+
toFloat32(): Float32Array;
|
|
1361
|
+
/** Convert gradient to Float32 array (if it exists) */
|
|
1362
|
+
gradToFloat32(): Float32Array | null;
|
|
1363
|
+
/** Convert to Int32 array */
|
|
1364
|
+
toInt32(): Int32Array;
|
|
1365
|
+
/**
|
|
1366
|
+
* Detach this tensor from the computation graph
|
|
1367
|
+
*
|
|
1368
|
+
* Returns a new tensor with the same data but no gradient tracking
|
|
1369
|
+
*/
|
|
1370
|
+
detach(): Tensor;
|
|
1371
|
+
/** Create a tensor of zeros */
|
|
1372
|
+
static zeros(
|
|
1373
|
+
shape: BigInt64Array,
|
|
1374
|
+
dtype?: DType | undefined | null,
|
|
1375
|
+
requiresGrad?: boolean | undefined | null,
|
|
1376
|
+
): Tensor;
|
|
1377
|
+
/** Create a tensor of ones */
|
|
1378
|
+
static ones(
|
|
1379
|
+
shape: BigInt64Array,
|
|
1380
|
+
dtype?: DType | undefined | null,
|
|
1381
|
+
requiresGrad?: boolean | undefined | null,
|
|
1382
|
+
): Tensor;
|
|
1383
|
+
/** Evaluate the underlying array */
|
|
1384
|
+
eval(): void;
|
|
1385
|
+
}
|
|
1386
|
+
|
|
1387
|
+
/** Result from VLM chat */
|
|
1388
|
+
export declare class VlmChatResult {
|
|
1389
|
+
/** Get the response text */
|
|
1390
|
+
get text(): string;
|
|
1391
|
+
/** Get the generated tokens */
|
|
1392
|
+
get tokens(): MxArray;
|
|
1393
|
+
/** Get the log probabilities */
|
|
1394
|
+
get logprobs(): MxArray;
|
|
1395
|
+
/** Get the finish reason */
|
|
1396
|
+
get finishReason(): 'stop' | 'length' | 'repetition';
|
|
1397
|
+
/** Get the number of tokens generated */
|
|
1398
|
+
get numTokens(): number;
|
|
1399
|
+
}
|
|
1400
|
+
export type VLMChatResult = VlmChatResult;
|
|
1401
|
+
|
|
1402
|
+
/**
|
|
1403
|
+
* Vision-Language Model
|
|
1404
|
+
*
|
|
1405
|
+
* A generic VLM for OCR and document understanding tasks.
|
|
1406
|
+
* Currently supports PaddleOCR-VL architecture (vision encoder + ERNIE language model).
|
|
1407
|
+
*/
|
|
1408
|
+
export declare class VLModel {
|
|
1409
|
+
/** Create a new PaddleOCR-VL model */
|
|
1410
|
+
constructor(config: ModelConfig);
|
|
1411
|
+
/** Set the tokenizer */
|
|
1412
|
+
setTokenizer(tokenizer: Qwen3Tokenizer): void;
|
|
1413
|
+
/** Check if tokenizer is available */
|
|
1414
|
+
get hasTokenizer(): boolean;
|
|
1415
|
+
/**
|
|
1416
|
+
* Chat with the VLM model
|
|
1417
|
+
*
|
|
1418
|
+
* High-level API for conversational interaction with images.
|
|
1419
|
+
*
|
|
1420
|
+
* # Arguments
|
|
1421
|
+
* * `messages` - Chat messages (role + content)
|
|
1422
|
+
* * `config` - Chat configuration (including image_paths for automatic processing)
|
|
1423
|
+
*
|
|
1424
|
+
* # Returns
|
|
1425
|
+
* * VLMChatResult with generated text
|
|
1426
|
+
*
|
|
1427
|
+
* # Example
|
|
1428
|
+
* ```typescript
|
|
1429
|
+
* const result = model.chat(
|
|
1430
|
+
* [{ role: 'user', content: 'Describe this image.' }],
|
|
1431
|
+
* { imagePaths: ['./photo.jpg'], maxNewTokens: 256 }
|
|
1432
|
+
* );
|
|
1433
|
+
* ```
|
|
1434
|
+
*/
|
|
1435
|
+
chat(messages: Array<VlmChatMessage>, config?: VlmChatConfig | undefined | null): VlmChatResult;
|
|
1436
|
+
/**
|
|
1437
|
+
* Simple OCR: extract text from an image file
|
|
1438
|
+
*
|
|
1439
|
+
* Convenience method that processes an image and extracts all text.
|
|
1440
|
+
*
|
|
1441
|
+
* # Arguments
|
|
1442
|
+
* * `image_path` - Path to the image file
|
|
1443
|
+
* * `prompt` - Optional custom prompt (default: "Extract all text from this image.")
|
|
1444
|
+
*
|
|
1445
|
+
* # Returns
|
|
1446
|
+
* * Extracted text as a string
|
|
1447
|
+
*
|
|
1448
|
+
* # Example
|
|
1449
|
+
* ```typescript
|
|
1450
|
+
* const text = await model.ocr('./receipt.jpg');
|
|
1451
|
+
* console.log(text);
|
|
1452
|
+
* ```
|
|
1453
|
+
*/
|
|
1454
|
+
ocr(imagePath: string, prompt?: string | undefined | null): string;
|
|
1455
|
+
/**
|
|
1456
|
+
* Get input embeddings with vision features merged
|
|
1457
|
+
*
|
|
1458
|
+
* # Arguments
|
|
1459
|
+
* * `input_ids` - Token IDs [batch, seq_len]
|
|
1460
|
+
* * `pixel_values` - Optional image patches [batch, seq, channels, patch_h, patch_w]
|
|
1461
|
+
* * `image_grid_thw` - Optional grid dimensions [num_images, 3]
|
|
1462
|
+
*
|
|
1463
|
+
* # Returns
|
|
1464
|
+
* * Input embeddings with vision features inserted at image token positions
|
|
1465
|
+
*/
|
|
1466
|
+
getInputEmbeddings(
|
|
1467
|
+
inputIds: MxArray,
|
|
1468
|
+
pixelValues?: MxArray | undefined | null,
|
|
1469
|
+
imageGridThw?: MxArray | undefined | null,
|
|
1470
|
+
): MxArray;
|
|
1471
|
+
/**
|
|
1472
|
+
* Forward pass
|
|
1473
|
+
*
|
|
1474
|
+
* # Arguments
|
|
1475
|
+
* * `input_ids` - Token IDs [batch, seq_len]
|
|
1476
|
+
* * `pixel_values` - Optional image patches
|
|
1477
|
+
* * `image_grid_thw` - Optional grid dimensions
|
|
1478
|
+
* * `mask` - Optional attention mask
|
|
1479
|
+
*
|
|
1480
|
+
* # Returns
|
|
1481
|
+
* * Logits [batch, seq_len, vocab_size]
|
|
1482
|
+
*/
|
|
1483
|
+
forward(
|
|
1484
|
+
inputIds: MxArray,
|
|
1485
|
+
pixelValues?: MxArray | undefined | null,
|
|
1486
|
+
imageGridThw?: MxArray | undefined | null,
|
|
1487
|
+
mask?: MxArray | undefined | null,
|
|
1488
|
+
): MxArray;
|
|
1489
|
+
/**
|
|
1490
|
+
* Generate text tokens given input tokens and optional image
|
|
1491
|
+
*
|
|
1492
|
+
* Uses KV caching for efficient generation - each step only processes the
|
|
1493
|
+
* new token(s) while reusing cached key-value states from previous tokens.
|
|
1494
|
+
* Vision features are computed once at the start and cached.
|
|
1495
|
+
*
|
|
1496
|
+
* # Arguments
|
|
1497
|
+
* * `input_ids` - Input token IDs [1, seq_len]
|
|
1498
|
+
* * `pixel_values` - Optional image patches [1, num_patches, C, H, W]
|
|
1499
|
+
* * `image_grid_thw` - Optional grid dimensions [1, 3]
|
|
1500
|
+
* * `config` - Generation configuration
|
|
1501
|
+
*
|
|
1502
|
+
* # Returns
|
|
1503
|
+
* * GenerationResult with tokens, logprobs, and finish reason
|
|
1504
|
+
*/
|
|
1505
|
+
generate(
|
|
1506
|
+
inputIds: MxArray,
|
|
1507
|
+
pixelValues?: MxArray | undefined | null,
|
|
1508
|
+
imageGridThw?: MxArray | undefined | null,
|
|
1509
|
+
config?: GenerationConfig | undefined | null,
|
|
1510
|
+
): GenerationResult;
|
|
1511
|
+
/** Get model configuration */
|
|
1512
|
+
get config(): ModelConfig;
|
|
1513
|
+
/** Check if model is fully initialized */
|
|
1514
|
+
get isInitialized(): boolean;
|
|
1515
|
+
/**
|
|
1516
|
+
* Load a VLM from disk
|
|
1517
|
+
*
|
|
1518
|
+
* Loads a model from a directory containing:
|
|
1519
|
+
* - config.json: Model configuration
|
|
1520
|
+
* - model.safetensors or model-*.safetensors: Model weights in SafeTensors format
|
|
1521
|
+
*
|
|
1522
|
+
* # Arguments
|
|
1523
|
+
* * `model_path` - Path to the model directory
|
|
1524
|
+
*
|
|
1525
|
+
* # Returns
|
|
1526
|
+
* * A fully initialized VLModel with loaded weights
|
|
1527
|
+
*
|
|
1528
|
+
* # Example
|
|
1529
|
+
* ```typescript
|
|
1530
|
+
* import { VLModel } from '@mlx-node/vlm';
|
|
1531
|
+
* const model = await VLModel.load('./models/paddleocr-vl');
|
|
1532
|
+
* const result = model.chat(messages, { imagePaths: ['./image.jpg'] });
|
|
1533
|
+
* ```
|
|
1534
|
+
*/
|
|
1535
|
+
static load(modelPath: string): Promise<VLModel>;
|
|
1536
|
+
/**
|
|
1537
|
+
* Load model configuration from disk without loading weights
|
|
1538
|
+
*
|
|
1539
|
+
* This is useful for inspecting model configuration before loading the full model.
|
|
1540
|
+
*
|
|
1541
|
+
* # Arguments
|
|
1542
|
+
* * `model_path` - Path to the model directory containing config.json
|
|
1543
|
+
*
|
|
1544
|
+
* # Returns
|
|
1545
|
+
* * ModelConfig with vision and text configuration
|
|
1546
|
+
*
|
|
1547
|
+
* # Example
|
|
1548
|
+
* ```typescript
|
|
1549
|
+
* import { VLModel } from '@mlx-node/vlm';
|
|
1550
|
+
* const config = await VLModel.loadConfig('./models/paddleocr-vl');
|
|
1551
|
+
* console.log(config.visionConfig.hiddenSize);
|
|
1552
|
+
* ```
|
|
1553
|
+
*/
|
|
1554
|
+
static loadConfig(modelPath: string): Promise<ModelConfig>;
|
|
1555
|
+
}
|
|
1556
|
+
|
|
1557
|
+
/**
|
|
1558
|
+
* Build RewardOutput array from generation results.
|
|
1559
|
+
*
|
|
1560
|
+
* Parses tool calls and thinking from completions, creating structured outputs
|
|
1561
|
+
* aligned with the ChatResult structure.
|
|
1562
|
+
*
|
|
1563
|
+
* # Arguments
|
|
1564
|
+
* * `prompts` - Array of prompt texts (one per unique prompt, will be expanded by group_size)
|
|
1565
|
+
* * `completions` - Array of completion texts (prompts.len() * group_size total)
|
|
1566
|
+
* * `token_counts` - Array of token counts for each completion
|
|
1567
|
+
* * `finish_reasons` - Array of finish reasons from generation ("eos", "length", "stop", "repetition")
|
|
1568
|
+
* * `group_size` - Number of completions per prompt
|
|
1569
|
+
*
|
|
1570
|
+
* # Returns
|
|
1571
|
+
* Array of RewardOutput objects with structured completion data
|
|
1572
|
+
*
|
|
1573
|
+
* # Example
|
|
1574
|
+
* ```typescript
|
|
1575
|
+
* import { buildRewardOutputs } from '@mlx-node/core';
|
|
1576
|
+
*
|
|
1577
|
+
* const outputs = buildRewardOutputs(
|
|
1578
|
+
* ['What is 2+2?'], // prompts
|
|
1579
|
+
* ['<think>Let me calculate</think>
|
|
1580
|
+
|
|
1581
|
+
4', '4'], // completions (group_size=2)
|
|
1582
|
+
* [10, 5], // token counts
|
|
1583
|
+
* ['eos', 'length'], // finish reasons
|
|
1584
|
+
* 2 // group_size
|
|
1585
|
+
* );
|
|
1586
|
+
*
|
|
1587
|
+
* outputs[0].completion.thinking; // "Let me calculate"
|
|
1588
|
+
* outputs[0].completion.text; // "4"
|
|
1589
|
+
* outputs[0].completion.finishReason; // "eos"
|
|
1590
|
+
* ```
|
|
1591
|
+
*/
|
|
1592
|
+
export declare function buildRewardOutputs(
|
|
1593
|
+
prompts: Array<string>,
|
|
1594
|
+
completions: Array<string>,
|
|
1595
|
+
tokenCounts: Array<number>,
|
|
1596
|
+
finishReasons: Array<string>,
|
|
1597
|
+
groupSize: number,
|
|
1598
|
+
): Array<RewardOutput>;
|
|
1599
|
+
|
|
1600
|
+
/** Configuration for built-in rewards */
|
|
1601
|
+
export interface BuiltinRewardConfig {
|
|
1602
|
+
/** Type of reward function */
|
|
1603
|
+
rewardType: BuiltinRewardType;
|
|
1604
|
+
/** Weight for this reward (default 1.0) */
|
|
1605
|
+
weight?: number;
|
|
1606
|
+
/** Allowed tool names (for ToolUse) */
|
|
1607
|
+
allowedTools?: Array<string>;
|
|
1608
|
+
/** Required tags (for XmlFormat) */
|
|
1609
|
+
requiredTags?: Array<string>;
|
|
1610
|
+
/** Minimum length (for Length) */
|
|
1611
|
+
minLength?: number;
|
|
1612
|
+
/** Maximum length (for Length) */
|
|
1613
|
+
maxLength?: number;
|
|
1614
|
+
/** Use character count vs word count (for Length) */
|
|
1615
|
+
useChars?: boolean;
|
|
1616
|
+
/** Required JSON fields (for JsonSchema) */
|
|
1617
|
+
requiredFields?: Array<string>;
|
|
1618
|
+
/** Whether tool call is required (for ToolUse) */
|
|
1619
|
+
required?: boolean;
|
|
1620
|
+
}
|
|
1621
|
+
|
|
1622
|
+
/** Built-in reward function types */
|
|
1623
|
+
export declare const enum BuiltinRewardType {
|
|
1624
|
+
/** Tool use validation */
|
|
1625
|
+
ToolUse = 'ToolUse',
|
|
1626
|
+
/** XML format validation */
|
|
1627
|
+
XmlFormat = 'XmlFormat',
|
|
1628
|
+
/** Length-based scoring */
|
|
1629
|
+
Length = 'Length',
|
|
1630
|
+
/** JSON schema validation */
|
|
1631
|
+
JsonSchema = 'JsonSchema',
|
|
1632
|
+
}
|
|
1633
|
+
|
|
1634
|
+
/**
|
|
1635
|
+
* Configuration for the high-level `chat()` API
|
|
1636
|
+
*
|
|
1637
|
+
* Combines tool definitions with generation parameters in a single config object.
|
|
1638
|
+
* Tools are optional - when not provided, `chat()` works as a simple conversational API.
|
|
1639
|
+
*
|
|
1640
|
+
* ## Example
|
|
1641
|
+
* ```typescript
|
|
1642
|
+
* // Simple chat (no tools)
|
|
1643
|
+
* const result = await model.chat(messages);
|
|
1644
|
+
*
|
|
1645
|
+
* // With tools
|
|
1646
|
+
* const result = await model.chat(messages, {
|
|
1647
|
+
* tools: [weatherTool, searchTool],
|
|
1648
|
+
* maxNewTokens: 2048,
|
|
1649
|
+
* temperature: 0.7,
|
|
1650
|
+
* });
|
|
1651
|
+
* ```
|
|
1652
|
+
*/
|
|
1653
|
+
export interface ChatConfig {
|
|
1654
|
+
/**
|
|
1655
|
+
* Tool definitions for function calling (optional)
|
|
1656
|
+
*
|
|
1657
|
+
* When provided, the model can invoke these tools during generation.
|
|
1658
|
+
* Tool calls are parsed and returned in `ChatResult.toolCalls`.
|
|
1659
|
+
*/
|
|
1660
|
+
tools?: Array<ToolDefinition>;
|
|
1661
|
+
/** Maximum number of new tokens to generate (default: 2048 for chat) */
|
|
1662
|
+
maxNewTokens?: number;
|
|
1663
|
+
/** Sampling temperature (0 = greedy, higher = more random) (default: 0.7) */
|
|
1664
|
+
temperature?: number;
|
|
1665
|
+
/** Top-k sampling: keep only top k tokens (0 = disabled) (default: 0) */
|
|
1666
|
+
topK?: number;
|
|
1667
|
+
/** Top-p (nucleus) sampling: keep tokens with cumulative prob < p (default: 0.9) */
|
|
1668
|
+
topP?: number;
|
|
1669
|
+
/** Min-p sampling: keep tokens with prob > min_p * max_prob (default: 0.0) */
|
|
1670
|
+
minP?: number;
|
|
1671
|
+
/** Repetition penalty factor (1.0 = no penalty) (default: 1.0) */
|
|
1672
|
+
repetitionPenalty?: number;
|
|
1673
|
+
/** Number of recent tokens to consider for repetition penalty (default: 20) */
|
|
1674
|
+
repetitionContextSize?: number;
|
|
1675
|
+
/** Stop if same token repeats this many times consecutively (default: 16) */
|
|
1676
|
+
maxConsecutiveTokens?: number;
|
|
1677
|
+
/** Stop if an n-gram pattern repeats this many times (default: 8) */
|
|
1678
|
+
maxNgramRepeats?: number;
|
|
1679
|
+
/** N-gram size for repetition detection (default: 3) */
|
|
1680
|
+
ngramSize?: number;
|
|
1681
|
+
/** EOS token ID (generation stops when this is generated) */
|
|
1682
|
+
eosTokenId?: number;
|
|
1683
|
+
/** Whether to return log probabilities (default: true) */
|
|
1684
|
+
returnLogprobs?: boolean;
|
|
1685
|
+
}
|
|
1686
|
+
|
|
1687
|
+
/** Chat message with tool calling support */
|
|
1688
|
+
export interface ChatMessage {
|
|
1689
|
+
/** Role: "system", "user", "assistant", or "tool" */
|
|
1690
|
+
role: string;
|
|
1691
|
+
/** Message content */
|
|
1692
|
+
content: string;
|
|
1693
|
+
/** Tool calls made by the assistant (for assistant messages) */
|
|
1694
|
+
toolCalls?: Array<ToolCall>;
|
|
1695
|
+
/** Tool call ID this message is responding to (for tool messages) */
|
|
1696
|
+
toolCallId?: string;
|
|
1697
|
+
/** Reasoning content for thinking mode (used with <think> tags) */
|
|
1698
|
+
reasoningContent?: string;
|
|
1699
|
+
}
|
|
1700
|
+
|
|
1701
|
+
/** Chat message role */
|
|
1702
|
+
export declare const enum ChatRole {
|
|
1703
|
+
/** User message */
|
|
1704
|
+
User = 'User',
|
|
1705
|
+
/** Assistant response */
|
|
1706
|
+
Assistant = 'Assistant',
|
|
1707
|
+
/** System prompt */
|
|
1708
|
+
System = 'System',
|
|
1709
|
+
}
|
|
1710
|
+
|
|
1711
|
+
/** Statistics about cleanup operations (NAPI wrapper) */
|
|
1712
|
+
export interface CleanupStats {
|
|
1713
|
+
/** Number of training steps deleted */
|
|
1714
|
+
stepsDeleted: number;
|
|
1715
|
+
/** Number of generations deleted */
|
|
1716
|
+
generationsDeleted: number;
|
|
1717
|
+
/** Number of tool calls deleted */
|
|
1718
|
+
toolCallsDeleted: number;
|
|
1719
|
+
/** Number of logs deleted */
|
|
1720
|
+
logsDeleted: number;
|
|
1721
|
+
}
|
|
1722
|
+
|
|
1723
|
+
/**
|
|
1724
|
+
* Structured completion information aligned with ChatResult.
|
|
1725
|
+
* Contains pre-parsed tool calls, thinking, and clean text.
|
|
1726
|
+
*/
|
|
1727
|
+
export interface CompletionInfo {
|
|
1728
|
+
/** Clean text with <tool_call> and <think> tags removed */
|
|
1729
|
+
text: string;
|
|
1730
|
+
/** Raw output before tag stripping (for debugging/XML parsing) */
|
|
1731
|
+
rawText: string;
|
|
1732
|
+
/** Parsed tool calls (arguments are already JS objects) */
|
|
1733
|
+
toolCalls: Array<ToolCallResult>;
|
|
1734
|
+
/** Extracted thinking/reasoning from <think> tags (null if none) */
|
|
1735
|
+
thinking?: string;
|
|
1736
|
+
/** Number of tokens generated */
|
|
1737
|
+
numTokens: number;
|
|
1738
|
+
/** Finish reason: "stop" | "length" | "tool_calls" */
|
|
1739
|
+
finishReason: string;
|
|
1740
|
+
}
|
|
1741
|
+
|
|
1742
|
+
export interface ConversionOptions {
|
|
1743
|
+
/** Input directory containing model files (config.json, model.safetensors) */
|
|
1744
|
+
inputDir: string;
|
|
1745
|
+
/** Output directory for converted model */
|
|
1746
|
+
outputDir: string;
|
|
1747
|
+
/** Target dtype for conversion (default: "float32") */
|
|
1748
|
+
dtype?: string;
|
|
1749
|
+
/** Whether to verbose logging (default: false) */
|
|
1750
|
+
verbose?: boolean;
|
|
1751
|
+
}
|
|
1752
|
+
|
|
1753
|
+
export interface ConversionResult {
|
|
1754
|
+
/** Number of tensors converted */
|
|
1755
|
+
numTensors: number;
|
|
1756
|
+
/** Total number of parameters */
|
|
1757
|
+
numParameters: number;
|
|
1758
|
+
/** Output model path */
|
|
1759
|
+
outputPath: string;
|
|
1760
|
+
/** List of converted tensor names */
|
|
1761
|
+
tensorNames: Array<string>;
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
/**
|
|
1765
|
+
* Convert a HuggingFace SafeTensors model to MLX format
|
|
1766
|
+
*
|
|
1767
|
+
* This function:
|
|
1768
|
+
* 1. Loads SafeTensors model from input directory
|
|
1769
|
+
* 2. Converts all tensors to specified dtype (default: float32)
|
|
1770
|
+
* 3. Saves converted model to output directory
|
|
1771
|
+
* 4. Copies config.json and tokenizer files
|
|
1772
|
+
*
|
|
1773
|
+
* # Arguments
|
|
1774
|
+
* * `options` - Conversion options (input_dir, output_dir, dtype, verbose)
|
|
1775
|
+
*
|
|
1776
|
+
* # Returns
|
|
1777
|
+
* * ConversionResult with statistics about the conversion
|
|
1778
|
+
*
|
|
1779
|
+
* # Example
|
|
1780
|
+
* ```typescript
|
|
1781
|
+
* import { convertModel } from '../../index.cjs';
|
|
1782
|
+
*
|
|
1783
|
+
* const result = await convertModel({
|
|
1784
|
+
* inputDir: '.cache/models/qwen3-0.6b',
|
|
1785
|
+
* outputDir: '.cache/models/qwen3-0.6b-mlx',
|
|
1786
|
+
* dtype: 'float32',
|
|
1787
|
+
* verbose: true
|
|
1788
|
+
* });
|
|
1789
|
+
*
|
|
1790
|
+
* console.log(`Converted ${result.numTensors} tensors (${result.numParameters} parameters)`);
|
|
1791
|
+
* ```
|
|
1792
|
+
*/
|
|
1793
|
+
export declare function convertModel(options: ConversionOptions): Promise<ConversionResult>;
|
|
1794
|
+
|
|
1795
|
+
export declare function convertParquetToJsonl(inputPath: string, outputPath: string): void;
|
|
1796
|
+
|
|
1797
|
+
/** Create a default PaddleOCR-VL 1.5 configuration (JS factory function) */
|
|
1798
|
+
export declare function createPaddleocrVlConfig(): ModelConfig;
|
|
1799
|
+
|
|
1800
|
+
/** Document element - either a table or paragraph */
|
|
1801
|
+
export interface DocumentElement {
|
|
1802
|
+
elementType: ElementType;
|
|
1803
|
+
/** Table data (only present if element_type is Table) */
|
|
1804
|
+
table?: Table;
|
|
1805
|
+
/** Paragraph data (only present if element_type is Paragraph) */
|
|
1806
|
+
paragraph?: Paragraph;
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
export declare const enum DType {
|
|
1810
|
+
Float32 = 0,
|
|
1811
|
+
Int32 = 1,
|
|
1812
|
+
Float16 = 2,
|
|
1813
|
+
BFloat16 = 3,
|
|
1814
|
+
Uint32 = 4,
|
|
1815
|
+
}
|
|
1816
|
+
|
|
1817
|
+
/** Document element type */
|
|
1818
|
+
export declare const enum ElementType {
|
|
1819
|
+
Table = 'Table',
|
|
1820
|
+
Paragraph = 'Paragraph',
|
|
1821
|
+
}
|
|
1822
|
+
|
|
1823
|
+
/** Metrics from a training epoch */
|
|
1824
|
+
export interface EngineEpochMetrics {
|
|
1825
|
+
/** Epoch number */
|
|
1826
|
+
epoch: number;
|
|
1827
|
+
/** Average loss for the epoch */
|
|
1828
|
+
avgLoss: number;
|
|
1829
|
+
/** Average reward for the epoch */
|
|
1830
|
+
avgReward: number;
|
|
1831
|
+
/** Total steps in the epoch */
|
|
1832
|
+
totalSteps: number;
|
|
1833
|
+
/** Total tokens processed */
|
|
1834
|
+
totalTokens: number;
|
|
1835
|
+
/** Time for the epoch (seconds) */
|
|
1836
|
+
epochTimeSecs: number;
|
|
1837
|
+
}
|
|
1838
|
+
|
|
1839
|
+
/** Metrics from a single training step */
|
|
1840
|
+
export interface EngineStepMetrics {
|
|
1841
|
+
/** Current step number */
|
|
1842
|
+
step: number;
|
|
1843
|
+
/** GRPO loss value */
|
|
1844
|
+
loss: number;
|
|
1845
|
+
/** Mean reward across completions */
|
|
1846
|
+
meanReward: number;
|
|
1847
|
+
/** Standard deviation of rewards */
|
|
1848
|
+
stdReward: number;
|
|
1849
|
+
/** Mean advantage value */
|
|
1850
|
+
meanAdvantage: number;
|
|
1851
|
+
/** Standard deviation of advantages */
|
|
1852
|
+
stdAdvantage: number;
|
|
1853
|
+
/** Total tokens generated this step */
|
|
1854
|
+
totalTokens: number;
|
|
1855
|
+
/** Whether gradients were applied */
|
|
1856
|
+
gradientsApplied: boolean;
|
|
1857
|
+
/** Time for generation (ms) */
|
|
1858
|
+
generationTimeMs: number;
|
|
1859
|
+
/** Time for training (ms) */
|
|
1860
|
+
trainingTimeMs: number;
|
|
1861
|
+
/** Peak memory usage this step (MB) */
|
|
1862
|
+
peakMemoryMb: number;
|
|
1863
|
+
/** Active memory at end of step (MB) */
|
|
1864
|
+
activeMemoryMb: number;
|
|
1865
|
+
}
|
|
1866
|
+
|
|
1867
|
+
/** Format parsed document according to config */
|
|
1868
|
+
export declare function formatDocument(doc: ParsedDocument, config?: ParserConfig | undefined | null): string;
|
|
1869
|
+
|
|
1870
|
+
/** Function definition for tool calling */
|
|
1871
|
+
export interface FunctionDefinition {
|
|
1872
|
+
/** Name of the function */
|
|
1873
|
+
name: string;
|
|
1874
|
+
/** Description of what the function does */
|
|
1875
|
+
description?: string;
|
|
1876
|
+
/** Parameter schema */
|
|
1877
|
+
parameters?: FunctionParameters;
|
|
1878
|
+
}
|
|
1879
|
+
|
|
1880
|
+
/** Function parameters schema (JSON Schema subset) */
|
|
1881
|
+
export interface FunctionParameters {
|
|
1882
|
+
/** Type (usually "object") */
|
|
1883
|
+
type: string;
|
|
1884
|
+
/** JSON string of property definitions */
|
|
1885
|
+
properties?: string;
|
|
1886
|
+
/** List of required parameter names */
|
|
1887
|
+
required?: Array<string>;
|
|
1888
|
+
}
|
|
1889
|
+
|
|
1890
|
+
/** Result from generate_batch_for_training with all data needed for training */
|
|
1891
|
+
export interface GenerateBatchResult {
|
|
1892
|
+
/** Generated completion texts */
|
|
1893
|
+
completionTexts: Array<string>;
|
|
1894
|
+
/** Completion token IDs (flattened, concatenated) */
|
|
1895
|
+
completionTokens: Array<number>;
|
|
1896
|
+
/** Completion log probabilities (flattened, concatenated) */
|
|
1897
|
+
completionLogprobs: Array<number>;
|
|
1898
|
+
/** Lengths of each completion (for reconstruction) */
|
|
1899
|
+
completionLengths: Array<number>;
|
|
1900
|
+
/** Finish reasons for each completion ("eos", "length", or "repetition") */
|
|
1901
|
+
finishReasons: Array<string>;
|
|
1902
|
+
}
|
|
1903
|
+
|
|
1904
|
+
/** Configuration for text generation */
|
|
1905
|
+
export interface GenerationConfig {
|
|
1906
|
+
/** Maximum number of new tokens to generate (default: 100) */
|
|
1907
|
+
maxNewTokens?: number;
|
|
1908
|
+
/** Sampling temperature (0 = greedy, higher = more random) (default: 1.0) */
|
|
1909
|
+
temperature?: number;
|
|
1910
|
+
/** Top-k sampling: keep only top k tokens (0 = disabled) (default: 0) */
|
|
1911
|
+
topK?: number;
|
|
1912
|
+
/** Top-p (nucleus) sampling: keep tokens with cumulative prob < p (default: 1.0) */
|
|
1913
|
+
topP?: number;
|
|
1914
|
+
/** Min-p sampling: keep tokens with prob > min_p * max_prob (default: 0.0) */
|
|
1915
|
+
minP?: number;
|
|
1916
|
+
/** Repetition penalty factor (1.0 = no penalty, 1.1-1.5 typical) (default: 1.0) */
|
|
1917
|
+
repetitionPenalty?: number;
|
|
1918
|
+
/**
|
|
1919
|
+
* Number of recent tokens to consider for repetition penalty (default: 20)
|
|
1920
|
+
* Matches mlx-lm default. Larger values catch longer patterns but use more memory
|
|
1921
|
+
*/
|
|
1922
|
+
repetitionContextSize?: number;
|
|
1923
|
+
/**
|
|
1924
|
+
* Stop if same token repeats this many times consecutively (default: 16)
|
|
1925
|
+
* Set to 0 to disable. Prevents OOM from degenerate repetitive generation.
|
|
1926
|
+
*/
|
|
1927
|
+
maxConsecutiveTokens?: number;
|
|
1928
|
+
/**
|
|
1929
|
+
* Stop if an n-gram pattern repeats this many times (default: 8)
|
|
1930
|
+
* Set to 0 to disable. Detects patterns like "A B A B A B A B".
|
|
1931
|
+
*/
|
|
1932
|
+
maxNgramRepeats?: number;
|
|
1933
|
+
/**
|
|
1934
|
+
* N-gram size for repetition detection (default: 3)
|
|
1935
|
+
* Used with max_ngram_repeats to detect repeating patterns.
|
|
1936
|
+
*/
|
|
1937
|
+
ngramSize?: number;
|
|
1938
|
+
/** EOS token ID (generation stops when this is generated) */
|
|
1939
|
+
eosTokenId?: number;
|
|
1940
|
+
/** Whether to return log probabilities (always true for GRPO) */
|
|
1941
|
+
returnLogprobs?: boolean;
|
|
1942
|
+
/**
|
|
1943
|
+
* Prefill step size for chunked processing of long prompts (default: 2048)
|
|
1944
|
+
* When the prompt length exceeds this value, it will be processed in chunks
|
|
1945
|
+
* to improve memory efficiency and enable async pipelining.
|
|
1946
|
+
* Set to 0 to disable chunking and process the entire prompt at once.
|
|
1947
|
+
*/
|
|
1948
|
+
prefillStepSize?: number;
|
|
1949
|
+
/**
|
|
1950
|
+
* KV cache quantization bits (default: 16 = no quantization)
|
|
1951
|
+
* - 16: Full precision (bfloat16/float16), no quantization
|
|
1952
|
+
* - 8: 8-bit quantization, ~2x memory savings, minimal quality loss
|
|
1953
|
+
* - 4: 4-bit quantization, ~4x memory savings, some quality degradation
|
|
1954
|
+
*
|
|
1955
|
+
* Quantized KV cache is useful for long sequences where memory becomes a bottleneck.
|
|
1956
|
+
* Note: Adds dequantization overhead per forward pass.
|
|
1957
|
+
*/
|
|
1958
|
+
kvCacheBits?: number;
|
|
1959
|
+
/**
|
|
1960
|
+
* KV cache quantization group size (default: 64)
|
|
1961
|
+
* Number of elements per quantization group. Smaller groups = better accuracy
|
|
1962
|
+
* but more overhead from storing scales/biases.
|
|
1963
|
+
* Only used when kv_cache_bits is 4 or 8.
|
|
1964
|
+
*/
|
|
1965
|
+
kvCacheGroupSize?: number;
|
|
1966
|
+
/**
|
|
1967
|
+
* Number of draft tokens to generate speculatively (default: 5)
|
|
1968
|
+
* Only used when a draft model is provided for speculative decoding.
|
|
1969
|
+
* Higher values can increase throughput but may reduce acceptance rate.
|
|
1970
|
+
*/
|
|
1971
|
+
numDraftTokens?: number;
|
|
1972
|
+
}
|
|
1973
|
+
|
|
1974
|
+
/** A generation record (NAPI wrapper) */
|
|
1975
|
+
export interface GenerationRecord {
|
|
1976
|
+
batchIndex: number;
|
|
1977
|
+
groupIndex: number;
|
|
1978
|
+
prompt: string;
|
|
1979
|
+
expectedAnswer?: string;
|
|
1980
|
+
completionText: string;
|
|
1981
|
+
completionRaw: string;
|
|
1982
|
+
thinking?: string;
|
|
1983
|
+
numTokens: number;
|
|
1984
|
+
finishReason: string;
|
|
1985
|
+
reward: number;
|
|
1986
|
+
}
|
|
1987
|
+
|
|
1988
|
+
/** A generation with its associated tool calls (NAPI wrapper) */
|
|
1989
|
+
export interface GenerationWithToolCalls {
|
|
1990
|
+
generation: GenerationRecord;
|
|
1991
|
+
toolCalls: Array<ToolCallRecord>;
|
|
1992
|
+
}
|
|
1993
|
+
|
|
1994
|
+
/** Get expected weight keys for PaddleOCR-VL model */
|
|
1995
|
+
export declare function getExpectedWeightKeys(): Array<string>;
|
|
1996
|
+
|
|
1997
|
+
/** Configuration for the GRPO training engine */
|
|
1998
|
+
export interface GrpoEngineConfig {
|
|
1999
|
+
/** Learning rate (default: 1e-6) */
|
|
2000
|
+
learningRate?: number;
|
|
2001
|
+
/** Gradient accumulation steps (default: 1) */
|
|
2002
|
+
gradientAccumulationSteps?: number;
|
|
2003
|
+
/** Maximum gradient norm for clipping (default: 1.0) */
|
|
2004
|
+
gradientClipNorm?: number;
|
|
2005
|
+
/**
|
|
2006
|
+
* Maximum gradient value for element-wise clipping (default: 1.0)
|
|
2007
|
+
* This clamps individual gradient elements to [-value, value]
|
|
2008
|
+
*/
|
|
2009
|
+
gradientClipValue?: number;
|
|
2010
|
+
/** Number of completions per prompt (default: 4) */
|
|
2011
|
+
groupSize?: number;
|
|
2012
|
+
/** PPO clipping epsilon (default: 0.2) */
|
|
2013
|
+
clipEpsilon?: number;
|
|
2014
|
+
/** KL divergence coefficient (default: 0.0) */
|
|
2015
|
+
klCoef?: number;
|
|
2016
|
+
/** Loss type: "grpo", "dapo", "dr_grpo", "bnpo" (default: "grpo") */
|
|
2017
|
+
lossType?: string;
|
|
2018
|
+
/**
|
|
2019
|
+
* Maximum completion length for both generation and training (default: 256)
|
|
2020
|
+
* Matches Python TRL's max_completion_length config.
|
|
2021
|
+
*/
|
|
2022
|
+
maxCompletionLength?: number;
|
|
2023
|
+
/** Sampling temperature (default: 0.8) */
|
|
2024
|
+
temperature?: number;
|
|
2025
|
+
/** Top-p (nucleus) sampling (default: 0.95) */
|
|
2026
|
+
topP?: number;
|
|
2027
|
+
/** Top-k sampling (optional) */
|
|
2028
|
+
topK?: number;
|
|
2029
|
+
/** Repetition penalty (default: 1.1) */
|
|
2030
|
+
repetitionPenalty?: number;
|
|
2031
|
+
/**
|
|
2032
|
+
* Maximum allowed NaN gradient occurrences before stopping training (default: 100)
|
|
2033
|
+
* When exceeded, training will stop with an error to prevent model corruption.
|
|
2034
|
+
*/
|
|
2035
|
+
maxNanGradients?: number;
|
|
2036
|
+
/**
|
|
2037
|
+
* Consecutive NaN gradients that trigger emergency checkpoint (default: 5)
|
|
2038
|
+
* When reached, the needs_emergency_save flag is set for the TypeScript layer.
|
|
2039
|
+
*/
|
|
2040
|
+
emergencySaveThreshold?: number;
|
|
2041
|
+
/**
|
|
2042
|
+
* Enable detailed NaN/Inf detection with per-element counts (default: false)
|
|
2043
|
+
* When false (default), uses GPU-native has_nan_or_inf() which only transfers a single
|
|
2044
|
+
* boolean to CPU. When true, transfers the entire gradient tensor to CPU for detailed
|
|
2045
|
+
* per-element analysis - useful for debugging but has significant performance overhead
|
|
2046
|
+
* for large models (e.g., 2.4GB for Qwen3-0.6B).
|
|
2047
|
+
*/
|
|
2048
|
+
verboseNanDetection?: boolean;
|
|
2049
|
+
/**
|
|
2050
|
+
* Enable thinking mode for Qwen3 models (default: true)
|
|
2051
|
+
* When false, adds empty <think></think> tags to disable model thinking.
|
|
2052
|
+
* This is useful for tool-use training where you want direct outputs.
|
|
2053
|
+
*/
|
|
2054
|
+
enableThinking?: boolean;
|
|
2055
|
+
/**
|
|
2056
|
+
* Tool definitions for function calling
|
|
2057
|
+
* When provided, tools are included in the chat template so the model
|
|
2058
|
+
* can generate tool calls. This is essential for tool-use training.
|
|
2059
|
+
*/
|
|
2060
|
+
tools?: Array<ToolDefinition>;
|
|
2061
|
+
/**
|
|
2062
|
+
* Batch chunk size for LM head computation (memory optimization).
|
|
2063
|
+
* When set, the LM head (hidden_states -> logits) is computed in chunks
|
|
2064
|
+
* of this size to reduce peak memory usage.
|
|
2065
|
+
* Default: None (no chunking, full batch at once)
|
|
2066
|
+
* Recommended: 2 for batch_size >= 4 with large vocabularies (e.g., 151936)
|
|
2067
|
+
* This reduces peak memory from ~1.2GB to ~300MB for Qwen3 (vocab=151936).
|
|
2068
|
+
*/
|
|
2069
|
+
lmHeadChunkSize?: number;
|
|
2070
|
+
/**
|
|
2071
|
+
* Batch chunk size for transformer forward pass (memory optimization).
|
|
2072
|
+
* When set, the transformer layers process the batch in chunks of this size,
|
|
2073
|
+
* reducing peak memory from O(batch × heads × seq²) for attention.
|
|
2074
|
+
* Default: None (no chunking, full batch at once)
|
|
2075
|
+
* Recommended: 4 for batch_size >= 4 with groupSize >= 4
|
|
2076
|
+
* Memory savings: ~70-80% for batch=4, groupSize=4 (16 sequences → 4 at a time)
|
|
2077
|
+
*/
|
|
2078
|
+
forwardChunkSize?: number;
|
|
2079
|
+
/**
|
|
2080
|
+
* Chunk size for vocabulary dimension in cross-entropy computation.
|
|
2081
|
+
* When computing logsumexp over large vocabularies (e.g., Qwen3's 151,936 tokens),
|
|
2082
|
+
* the computation is split into chunks of this size to reduce peak memory usage.
|
|
2083
|
+
* Default: 65536 (2^16)
|
|
2084
|
+
* Recommended: 65536 for Qwen3 (vocab=151936) splits into 3 chunks
|
|
2085
|
+
* Set to a larger value to reduce chunking overhead or smaller for tighter memory constraints.
|
|
2086
|
+
*/
|
|
2087
|
+
vocabChunkSize?: number;
|
|
2088
|
+
/**
|
|
2089
|
+
* Enable true parallel batch generation (default: false).
|
|
2090
|
+
* When true, all N*G sequences are processed in parallel using batched FFI
|
|
2091
|
+
* with per-sequence RoPE offsets. This provides 2-4x speedup for GRPO training.
|
|
2092
|
+
* When false, uses the sequential generation (process one prompt at a time,
|
|
2093
|
+
* then expand KV cache for G completions).
|
|
2094
|
+
*/
|
|
2095
|
+
useParallelBatchGeneration?: boolean;
|
|
2096
|
+
}
|
|
2097
|
+
|
|
2098
|
+
/** Configuration for GRPO loss computation */
|
|
2099
|
+
export interface GrpoLossConfig {
|
|
2100
|
+
/** Lower clipping bound (default: 0.2, means clip to [1-0.2, 1+epsilon_high]) */
|
|
2101
|
+
epsilonLow: number;
|
|
2102
|
+
/** Upper clipping bound (default: same as epsilon_low) */
|
|
2103
|
+
epsilonHigh?: number;
|
|
2104
|
+
/** KL divergence penalty coefficient (default: 0.0, no penalty) */
|
|
2105
|
+
beta: number;
|
|
2106
|
+
/** Loss aggregation type: "grpo", "bnpo", "dr_grpo", or "dapo" */
|
|
2107
|
+
lossType: string;
|
|
2108
|
+
/** Importance sampling level: "token" or "sequence" */
|
|
2109
|
+
importanceSamplingLevel: string;
|
|
2110
|
+
/**
|
|
2111
|
+
* Maximum completion length (legacy, no longer used by dr_grpo)
|
|
2112
|
+
* Kept for backwards compatibility but ignored in current implementation.
|
|
2113
|
+
*/
|
|
2114
|
+
maxCompletionLength?: number;
|
|
2115
|
+
/** Total number of items in batch across all processes (needed for dapo) */
|
|
2116
|
+
numItemsInBatch?: number;
|
|
2117
|
+
/** Current gradient accumulation step (for loss scaling) */
|
|
2118
|
+
gradientAccumulationSteps: number;
|
|
2119
|
+
/**
|
|
2120
|
+
* Batch chunk size for LM head computation (memory optimization).
|
|
2121
|
+
* When set, the LM head (hidden_states -> logits) is computed in chunks
|
|
2122
|
+
* of this size to reduce peak memory usage.
|
|
2123
|
+
* Default: None (no chunking, full batch at once)
|
|
2124
|
+
* Recommended: 2 for batch_size >= 4 with large vocabularies (e.g., 151936)
|
|
2125
|
+
*/
|
|
2126
|
+
lmHeadChunkSize?: number;
|
|
2127
|
+
/**
|
|
2128
|
+
* Batch chunk size for transformer forward pass (memory optimization).
|
|
2129
|
+
* When set, the transformer layers process the batch in chunks of this size,
|
|
2130
|
+
* reducing peak memory from O(batch × heads × seq²) for attention.
|
|
2131
|
+
* Default: None (no chunking, full batch at once)
|
|
2132
|
+
* Recommended: 4 for batch_size >= 4 with groupSize >= 4
|
|
2133
|
+
* Memory savings: ~70-80% for batch=4, groupSize=4 (16 sequences → 4 at a time)
|
|
2134
|
+
*/
|
|
2135
|
+
forwardChunkSize?: number;
|
|
2136
|
+
/**
|
|
2137
|
+
* Chunk size for vocabulary dimension in cross-entropy computation.
|
|
2138
|
+
* When computing logsumexp over large vocabularies (e.g., Qwen3's 151,936 tokens),
|
|
2139
|
+
* the computation is split into chunks of this size to reduce peak memory usage.
|
|
2140
|
+
* Default: 65536 (2^16)
|
|
2141
|
+
* Recommended: 65536 for Qwen3 (vocab=151936) splits into 3 chunks
|
|
2142
|
+
*/
|
|
2143
|
+
vocabChunkSize?: number;
|
|
2144
|
+
}
|
|
2145
|
+
|
|
2146
|
+
/** Full model configuration */
|
|
2147
|
+
export interface ModelConfig {
|
|
2148
|
+
visionConfig: VisionConfig;
|
|
2149
|
+
textConfig: TextConfig;
|
|
2150
|
+
modelType: string;
|
|
2151
|
+
ignoreIndex: number;
|
|
2152
|
+
imageTokenId: number;
|
|
2153
|
+
videoTokenId: number;
|
|
2154
|
+
visionStartTokenId: number;
|
|
2155
|
+
visionEndTokenId: number;
|
|
2156
|
+
eosTokenId: number;
|
|
2157
|
+
}
|
|
2158
|
+
|
|
2159
|
+
/** Output format options */
|
|
2160
|
+
export declare const enum OutputFormat {
|
|
2161
|
+
/** Raw output with minimal processing */
|
|
2162
|
+
Raw = 'Raw',
|
|
2163
|
+
/** Plain text with aligned columns */
|
|
2164
|
+
Plain = 'Plain',
|
|
2165
|
+
/** Markdown tables */
|
|
2166
|
+
Markdown = 'Markdown',
|
|
2167
|
+
/** HTML tables */
|
|
2168
|
+
Html = 'Html',
|
|
2169
|
+
}
|
|
2170
|
+
|
|
2171
|
+
/** Configuration for creating an OutputStore connection */
|
|
2172
|
+
export interface OutputStoreConfig {
|
|
2173
|
+
/** Local SQLite file path (e.g., "training_outputs.db") */
|
|
2174
|
+
localPath: string;
|
|
2175
|
+
}
|
|
2176
|
+
|
|
2177
|
+
/** Paged attention memory statistics (NAPI-compatible) */
|
|
2178
|
+
export interface PagedCacheStats {
|
|
2179
|
+
/** Total number of blocks in the pool */
|
|
2180
|
+
totalBlocks: number;
|
|
2181
|
+
/** Number of free blocks */
|
|
2182
|
+
freeBlocks: number;
|
|
2183
|
+
/** Number of allocated blocks */
|
|
2184
|
+
allocatedBlocks: number;
|
|
2185
|
+
/** Total memory in MB */
|
|
2186
|
+
totalMemoryMb: number;
|
|
2187
|
+
/** Used memory in MB */
|
|
2188
|
+
usedMemoryMb: number;
|
|
2189
|
+
/** Utilization percentage */
|
|
2190
|
+
utilizationPercent: number;
|
|
2191
|
+
}
|
|
2192
|
+
|
|
2193
|
+
/** A completed sequence from paged generation */
|
|
2194
|
+
export interface PagedCompletedSequence {
|
|
2195
|
+
/** Original request ID */
|
|
2196
|
+
requestId: string;
|
|
2197
|
+
/** All generated tokens (excluding prompt) */
|
|
2198
|
+
tokens: Array<number>;
|
|
2199
|
+
/** Reason for completion ("eos", "max_tokens", etc.) */
|
|
2200
|
+
finishReason: string;
|
|
2201
|
+
}
|
|
2202
|
+
|
|
2203
|
+
/** Result of a paged generation step */
|
|
2204
|
+
export interface PagedGenerationStep {
|
|
2205
|
+
/** Token outputs for each sequence in the batch */
|
|
2206
|
+
outputs: Array<PagedTokenOutput>;
|
|
2207
|
+
/** Number of sequences that were in prefill phase */
|
|
2208
|
+
numPrefill: number;
|
|
2209
|
+
/** Number of sequences that were in decode phase */
|
|
2210
|
+
numDecode: number;
|
|
2211
|
+
}
|
|
2212
|
+
|
|
2213
|
+
/** Output from a single token generation step in paged attention */
|
|
2214
|
+
export interface PagedTokenOutput {
|
|
2215
|
+
/** Sequence ID in the scheduler */
|
|
2216
|
+
seqId: number;
|
|
2217
|
+
/** Request ID for this sequence */
|
|
2218
|
+
requestId: string;
|
|
2219
|
+
/** Generated token ID */
|
|
2220
|
+
token: number;
|
|
2221
|
+
/** Log probability of the token (f64 for NAPI compatibility) */
|
|
2222
|
+
logprob: number;
|
|
2223
|
+
/** Whether this sequence has finished */
|
|
2224
|
+
isFinished: boolean;
|
|
2225
|
+
}
|
|
2226
|
+
|
|
2227
|
+
/** A text paragraph */
|
|
2228
|
+
export interface Paragraph {
|
|
2229
|
+
content: string;
|
|
2230
|
+
}
|
|
2231
|
+
|
|
2232
|
+
/** Parsed document structure */
|
|
2233
|
+
export interface ParsedDocument {
|
|
2234
|
+
elements: Array<DocumentElement>;
|
|
2235
|
+
}
|
|
2236
|
+
|
|
2237
|
+
/**
|
|
2238
|
+
* Parse and format PaddleOCR-VL response in one step
|
|
2239
|
+
*
|
|
2240
|
+
* Convenience function that parses the VLM output and formats it
|
|
2241
|
+
* according to the specified configuration.
|
|
2242
|
+
*
|
|
2243
|
+
* # Arguments
|
|
2244
|
+
* * `text` - Raw VLM output containing table tokens
|
|
2245
|
+
* * `config` - Optional parser configuration (format, trim_cells, etc.)
|
|
2246
|
+
*
|
|
2247
|
+
* # Returns
|
|
2248
|
+
* * Formatted string in the requested format (markdown, plain, html, raw)
|
|
2249
|
+
*
|
|
2250
|
+
* # Example
|
|
2251
|
+
* ```typescript
|
|
2252
|
+
* import { parsePaddleResponse } from '@mlx-node/core';
|
|
2253
|
+
*
|
|
2254
|
+
* // Parse and format as markdown (default)
|
|
2255
|
+
* const markdown = parsePaddleResponse(vlmResult.text);
|
|
2256
|
+
*
|
|
2257
|
+
* // Parse and format as HTML
|
|
2258
|
+
* const html = parsePaddleResponse(vlmResult.text, { format: 'html' });
|
|
2259
|
+
*
|
|
2260
|
+
* // Parse and format as plain text
|
|
2261
|
+
* const plain = parsePaddleResponse(vlmResult.text, { format: 'plain' });
|
|
2262
|
+
* ```
|
|
2263
|
+
*/
|
|
2264
|
+
export declare function parsePaddleResponse(text: string, config?: ParserConfig | undefined | null): string;
|
|
2265
|
+
|
|
2266
|
+
/** Parser configuration */
|
|
2267
|
+
export interface ParserConfig {
|
|
2268
|
+
/** Output format (default: 'markdown') */
|
|
2269
|
+
format?: OutputFormat;
|
|
2270
|
+
/** Whether to trim whitespace from cells (default: true) */
|
|
2271
|
+
trimCells?: boolean;
|
|
2272
|
+
/** Whether to collapse empty rows (default: true) */
|
|
2273
|
+
collapseEmptyRows?: boolean;
|
|
2274
|
+
}
|
|
2275
|
+
|
|
2276
|
+
/**
|
|
2277
|
+
* Parse tool calls from text (NAPI export)
|
|
2278
|
+
*
|
|
2279
|
+
* Extracts tool calls from model-generated text and returns both the cleaned text
|
|
2280
|
+
* and the parsed tool calls.
|
|
2281
|
+
*
|
|
2282
|
+
* # Example
|
|
2283
|
+
* ```typescript
|
|
2284
|
+
* import { parseToolCallsFromText } from '@mlx-node/core';
|
|
2285
|
+
*
|
|
2286
|
+
* const result = parseToolCallsFromText('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>');
|
|
2287
|
+
* console.log(result.text); // ""
|
|
2288
|
+
* console.log(result.toolCalls[0].name); // "search"
|
|
2289
|
+
* console.log(result.toolCalls[0].arguments.q); // "test"
|
|
2290
|
+
* ```
|
|
2291
|
+
*/
|
|
2292
|
+
export declare function parseToolCallsFromText(text: string): ParseToolCallsResult;
|
|
2293
|
+
|
|
2294
|
+
/** Result of parsing tool calls from text */
|
|
2295
|
+
export interface ParseToolCallsResult {
|
|
2296
|
+
/** Cleaned text with tool_call tags removed */
|
|
2297
|
+
text: string;
|
|
2298
|
+
/** Parsed tool calls */
|
|
2299
|
+
toolCalls: Array<ToolCallResult>;
|
|
2300
|
+
}
|
|
2301
|
+
|
|
2302
|
+
/** Parse VLM output into structured document */
|
|
2303
|
+
export declare function parseVlmOutput(text: string): ParsedDocument;
|
|
2304
|
+
|
|
2305
|
+
/** Qwen3 model configuration */
|
|
2306
|
+
export interface Qwen3Config {
|
|
2307
|
+
vocabSize: number;
|
|
2308
|
+
hiddenSize: number;
|
|
2309
|
+
numLayers: number;
|
|
2310
|
+
numHeads: number;
|
|
2311
|
+
numKvHeads: number;
|
|
2312
|
+
intermediateSize: number;
|
|
2313
|
+
rmsNormEps: number;
|
|
2314
|
+
ropeTheta: number;
|
|
2315
|
+
maxPositionEmbeddings: number;
|
|
2316
|
+
headDim: number;
|
|
2317
|
+
useQkNorm: boolean;
|
|
2318
|
+
tieWordEmbeddings: boolean;
|
|
2319
|
+
padTokenId: number;
|
|
2320
|
+
eosTokenId: number;
|
|
2321
|
+
bosTokenId: number;
|
|
2322
|
+
/**
|
|
2323
|
+
* Enable paged attention for memory-efficient inference.
|
|
2324
|
+
* Default: false (use standard KVCache)
|
|
2325
|
+
*/
|
|
2326
|
+
usePagedAttention?: boolean | undefined;
|
|
2327
|
+
/**
|
|
2328
|
+
* GPU memory budget for paged KV cache in megabytes.
|
|
2329
|
+
* Only used when use_paged_attention is true.
|
|
2330
|
+
* Default: 2048 (2GB)
|
|
2331
|
+
*/
|
|
2332
|
+
pagedCacheMemoryMb?: number | undefined;
|
|
2333
|
+
/**
|
|
2334
|
+
* Block size for paged attention (tokens per block).
|
|
2335
|
+
* Only used when use_paged_attention is true.
|
|
2336
|
+
* Default: 16
|
|
2337
|
+
*/
|
|
2338
|
+
pagedBlockSize?: number | undefined;
|
|
2339
|
+
/**
|
|
2340
|
+
* Use FP8 cache for 2x memory reduction (experimental).
|
|
2341
|
+
* Only used when use_paged_attention is true.
|
|
2342
|
+
* Default: false
|
|
2343
|
+
*/
|
|
2344
|
+
useFp8Cache?: boolean | undefined;
|
|
2345
|
+
}
|
|
2346
|
+
|
|
2347
|
+
/** Result of resume position computation */
|
|
2348
|
+
export interface ResumePosition {
|
|
2349
|
+
/** Epoch to start from (0-indexed) */
|
|
2350
|
+
startEpoch: number;
|
|
2351
|
+
/** Batch index within epoch to start from */
|
|
2352
|
+
startBatchIdx: number;
|
|
2353
|
+
/** Whether we're at an epoch boundary */
|
|
2354
|
+
isEpochBoundary: boolean;
|
|
2355
|
+
}
|
|
2356
|
+
|
|
2357
|
+
/**
|
|
2358
|
+
* Reward function input for a single completion.
|
|
2359
|
+
* Provides all context needed to compute a reward score.
|
|
2360
|
+
*/
|
|
2361
|
+
export interface RewardOutput {
|
|
2362
|
+
/** The input prompt text */
|
|
2363
|
+
prompt: string;
|
|
2364
|
+
/** Structured completion data aligned with ChatResult */
|
|
2365
|
+
completion: CompletionInfo;
|
|
2366
|
+
}
|
|
2367
|
+
|
|
2368
|
+
/** Reward distribution statistics (NAPI wrapper) */
|
|
2369
|
+
export interface RewardStats {
|
|
2370
|
+
count: number;
|
|
2371
|
+
mean: number;
|
|
2372
|
+
std: number;
|
|
2373
|
+
min: number;
|
|
2374
|
+
max: number;
|
|
2375
|
+
median: number;
|
|
2376
|
+
p25: number;
|
|
2377
|
+
p75: number;
|
|
2378
|
+
}
|
|
2379
|
+
|
|
2380
|
+
/** Aggregate statistics for a training run for resume state (NAPI wrapper) */
|
|
2381
|
+
export interface RunAggregates {
|
|
2382
|
+
/** Best (highest) reward seen */
|
|
2383
|
+
bestReward: number;
|
|
2384
|
+
/** Average reward */
|
|
2385
|
+
avgReward: number;
|
|
2386
|
+
/** Total reward count */
|
|
2387
|
+
rewardCount: number;
|
|
2388
|
+
/** Best (lowest) loss seen */
|
|
2389
|
+
bestLoss: number;
|
|
2390
|
+
/** Average loss */
|
|
2391
|
+
avgLoss: number;
|
|
2392
|
+
/** Total loss count */
|
|
2393
|
+
lossCount: number;
|
|
2394
|
+
/** Total tokens generated */
|
|
2395
|
+
totalTokens: number;
|
|
2396
|
+
/** Current step number */
|
|
2397
|
+
currentStep: number;
|
|
2398
|
+
/** Average generation time (milliseconds) */
|
|
2399
|
+
avgGenerationTimeMs: number;
|
|
2400
|
+
/** Average training time (milliseconds) */
|
|
2401
|
+
avgTrainingTimeMs: number;
|
|
2402
|
+
}
|
|
2403
|
+
|
|
2404
|
+
/**
|
|
2405
|
+
* Configuration for sampling strategies
|
|
2406
|
+
* ⚡ PERFORMANCE: Made Copy to avoid cloning on every token
|
|
2407
|
+
*/
|
|
2408
|
+
export interface SamplingConfig {
|
|
2409
|
+
/** Temperature for softmax (default: 1.0). Lower = more deterministic */
|
|
2410
|
+
temperature?: number;
|
|
2411
|
+
/** Number of top tokens to keep (top-k sampling). 0 = disabled */
|
|
2412
|
+
topK?: number;
|
|
2413
|
+
/** Cumulative probability threshold (top-p/nucleus sampling). 1.0 = disabled */
|
|
2414
|
+
topP?: number;
|
|
2415
|
+
/** Minimum probability threshold relative to max (min-p sampling). 0 = disabled */
|
|
2416
|
+
minP?: number;
|
|
2417
|
+
}
|
|
2418
|
+
|
|
2419
|
+
/** Scheduler statistics (NAPI-compatible) */
|
|
2420
|
+
export interface SchedulerStatsNapi {
|
|
2421
|
+
/** Number of requests waiting to be scheduled */
|
|
2422
|
+
numWaiting: number;
|
|
2423
|
+
/** Number of sequences currently running */
|
|
2424
|
+
numRunning: number;
|
|
2425
|
+
/** Number of completed sequences */
|
|
2426
|
+
numCompleted: number;
|
|
2427
|
+
/** Number of sequences in prefill phase */
|
|
2428
|
+
numPrefill: number;
|
|
2429
|
+
/** Number of sequences in decode phase */
|
|
2430
|
+
numDecode: number;
|
|
2431
|
+
/** Total tokens across all running sequences */
|
|
2432
|
+
totalRunningTokens: number;
|
|
2433
|
+
}
|
|
2434
|
+
|
|
2435
|
+
/** Configuration for the SFT training engine */
|
|
2436
|
+
export interface SftEngineConfig {
|
|
2437
|
+
/** Learning rate (default: 2e-5) */
|
|
2438
|
+
learningRate?: number;
|
|
2439
|
+
/** Gradient accumulation steps (default: 1) */
|
|
2440
|
+
gradientAccumulationSteps?: number;
|
|
2441
|
+
/** Maximum gradient norm for clipping (default: 1.0) */
|
|
2442
|
+
gradientClipNorm?: number;
|
|
2443
|
+
/** Maximum gradient value for element-wise clipping (optional) */
|
|
2444
|
+
gradientClipValue?: number;
|
|
2445
|
+
/** Weight decay (L2 regularization) (default: 0.01) */
|
|
2446
|
+
weightDecay?: number;
|
|
2447
|
+
/** Label smoothing factor (default: 0.0) */
|
|
2448
|
+
labelSmoothing?: number;
|
|
2449
|
+
/** Steps between heavy cleanup (default: 25) */
|
|
2450
|
+
heavyCleanupInterval?: number;
|
|
2451
|
+
/** Maximum allowed NaN gradient occurrences (default: 100) */
|
|
2452
|
+
maxNanGradients?: number;
|
|
2453
|
+
/** Consecutive NaN gradients that trigger emergency checkpoint (default: 5) */
|
|
2454
|
+
emergencySaveThreshold?: number;
|
|
2455
|
+
/** Compute token accuracy (requires extra forward pass) (default: false) */
|
|
2456
|
+
computeAccuracy?: boolean;
|
|
2457
|
+
/**
|
|
2458
|
+
* Enable detailed NaN/Inf detection with per-element counts (default: false)
|
|
2459
|
+
* When false (default), uses GPU-native has_nan_or_inf() which only transfers a single
|
|
2460
|
+
* boolean to CPU. When true, transfers the entire gradient tensor to CPU for detailed
|
|
2461
|
+
* per-element analysis - useful for debugging but has significant performance overhead.
|
|
2462
|
+
*/
|
|
2463
|
+
verboseNanDetection?: boolean;
|
|
2464
|
+
}
|
|
2465
|
+
|
|
2466
|
+
/** Metrics from a training epoch */
|
|
2467
|
+
export interface SftEpochMetrics {
|
|
2468
|
+
/** Epoch number */
|
|
2469
|
+
epoch: number;
|
|
2470
|
+
/** Average loss for the epoch */
|
|
2471
|
+
avgLoss: number;
|
|
2472
|
+
/** Total steps in the epoch */
|
|
2473
|
+
totalSteps: number;
|
|
2474
|
+
/** Total tokens processed */
|
|
2475
|
+
totalTokens: number;
|
|
2476
|
+
/** Time for the epoch (seconds) */
|
|
2477
|
+
epochTimeSecs: number;
|
|
2478
|
+
}
|
|
2479
|
+
|
|
2480
|
+
/** Metrics from a single training step */
|
|
2481
|
+
export interface SftStepMetrics {
|
|
2482
|
+
/** Current step number */
|
|
2483
|
+
step: number;
|
|
2484
|
+
/** Cross-entropy loss value */
|
|
2485
|
+
loss: number;
|
|
2486
|
+
/** Total tokens processed this step (non-ignored) */
|
|
2487
|
+
totalTokens: number;
|
|
2488
|
+
/** Token-level accuracy (if compute_accuracy enabled) */
|
|
2489
|
+
tokenAccuracy?: number;
|
|
2490
|
+
/** Whether gradients were applied (vs accumulated) */
|
|
2491
|
+
gradientsApplied: boolean;
|
|
2492
|
+
/** Time for training step (ms) */
|
|
2493
|
+
trainingTimeMs: number;
|
|
2494
|
+
}
|
|
2495
|
+
|
|
2496
|
+
/** Metrics from a single training step for sparkline restoration (NAPI wrapper) */
|
|
2497
|
+
export interface StepMetricSummary {
|
|
2498
|
+
/** Step number */
|
|
2499
|
+
step: number;
|
|
2500
|
+
/** Loss value */
|
|
2501
|
+
loss: number;
|
|
2502
|
+
/** Mean reward (GRPO) */
|
|
2503
|
+
meanReward: number;
|
|
2504
|
+
/** Mean advantage (GRPO) */
|
|
2505
|
+
meanAdvantage: number;
|
|
2506
|
+
/** Std advantage (GRPO) - indicates reward variance within groups */
|
|
2507
|
+
stdAdvantage: number;
|
|
2508
|
+
/** Perplexity (SFT, optional) */
|
|
2509
|
+
perplexity?: number;
|
|
2510
|
+
/** Token accuracy (SFT, optional) */
|
|
2511
|
+
tokenAccuracy?: number;
|
|
2512
|
+
/** Total tokens this step */
|
|
2513
|
+
totalTokens: number;
|
|
2514
|
+
/** Time for generation phase (milliseconds) */
|
|
2515
|
+
generationTimeMs?: number;
|
|
2516
|
+
/** Time for training phase (milliseconds) */
|
|
2517
|
+
trainingTimeMs?: number;
|
|
2518
|
+
}
|
|
2519
|
+
|
|
2520
|
+
/** A training step record (NAPI wrapper) */
|
|
2521
|
+
export interface StepRecord {
|
|
2522
|
+
runId: string;
|
|
2523
|
+
step: number;
|
|
2524
|
+
epoch?: number;
|
|
2525
|
+
loss: number;
|
|
2526
|
+
meanReward: number;
|
|
2527
|
+
stdReward: number;
|
|
2528
|
+
meanAdvantage?: number;
|
|
2529
|
+
stdAdvantage: number;
|
|
2530
|
+
totalTokens?: number;
|
|
2531
|
+
generationTimeMs?: number;
|
|
2532
|
+
trainingTimeMs?: number;
|
|
2533
|
+
gradientsApplied: boolean;
|
|
2534
|
+
}
|
|
2535
|
+
|
|
2536
|
+
/** Summary of a training step (NAPI wrapper) */
|
|
2537
|
+
export interface StepSummary {
|
|
2538
|
+
step: number;
|
|
2539
|
+
loss: number;
|
|
2540
|
+
meanReward: number;
|
|
2541
|
+
numGenerations: number;
|
|
2542
|
+
numToolCalls: number;
|
|
2543
|
+
eosCount: number;
|
|
2544
|
+
lengthCount: number;
|
|
2545
|
+
}
|
|
2546
|
+
|
|
2547
|
+
/** A table structure */
|
|
2548
|
+
export interface Table {
|
|
2549
|
+
rows: Array<TableRow>;
|
|
2550
|
+
}
|
|
2551
|
+
|
|
2552
|
+
/** A single cell in a table */
|
|
2553
|
+
export interface TableCell {
|
|
2554
|
+
content: string;
|
|
2555
|
+
isEmpty: boolean;
|
|
2556
|
+
}
|
|
2557
|
+
|
|
2558
|
+
/** A row in a table */
|
|
2559
|
+
export interface TableRow {
|
|
2560
|
+
cells: Array<TableCell>;
|
|
2561
|
+
}
|
|
2562
|
+
|
|
2563
|
+
/** Language model (text decoder) configuration */
|
|
2564
|
+
export interface TextConfig {
|
|
2565
|
+
modelType: string;
|
|
2566
|
+
hiddenSize: number;
|
|
2567
|
+
numHiddenLayers: number;
|
|
2568
|
+
intermediateSize: number;
|
|
2569
|
+
numAttentionHeads: number;
|
|
2570
|
+
rmsNormEps: number;
|
|
2571
|
+
vocabSize: number;
|
|
2572
|
+
numKeyValueHeads: number;
|
|
2573
|
+
maxPositionEmbeddings: number;
|
|
2574
|
+
ropeTheta: number;
|
|
2575
|
+
ropeTraditional: boolean;
|
|
2576
|
+
useBias: boolean;
|
|
2577
|
+
headDim: number;
|
|
2578
|
+
/**
|
|
2579
|
+
* Multimodal RoPE sections: [temporal, height, width]
|
|
2580
|
+
* These define how the head_dim is split for 3D position encoding
|
|
2581
|
+
*/
|
|
2582
|
+
mropeSection: Array<number>;
|
|
2583
|
+
}
|
|
2584
|
+
|
|
2585
|
+
/** Tool call made by an assistant */
|
|
2586
|
+
export interface ToolCall {
|
|
2587
|
+
/** Optional unique identifier for the tool call */
|
|
2588
|
+
id?: string;
|
|
2589
|
+
/** Name of the tool/function to call */
|
|
2590
|
+
name: string;
|
|
2591
|
+
/** JSON string of arguments to pass to the tool */
|
|
2592
|
+
arguments: string;
|
|
2593
|
+
}
|
|
2594
|
+
|
|
2595
|
+
/** A tool call record (NAPI wrapper) */
|
|
2596
|
+
export interface ToolCallRecord {
|
|
2597
|
+
callIndex: number;
|
|
2598
|
+
status: string;
|
|
2599
|
+
toolName?: string;
|
|
2600
|
+
arguments?: string;
|
|
2601
|
+
rawContent: string;
|
|
2602
|
+
errorMessage?: string;
|
|
2603
|
+
}
|
|
2604
|
+
|
|
2605
|
+
/** Structured tool call with parsed arguments */
|
|
2606
|
+
export interface ToolCallResult {
|
|
2607
|
+
/** Unique identifier for this tool call (format: call_<uuid>) */
|
|
2608
|
+
id: string;
|
|
2609
|
+
/** Name of the tool/function to call */
|
|
2610
|
+
name: string;
|
|
2611
|
+
/**
|
|
2612
|
+
* Parsed arguments as native object (serde_json::Value -> JS object)
|
|
2613
|
+
*
|
|
2614
|
+
* When status is "ok", this contains the parsed arguments object.
|
|
2615
|
+
* When status is "parse_error", this contains the original unparsed string.
|
|
2616
|
+
* Otherwise, this is an empty object {}.
|
|
2617
|
+
*/
|
|
2618
|
+
arguments: Record<string, unknown> | string;
|
|
2619
|
+
/**
|
|
2620
|
+
* Parsing status: "ok" | "invalid_json" | "missing_name" | "parse_error"
|
|
2621
|
+
*
|
|
2622
|
+
* - "ok": Successfully parsed tool call
|
|
2623
|
+
* - "invalid_json": The tool_call tag content was not valid JSON
|
|
2624
|
+
* - "missing_name": Valid JSON but no "name" field
|
|
2625
|
+
* - "parse_error": Valid JSON but the "arguments" string field couldn't be parsed as JSON
|
|
2626
|
+
*/
|
|
2627
|
+
status: string;
|
|
2628
|
+
/** Error message if status != "ok" */
|
|
2629
|
+
error?: string;
|
|
2630
|
+
/**
|
|
2631
|
+
* Raw content from <tool_call> tag (preserved for debugging/persistence)
|
|
2632
|
+
* Defaults to empty string for backward compatibility with older JSON
|
|
2633
|
+
*/
|
|
2634
|
+
rawContent: string;
|
|
2635
|
+
}
|
|
2636
|
+
|
|
2637
|
+
/** OpenAI-compatible tool definition */
|
|
2638
|
+
export interface ToolDefinition {
|
|
2639
|
+
/** Tool type (currently only "function" is supported) */
|
|
2640
|
+
type: string;
|
|
2641
|
+
/** Function definition */
|
|
2642
|
+
function: FunctionDefinition;
|
|
2643
|
+
}
|
|
2644
|
+
|
|
2645
|
+
/** A training run record (NAPI wrapper) */
|
|
2646
|
+
export interface TrainingRunRecord {
|
|
2647
|
+
id: string;
|
|
2648
|
+
name?: string;
|
|
2649
|
+
modelName: string;
|
|
2650
|
+
modelPath?: string;
|
|
2651
|
+
config: string;
|
|
2652
|
+
startedAt: number;
|
|
2653
|
+
endedAt?: number;
|
|
2654
|
+
totalSteps: number;
|
|
2655
|
+
status: string;
|
|
2656
|
+
}
|
|
2657
|
+
|
|
2658
|
+
/** Result from train_step_auto including metrics, completions, and rewards */
|
|
2659
|
+
export interface TrainStepResult {
|
|
2660
|
+
/** Training metrics */
|
|
2661
|
+
metrics: EngineStepMetrics;
|
|
2662
|
+
/** Generated completion texts (for TUI logging) */
|
|
2663
|
+
completions: Array<string>;
|
|
2664
|
+
/** Computed reward values (for TUI logging) */
|
|
2665
|
+
rewards: Array<number>;
|
|
2666
|
+
}
|
|
2667
|
+
|
|
2668
|
+
/** Result from train_step_auto_with_recording including optional full RewardOutput data */
|
|
2669
|
+
export interface TrainStepResultWithOutputs {
|
|
2670
|
+
/** Training metrics */
|
|
2671
|
+
metrics: EngineStepMetrics;
|
|
2672
|
+
/** Generated completion texts (for TUI logging) */
|
|
2673
|
+
completions: Array<string>;
|
|
2674
|
+
/** Computed reward values (for TUI logging) */
|
|
2675
|
+
rewards: Array<number>;
|
|
2676
|
+
/**
|
|
2677
|
+
* Full RewardOutput data as JSON (only populated when record_outputs is true)
|
|
2678
|
+
* This enables zero-copy persistence of training outputs
|
|
2679
|
+
*/
|
|
2680
|
+
outputsJson?: string;
|
|
2681
|
+
/** Actual token counts for each completion (for accurate TUI display) */
|
|
2682
|
+
completionLengths: Array<number>;
|
|
2683
|
+
}
|
|
2684
|
+
|
|
2685
|
+
/** Vision encoder configuration */
|
|
2686
|
+
export interface VisionConfig {
|
|
2687
|
+
modelType: string;
|
|
2688
|
+
hiddenSize: number;
|
|
2689
|
+
intermediateSize: number;
|
|
2690
|
+
numHiddenLayers: number;
|
|
2691
|
+
numAttentionHeads: number;
|
|
2692
|
+
numChannels: number;
|
|
2693
|
+
imageSize: number;
|
|
2694
|
+
patchSize: number;
|
|
2695
|
+
hiddenAct: string;
|
|
2696
|
+
layerNormEps: number;
|
|
2697
|
+
attentionDropout: number;
|
|
2698
|
+
spatialMergeSize: number;
|
|
2699
|
+
}
|
|
2700
|
+
|
|
2701
|
+
/** Configuration for VLM chat */
|
|
2702
|
+
export interface VlmChatConfig {
|
|
2703
|
+
/**
|
|
2704
|
+
* Image paths to process (alternative to passing pre-processed images)
|
|
2705
|
+
* These will be automatically processed using the ImageProcessor
|
|
2706
|
+
*/
|
|
2707
|
+
imagePaths?: Array<string>;
|
|
2708
|
+
/** Maximum number of new tokens to generate (default: 512) */
|
|
2709
|
+
maxNewTokens?: number;
|
|
2710
|
+
/** Sampling temperature (0 = greedy, higher = more random) (default: 0.0 for OCR) */
|
|
2711
|
+
temperature?: number;
|
|
2712
|
+
/** Top-k sampling (default: 0) */
|
|
2713
|
+
topK?: number;
|
|
2714
|
+
/** Top-p (nucleus) sampling (default: 1.0) */
|
|
2715
|
+
topP?: number;
|
|
2716
|
+
/** Repetition penalty (default: 1.5) */
|
|
2717
|
+
repetitionPenalty?: number;
|
|
2718
|
+
/** Whether to return log probabilities (default: false) */
|
|
2719
|
+
returnLogprobs?: boolean;
|
|
2720
|
+
}
|
|
2721
|
+
|
|
2722
|
+
/** A chat message with optional image */
|
|
2723
|
+
export interface VlmChatMessage {
|
|
2724
|
+
/** Role of the message sender */
|
|
2725
|
+
role: ChatRole;
|
|
2726
|
+
/** Text content of the message */
|
|
2727
|
+
content: string;
|
|
2728
|
+
}
|