@kernel.chat/kbot 2.24.0 → 2.25.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2313 @@
1
+ // K:BOT Training & Fine-Tuning Tools
2
+ //
3
+ // Prepare datasets, launch fine-tuning jobs (cloud + local), monitor training,
4
+ // evaluate results, export/convert models, deploy, and estimate costs.
5
+ //
6
+ // Tools:
7
+ // train_prepare — Convert data into training formats (JSONL/Alpaca/ShareGPT)
8
+ // train_validate — Validate a training dataset before launching
9
+ // train_start — Launch fine-tuning (OpenAI, Together, Mistral, MLX, Unsloth, llama.cpp)
10
+ // train_status — Check training job status (cloud or local)
11
+ // train_evaluate — Evaluate a fine-tuned model against test data
12
+ // train_export — Merge LoRA, convert to GGUF, quantize
13
+ // train_deploy — Deploy to Ollama, HuggingFace, or K:BOT local
14
+ // train_cost — Estimate training cost, time, and VRAM
15
+ import { execSync, spawn } from 'node:child_process';
16
+ import { existsSync, readFileSync, writeFileSync, mkdirSync, readdirSync, statSync } from 'node:fs';
17
+ import { resolve, join, basename, extname, dirname } from 'node:path';
18
+ import { homedir, cpus } from 'node:os';
19
+ import { registerTool } from './index.js';
20
+ // ── Helpers ──────────────────────────────────────────────────────────
21
+ function shell(cmd, opts) {
22
+ return execSync(cmd, {
23
+ encoding: 'utf-8',
24
+ timeout: opts?.timeout ?? 60_000,
25
+ maxBuffer: 50 * 1024 * 1024,
26
+ cwd: opts?.cwd ?? process.cwd(),
27
+ env: opts?.env ?? process.env,
28
+ stdio: ['pipe', 'pipe', 'pipe'],
29
+ }).trim();
30
+ }
31
+ function shellSafe(cmd, opts) {
32
+ try {
33
+ const output = shell(cmd, opts);
34
+ return { ok: true, output };
35
+ }
36
+ catch (err) {
37
+ const e = err;
38
+ const output = [e.stdout, e.stderr].filter(Boolean).join('\n').trim();
39
+ return { ok: false, output: output || e.message || 'Command failed' };
40
+ }
41
+ }
42
+ function isCommandAvailable(cmd) {
43
+ try {
44
+ execSync(`which ${cmd}`, { encoding: 'utf-8', timeout: 5_000, stdio: ['pipe', 'pipe', 'pipe'] });
45
+ return true;
46
+ }
47
+ catch {
48
+ return false;
49
+ }
50
+ }
51
+ function estimateTokens(text) {
52
+ // Rough estimate: ~4 chars per token for English text
53
+ return Math.ceil(text.length / 4);
54
+ }
55
+ function readKbotConfig() {
56
+ const configPath = join(homedir(), '.kbot', 'config.json');
57
+ if (!existsSync(configPath))
58
+ return {};
59
+ try {
60
+ return JSON.parse(readFileSync(configPath, 'utf-8'));
61
+ }
62
+ catch {
63
+ return {};
64
+ }
65
+ }
66
+ function getApiKey(backend, explicit) {
67
+ if (explicit)
68
+ return explicit;
69
+ // Check environment variables
70
+ const envVarMap = {
71
+ openai: ['OPENAI_API_KEY'],
72
+ together: ['TOGETHER_API_KEY', 'TOGETHER_AI_KEY'],
73
+ mistral: ['MISTRAL_API_KEY'],
74
+ };
75
+ for (const envVar of envVarMap[backend]) {
76
+ if (process.env[envVar])
77
+ return process.env[envVar];
78
+ }
79
+ // Check ~/.kbot/config.json
80
+ const config = readKbotConfig();
81
+ const configKeyMap = {
82
+ openai: ['openai_api_key', 'openaiApiKey'],
83
+ together: ['together_api_key', 'togetherApiKey'],
84
+ mistral: ['mistral_api_key', 'mistralApiKey'],
85
+ };
86
+ for (const key of configKeyMap[backend]) {
87
+ if (config[key] && typeof config[key] === 'string')
88
+ return config[key];
89
+ }
90
+ return null;
91
+ }
92
+ function normalizeText(text) {
93
+ return text.toLowerCase().replace(/\s+/g, ' ').trim();
94
+ }
95
+ /** Recursively collect files from a directory matching given extensions */
96
+ function collectFiles(dirPath, extensions) {
97
+ const results = [];
98
+ if (!existsSync(dirPath))
99
+ return results;
100
+ const stat = statSync(dirPath);
101
+ if (stat.isFile()) {
102
+ const ext = extname(dirPath).toLowerCase();
103
+ if (extensions.length === 0 || extensions.includes(ext)) {
104
+ results.push(dirPath);
105
+ }
106
+ return results;
107
+ }
108
+ if (!stat.isDirectory())
109
+ return results;
110
+ const entries = readdirSync(dirPath);
111
+ for (const entry of entries) {
112
+ if (entry.startsWith('.') || entry === 'node_modules')
113
+ continue;
114
+ const fullPath = join(dirPath, entry);
115
+ const entryStat = statSync(fullPath);
116
+ if (entryStat.isDirectory()) {
117
+ results.push(...collectFiles(fullPath, extensions));
118
+ }
119
+ else {
120
+ const ext = extname(entry).toLowerCase();
121
+ if (extensions.length === 0 || extensions.includes(ext)) {
122
+ results.push(fullPath);
123
+ }
124
+ }
125
+ }
126
+ return results;
127
+ }
128
+ // ── Source parsing helpers ────────────────────────────────────────────
129
+ /** Extract instruction/response pairs from a JSON conversation file */
130
+ function parseConversationJson(content) {
131
+ const pairs = [];
132
+ try {
133
+ const data = JSON.parse(content);
134
+ // Handle array of messages
135
+ if (Array.isArray(data)) {
136
+ for (let i = 0; i < data.length - 1; i++) {
137
+ const curr = data[i];
138
+ const next = data[i + 1];
139
+ if ((curr.role === 'user' || curr.from === 'human') &&
140
+ (next.role === 'assistant' || next.from === 'gpt' || next.from === 'assistant')) {
141
+ pairs.push({
142
+ instruction: curr.content || curr.value || curr.text || '',
143
+ response: next.content || next.value || next.text || '',
144
+ });
145
+ i++; // skip the response message
146
+ }
147
+ }
148
+ }
149
+ // Handle object with messages array
150
+ if (data.messages && Array.isArray(data.messages)) {
151
+ return parseConversationJson(JSON.stringify(data.messages));
152
+ }
153
+ // Handle object with conversations array
154
+ if (data.conversations && Array.isArray(data.conversations)) {
155
+ return parseConversationJson(JSON.stringify(data.conversations));
156
+ }
157
+ // Handle array of conversation objects
158
+ if (Array.isArray(data) && data.length > 0 && (data[0].messages || data[0].conversations)) {
159
+ for (const item of data) {
160
+ const msgs = item.messages || item.conversations;
161
+ if (Array.isArray(msgs)) {
162
+ pairs.push(...parseConversationJson(JSON.stringify(msgs)));
163
+ }
164
+ }
165
+ }
166
+ }
167
+ catch {
168
+ // Not valid JSON — skip
169
+ }
170
+ return pairs;
171
+ }
172
+ /** Extract instruction/response pairs from markdown (headings -> content) */
173
+ function parseMarkdown(content) {
174
+ const pairs = [];
175
+ const sections = content.split(/^(#{1,6}\s+.+)$/m);
176
+ for (let i = 1; i < sections.length - 1; i += 2) {
177
+ const heading = sections[i].replace(/^#+\s+/, '').trim();
178
+ const body = sections[i + 1].trim();
179
+ if (heading && body && body.length > 20) {
180
+ pairs.push({ instruction: heading, response: body });
181
+ }
182
+ }
183
+ return pairs;
184
+ }
185
+ /** Extract instruction/response pairs from code files (docstrings + implementations) */
186
+ function parseCodeFile(content, ext) {
187
+ const pairs = [];
188
+ if (['.py'].includes(ext)) {
189
+ // Python: extract function definitions with docstrings
190
+ const funcPattern = /^(def\s+\w+\s*\([^)]*\)(?:\s*->\s*[^:]+)?)\s*:\s*\n\s*("""[\s\S]*?"""|'''[\s\S]*?''')\s*\n([\s\S]*?)(?=\ndef\s|\nclass\s|$)/gm;
191
+ let match;
192
+ while ((match = funcPattern.exec(content)) !== null) {
193
+ const signature = match[1].trim();
194
+ const docstring = match[2].replace(/^("""|''')|("""|''')$/g, '').trim();
195
+ const body = match[0].trim();
196
+ if (docstring && body) {
197
+ pairs.push({
198
+ instruction: `Write a Python function: ${signature}\n\nDescription: ${docstring}`,
199
+ response: body,
200
+ });
201
+ }
202
+ }
203
+ }
204
+ if (['.ts', '.tsx', '.js', '.jsx'].includes(ext)) {
205
+ // TypeScript/JavaScript: extract functions with JSDoc
206
+ const funcPattern = /(\/\*\*[\s\S]*?\*\/)\s*\n\s*((?:export\s+)?(?:async\s+)?(?:function\s+\w+|(?:const|let)\s+\w+\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=]+)=>)[\s\S]*?)(?=\n\/\*\*|\nexport\s+(?:default\s+)?(?:function|const|class|interface|type)|\nclass\s|$)/gm;
207
+ let match;
208
+ while ((match = funcPattern.exec(content)) !== null) {
209
+ const jsdoc = match[1].replace(/\/\*\*|\*\/|\*\s?/g, '').trim();
210
+ const impl = match[2].trim();
211
+ if (jsdoc && impl && impl.length > 30) {
212
+ // Extract function name from implementation
213
+ const nameMatch = impl.match(/(?:function\s+(\w+)|(?:const|let)\s+(\w+))/);
214
+ const funcName = nameMatch ? (nameMatch[1] || nameMatch[2]) : 'function';
215
+ pairs.push({
216
+ instruction: `Write a TypeScript function \`${funcName}\`: ${jsdoc}`,
217
+ response: impl,
218
+ });
219
+ }
220
+ }
221
+ }
222
+ if (['.rs'].includes(ext)) {
223
+ // Rust: extract functions with doc comments
224
+ const funcPattern = /((?:\/\/\/.*\n)+)\s*(pub\s+(?:async\s+)?fn\s+\w+[\s\S]*?)(?=\n\/\/\/|\npub\s+(?:async\s+)?fn|\nfn\s|\nimpl\s|$)/gm;
225
+ let match;
226
+ while ((match = funcPattern.exec(content)) !== null) {
227
+ const docComment = match[1].replace(/\/\/\/\s?/g, '').trim();
228
+ const impl = match[2].trim();
229
+ if (docComment && impl) {
230
+ const nameMatch = impl.match(/fn\s+(\w+)/);
231
+ const funcName = nameMatch ? nameMatch[1] : 'function';
232
+ pairs.push({
233
+ instruction: `Write a Rust function \`${funcName}\`: ${docComment}`,
234
+ response: impl,
235
+ });
236
+ }
237
+ }
238
+ }
239
+ if (['.go'].includes(ext)) {
240
+ // Go: extract functions with doc comments
241
+ const funcPattern = /((?:\/\/.*\n)+)\s*(func\s+(?:\([^)]*\)\s+)?\w+[\s\S]*?)(?=\n\/\/|\nfunc\s|$)/gm;
242
+ let match;
243
+ while ((match = funcPattern.exec(content)) !== null) {
244
+ const docComment = match[1].replace(/\/\/\s?/g, '').trim();
245
+ const impl = match[2].trim();
246
+ if (docComment && impl) {
247
+ const nameMatch = impl.match(/func\s+(?:\([^)]*\)\s+)?(\w+)/);
248
+ const funcName = nameMatch ? nameMatch[1] : 'function';
249
+ pairs.push({
250
+ instruction: `Write a Go function \`${funcName}\`: ${docComment}`,
251
+ response: impl,
252
+ });
253
+ }
254
+ }
255
+ }
256
+ return pairs;
257
+ }
258
+ /** Convert pairs to the target training format */
259
+ function formatExample(pair, format, systemPrompt) {
260
+ switch (format) {
261
+ case 'jsonl': {
262
+ const messages = [];
263
+ if (systemPrompt)
264
+ messages.push({ role: 'system', content: systemPrompt });
265
+ messages.push({ role: 'user', content: pair.instruction });
266
+ messages.push({ role: 'assistant', content: pair.response });
267
+ return JSON.stringify({ messages });
268
+ }
269
+ case 'alpaca': {
270
+ const example = {
271
+ instruction: pair.instruction,
272
+ input: '',
273
+ output: pair.response,
274
+ };
275
+ return JSON.stringify(example);
276
+ }
277
+ case 'sharegpt': {
278
+ const conversations = [];
279
+ if (systemPrompt)
280
+ conversations.push({ from: 'system', value: systemPrompt });
281
+ conversations.push({ from: 'human', value: pair.instruction });
282
+ conversations.push({ from: 'gpt', value: pair.response });
283
+ return JSON.stringify({ conversations });
284
+ }
285
+ }
286
+ }
287
+ // ── BLEU score helper ────────────────────────────────────────────────
288
+ function computeNgrams(tokens, n) {
289
+ const ngrams = new Map();
290
+ for (let i = 0; i <= tokens.length - n; i++) {
291
+ const ngram = tokens.slice(i, i + n).join(' ');
292
+ ngrams.set(ngram, (ngrams.get(ngram) || 0) + 1);
293
+ }
294
+ return ngrams;
295
+ }
296
+ function bleuScore(reference, candidate, maxN = 4) {
297
+ const refTokens = reference.toLowerCase().split(/\s+/).filter(Boolean);
298
+ const candTokens = candidate.toLowerCase().split(/\s+/).filter(Boolean);
299
+ if (candTokens.length === 0 || refTokens.length === 0)
300
+ return 0;
301
+ let logBleu = 0;
302
+ let count = 0;
303
+ for (let n = 1; n <= Math.min(maxN, candTokens.length); n++) {
304
+ const refNgrams = computeNgrams(refTokens, n);
305
+ const candNgrams = computeNgrams(candTokens, n);
306
+ let clippedCount = 0;
307
+ let totalCount = 0;
308
+ for (const [ngram, candCount] of candNgrams) {
309
+ const refCount = refNgrams.get(ngram) || 0;
310
+ clippedCount += Math.min(candCount, refCount);
311
+ totalCount += candCount;
312
+ }
313
+ if (totalCount === 0)
314
+ continue;
315
+ const precision = clippedCount / totalCount;
316
+ if (precision === 0)
317
+ return 0;
318
+ logBleu += Math.log(precision);
319
+ count++;
320
+ }
321
+ if (count === 0)
322
+ return 0;
323
+ // Brevity penalty
324
+ const bp = candTokens.length >= refTokens.length
325
+ ? 1
326
+ : Math.exp(1 - refTokens.length / candTokens.length);
327
+ return bp * Math.exp(logBleu / count);
328
+ }
329
+ // ── Model size estimation helpers ────────────────────────────────────
330
+ /** Estimate parameter count from model name */
331
+ function estimateModelParams(modelName) {
332
+ const name = modelName.toLowerCase();
333
+ // Common model size patterns
334
+ const sizePatterns = [
335
+ [/(\d+)b/, 0], // will capture
336
+ [/(\d+\.?\d*)b/, 0],
337
+ [/gpt-4/i, 200e9],
338
+ [/gpt-3\.5/i, 7e9],
339
+ [/gpt-4.1-mini/i, 8e9],
340
+ [/gpt-4.1/i, 200e9],
341
+ [/llama.*70b/i, 70e9],
342
+ [/llama.*13b/i, 13e9],
343
+ [/llama.*8b/i, 8e9],
344
+ [/llama.*7b/i, 7e9],
345
+ [/llama.*3b/i, 3e9],
346
+ [/llama.*1b/i, 1e9],
347
+ [/mistral.*7b/i, 7e9],
348
+ [/mixtral/i, 47e9],
349
+ [/phi.*3/i, 3.8e9],
350
+ [/gemma.*7b/i, 7e9],
351
+ [/gemma.*2b/i, 2e9],
352
+ [/qwen.*72b/i, 72e9],
353
+ [/qwen.*14b/i, 14e9],
354
+ [/qwen.*7b/i, 7e9],
355
+ [/qwen.*1\.8b/i, 1.8e9],
356
+ ];
357
+ // Try to extract explicit size (e.g., "7b", "13B", "70b")
358
+ const sizeMatch = name.match(/(\d+\.?\d*)\s*b(?:illion)?/i);
359
+ if (sizeMatch) {
360
+ return parseFloat(sizeMatch[1]) * 1e9;
361
+ }
362
+ for (const [pattern, size] of sizePatterns) {
363
+ if (pattern.test(name) && size > 0)
364
+ return size;
365
+ }
366
+ // Default estimate for unknown models
367
+ return 7e9;
368
+ }
369
+ // ── Tool Registration ────────────────────────────────────────────────
370
+ export function registerTrainingTools() {
371
+ // ── train_prepare ──────────────────────────────────────────────────
372
+ registerTool({
373
+ name: 'train_prepare',
374
+ description: 'Convert data into training formats for fine-tuning. Supports JSONL chat (OpenAI SFT), Alpaca, and ShareGPT formats. ' +
375
+ 'Auto-detects source type: conversation JSON, markdown, or code files with docstrings. ' +
376
+ 'Extracts instruction/response pairs and writes the output dataset.',
377
+ parameters: {
378
+ source: {
379
+ type: 'string',
380
+ description: 'Path to source file or directory containing training data',
381
+ required: true,
382
+ },
383
+ format: {
384
+ type: 'string',
385
+ description: 'Output format: jsonl (OpenAI chat), alpaca, sharegpt (default: jsonl)',
386
+ },
387
+ output: {
388
+ type: 'string',
389
+ description: 'Output file path (default: <source>_prepared.<format>)',
390
+ },
391
+ system_prompt: {
392
+ type: 'string',
393
+ description: 'System prompt to prepend to each example (optional)',
394
+ },
395
+ },
396
+ tier: 'pro',
397
+ timeout: 300_000,
398
+ async execute(args) {
399
+ try {
400
+ const sourcePath = resolve(String(args.source));
401
+ const format = (String(args.format || 'jsonl').toLowerCase());
402
+ const systemPrompt = args.system_prompt ? String(args.system_prompt) : undefined;
403
+ if (!['jsonl', 'alpaca', 'sharegpt'].includes(format)) {
404
+ return `Error: Invalid format "${format}". Use: jsonl, alpaca, sharegpt`;
405
+ }
406
+ if (!existsSync(sourcePath)) {
407
+ return `Error: Source path not found: ${sourcePath}`;
408
+ }
409
+ // Collect all source files
410
+ const stat = statSync(sourcePath);
411
+ let sourceFiles;
412
+ if (stat.isFile()) {
413
+ sourceFiles = [sourcePath];
414
+ }
415
+ else {
416
+ sourceFiles = collectFiles(sourcePath, [
417
+ '.json', '.jsonl', '.md', '.markdown',
418
+ '.py', '.ts', '.tsx', '.js', '.jsx', '.rs', '.go',
419
+ ]);
420
+ }
421
+ if (sourceFiles.length === 0) {
422
+ return `Error: No supported source files found in ${sourcePath}\n\nSupported: .json, .jsonl, .md, .py, .ts, .js, .rs, .go`;
423
+ }
424
+ // Extract pairs from all sources
425
+ const allPairs = [];
426
+ for (const file of sourceFiles) {
427
+ const content = readFileSync(file, 'utf-8');
428
+ const ext = extname(file).toLowerCase();
429
+ if (ext === '.json') {
430
+ allPairs.push(...parseConversationJson(content));
431
+ }
432
+ else if (ext === '.jsonl') {
433
+ // Each line is a JSON object
434
+ const lines = content.split('\n').filter(l => l.trim());
435
+ for (const line of lines) {
436
+ try {
437
+ const obj = JSON.parse(line);
438
+ if (obj.messages) {
439
+ allPairs.push(...parseConversationJson(JSON.stringify(obj.messages)));
440
+ }
441
+ else if (obj.instruction && obj.output) {
442
+ allPairs.push({ instruction: obj.instruction, response: obj.output });
443
+ }
444
+ else if (obj.conversations) {
445
+ allPairs.push(...parseConversationJson(JSON.stringify(obj.conversations)));
446
+ }
447
+ }
448
+ catch { /* skip malformed lines */ }
449
+ }
450
+ }
451
+ else if (ext === '.md' || ext === '.markdown') {
452
+ allPairs.push(...parseMarkdown(content));
453
+ }
454
+ else {
455
+ // Code files
456
+ allPairs.push(...parseCodeFile(content, ext));
457
+ }
458
+ }
459
+ if (allPairs.length === 0) {
460
+ return [
461
+ `No instruction/response pairs extracted from ${sourceFiles.length} files.`,
462
+ '',
463
+ 'Tips:',
464
+ ' - JSON files should have messages with role/content or from/value fields',
465
+ ' - Markdown files should have headings followed by content paragraphs',
466
+ ' - Code files should have functions with docstrings/JSDoc/doc comments',
467
+ ].join('\n');
468
+ }
469
+ // Format all pairs
470
+ const outputLines = allPairs.map(pair => formatExample(pair, format, systemPrompt));
471
+ // Determine output path
472
+ const outputExt = format === 'jsonl' ? '.jsonl' : '.json';
473
+ let outputPath;
474
+ if (args.output) {
475
+ outputPath = resolve(String(args.output));
476
+ }
477
+ else {
478
+ const srcBase = stat.isFile()
479
+ ? sourcePath.replace(extname(sourcePath), '')
480
+ : sourcePath.replace(/\/$/, '');
481
+ outputPath = `${srcBase}_prepared${outputExt}`;
482
+ }
483
+ // Write output
484
+ mkdirSync(dirname(outputPath), { recursive: true });
485
+ if (format === 'jsonl') {
486
+ writeFileSync(outputPath, outputLines.join('\n') + '\n', 'utf-8');
487
+ }
488
+ else {
489
+ // For alpaca and sharegpt, write as a JSON array
490
+ const parsed = outputLines.map(l => JSON.parse(l));
491
+ writeFileSync(outputPath, JSON.stringify(parsed, null, 2) + '\n', 'utf-8');
492
+ }
493
+ // Calculate stats
494
+ const totalChars = allPairs.reduce((sum, p) => sum + p.instruction.length + p.response.length, 0);
495
+ const estimatedTokens = estimateTokens(allPairs.reduce((s, p) => s + p.instruction + p.response, ''));
496
+ return [
497
+ `Training data prepared successfully.`,
498
+ '',
499
+ ` Format: ${format}`,
500
+ ` Source: ${sourceFiles.length} file(s)`,
501
+ ` Examples: ${allPairs.length}`,
502
+ ` Est. tokens: ~${estimatedTokens.toLocaleString()} (~${totalChars.toLocaleString()} chars)`,
503
+ ` Output: ${outputPath}`,
504
+ '',
505
+ `File size: ${(statSync(outputPath).size / 1024).toFixed(1)} KB`,
506
+ ].join('\n');
507
+ }
508
+ catch (err) {
509
+ return `Error preparing training data: ${err instanceof Error ? err.message : String(err)}`;
510
+ }
511
+ },
512
+ });
513
+ // ── train_validate ─────────────────────────────────────────────────
514
+ registerTool({
515
+ name: 'train_validate',
516
+ description: 'Validate a training dataset before starting a fine-tuning job. Checks format correctness, ' +
517
+ 'duplicates, token length distribution, empty examples, and system prompt consistency. ' +
518
+ 'Reports warnings and errors with line numbers.',
519
+ parameters: {
520
+ dataset: {
521
+ type: 'string',
522
+ description: 'Path to the training dataset file',
523
+ required: true,
524
+ },
525
+ format: {
526
+ type: 'string',
527
+ description: 'Dataset format: jsonl, alpaca, sharegpt (default: jsonl)',
528
+ },
529
+ },
530
+ tier: 'pro',
531
+ timeout: 120_000,
532
+ async execute(args) {
533
+ try {
534
+ const datasetPath = resolve(String(args.dataset));
535
+ const format = (String(args.format || 'jsonl').toLowerCase());
536
+ if (!existsSync(datasetPath)) {
537
+ return `Error: Dataset not found: ${datasetPath}`;
538
+ }
539
+ const content = readFileSync(datasetPath, 'utf-8');
540
+ const issues = [];
541
+ const tokenCounts = [];
542
+ const normalizedTexts = new Set();
543
+ const systemPrompts = new Map();
544
+ let totalExamples = 0;
545
+ let emptyCount = 0;
546
+ let duplicateCount = 0;
547
+ const lines = format === 'jsonl'
548
+ ? content.split('\n').filter(l => l.trim())
549
+ : (() => {
550
+ try {
551
+ const parsed = JSON.parse(content);
552
+ return Array.isArray(parsed) ? parsed.map((item) => JSON.stringify(item)) : [content];
553
+ }
554
+ catch {
555
+ return content.split('\n').filter(l => l.trim());
556
+ }
557
+ })();
558
+ for (let i = 0; i < lines.length; i++) {
559
+ const lineNum = i + 1;
560
+ const line = typeof lines[i] === 'string' ? lines[i] : String(lines[i]);
561
+ // Parse JSON
562
+ let obj;
563
+ try {
564
+ obj = JSON.parse(line);
565
+ }
566
+ catch {
567
+ issues.push({ line: lineNum, severity: 'error', message: 'Invalid JSON' });
568
+ continue;
569
+ }
570
+ totalExamples++;
571
+ // Format-specific validation
572
+ if (format === 'jsonl') {
573
+ if (!obj.messages || !Array.isArray(obj.messages)) {
574
+ issues.push({ line: lineNum, severity: 'error', message: 'Missing "messages" array' });
575
+ continue;
576
+ }
577
+ const msgs = obj.messages;
578
+ if (msgs.length < 2) {
579
+ issues.push({ line: lineNum, severity: 'error', message: `Only ${msgs.length} message(s) — need at least user + assistant` });
580
+ continue;
581
+ }
582
+ // Check roles
583
+ const hasUser = msgs.some(m => m.role === 'user');
584
+ const hasAssistant = msgs.some(m => m.role === 'assistant');
585
+ if (!hasUser)
586
+ issues.push({ line: lineNum, severity: 'error', message: 'No "user" role message found' });
587
+ if (!hasAssistant)
588
+ issues.push({ line: lineNum, severity: 'error', message: 'No "assistant" role message found' });
589
+ for (let j = 0; j < msgs.length; j++) {
590
+ const m = msgs[j];
591
+ if (!m.role || !['system', 'user', 'assistant'].includes(m.role)) {
592
+ issues.push({ line: lineNum, severity: 'error', message: `Message ${j}: invalid role "${m.role}"` });
593
+ }
594
+ if (!m.content || typeof m.content !== 'string') {
595
+ issues.push({ line: lineNum, severity: 'error', message: `Message ${j}: missing or non-string content` });
596
+ }
597
+ }
598
+ // Track system prompts
599
+ const sysMsg = msgs.find(m => m.role === 'system');
600
+ if (sysMsg?.content) {
601
+ const key = normalizeText(sysMsg.content);
602
+ systemPrompts.set(key, (systemPrompts.get(key) || 0) + 1);
603
+ }
604
+ // Token count
605
+ const fullText = msgs.map(m => m.content || '').join(' ');
606
+ const tokens = estimateTokens(fullText);
607
+ tokenCounts.push(tokens);
608
+ // Empty check
609
+ const userMsg = msgs.find(m => m.role === 'user');
610
+ const assistantMsg = msgs.find(m => m.role === 'assistant');
611
+ if ((userMsg?.content?.length || 0) < 5 || (assistantMsg?.content?.length || 0) < 5) {
612
+ issues.push({ line: lineNum, severity: 'warning', message: 'Very short user or assistant message (< 5 chars)' });
613
+ emptyCount++;
614
+ }
615
+ // Duplicate check
616
+ const normalized = normalizeText(fullText);
617
+ if (normalizedTexts.has(normalized)) {
618
+ issues.push({ line: lineNum, severity: 'warning', message: 'Duplicate or near-duplicate example' });
619
+ duplicateCount++;
620
+ }
621
+ normalizedTexts.add(normalized);
622
+ }
623
+ else if (format === 'alpaca') {
624
+ if (!obj.instruction || typeof obj.instruction !== 'string') {
625
+ issues.push({ line: lineNum, severity: 'error', message: 'Missing or non-string "instruction" field' });
626
+ }
627
+ if (!obj.output || typeof obj.output !== 'string') {
628
+ issues.push({ line: lineNum, severity: 'error', message: 'Missing or non-string "output" field' });
629
+ }
630
+ if (obj.instruction && obj.instruction.length < 5) {
631
+ issues.push({ line: lineNum, severity: 'warning', message: 'Very short instruction (< 5 chars)' });
632
+ emptyCount++;
633
+ }
634
+ if (obj.output && obj.output.length < 5) {
635
+ issues.push({ line: lineNum, severity: 'warning', message: 'Very short output (< 5 chars)' });
636
+ }
637
+ const fullText = `${obj.instruction || ''} ${obj.input || ''} ${obj.output || ''}`;
638
+ tokenCounts.push(estimateTokens(fullText));
639
+ const normalized = normalizeText(fullText);
640
+ if (normalizedTexts.has(normalized)) {
641
+ issues.push({ line: lineNum, severity: 'warning', message: 'Duplicate or near-duplicate example' });
642
+ duplicateCount++;
643
+ }
644
+ normalizedTexts.add(normalized);
645
+ }
646
+ else if (format === 'sharegpt') {
647
+ if (!obj.conversations || !Array.isArray(obj.conversations)) {
648
+ issues.push({ line: lineNum, severity: 'error', message: 'Missing "conversations" array' });
649
+ continue;
650
+ }
651
+ const convs = obj.conversations;
652
+ if (convs.length < 2) {
653
+ issues.push({ line: lineNum, severity: 'error', message: `Only ${convs.length} turn(s) — need at least human + gpt` });
654
+ }
655
+ const hasHuman = convs.some(c => c.from === 'human');
656
+ const hasGpt = convs.some(c => c.from === 'gpt');
657
+ if (!hasHuman)
658
+ issues.push({ line: lineNum, severity: 'error', message: 'No "human" turn found' });
659
+ if (!hasGpt)
660
+ issues.push({ line: lineNum, severity: 'error', message: 'No "gpt" turn found' });
661
+ for (let j = 0; j < convs.length; j++) {
662
+ if (!convs[j].from || !['human', 'gpt', 'system'].includes(convs[j].from)) {
663
+ issues.push({ line: lineNum, severity: 'error', message: `Turn ${j}: invalid "from" value "${convs[j].from}"` });
664
+ }
665
+ if (!convs[j].value || typeof convs[j].value !== 'string') {
666
+ issues.push({ line: lineNum, severity: 'error', message: `Turn ${j}: missing or non-string "value"` });
667
+ }
668
+ }
669
+ // Track system prompts
670
+ const sysTurn = convs.find(c => c.from === 'system');
671
+ if (sysTurn?.value) {
672
+ const key = normalizeText(sysTurn.value);
673
+ systemPrompts.set(key, (systemPrompts.get(key) || 0) + 1);
674
+ }
675
+ const fullText = convs.map(c => c.value || '').join(' ');
676
+ tokenCounts.push(estimateTokens(fullText));
677
+ const normalized = normalizeText(fullText);
678
+ if (normalizedTexts.has(normalized)) {
679
+ issues.push({ line: lineNum, severity: 'warning', message: 'Duplicate or near-duplicate example' });
680
+ duplicateCount++;
681
+ }
682
+ normalizedTexts.add(normalized);
683
+ }
684
+ }
685
+ // Compute token stats
686
+ const sortedTokens = [...tokenCounts].sort((a, b) => a - b);
687
+ const minTokens = sortedTokens[0] || 0;
688
+ const maxTokens = sortedTokens[sortedTokens.length - 1] || 0;
689
+ const meanTokens = tokenCounts.length > 0
690
+ ? Math.round(tokenCounts.reduce((a, b) => a + b, 0) / tokenCounts.length)
691
+ : 0;
692
+ const p95Index = Math.floor(tokenCounts.length * 0.95);
693
+ const p95Tokens = sortedTokens[p95Index] || maxTokens;
694
+ const errors = issues.filter(i => i.severity === 'error');
695
+ const warnings = issues.filter(i => i.severity === 'warning');
696
+ // System prompt consistency
697
+ let systemPromptNote = '';
698
+ if (systemPrompts.size > 1) {
699
+ systemPromptNote = `\n System prompts: ${systemPrompts.size} unique variants (inconsistent — consider standardizing)`;
700
+ }
701
+ else if (systemPrompts.size === 1) {
702
+ const [prompt, count] = [...systemPrompts.entries()][0];
703
+ systemPromptNote = `\n System prompt: consistent across ${count} examples ("${prompt.slice(0, 60)}${prompt.length > 60 ? '...' : ''}")`;
704
+ }
705
+ else {
706
+ systemPromptNote = '\n System prompt: none found';
707
+ }
708
+ // Build report
709
+ const report = [
710
+ `Dataset Validation Report`,
711
+ `${'='.repeat(50)}`,
712
+ '',
713
+ ` File: ${datasetPath}`,
714
+ ` Format: ${format}`,
715
+ ` Examples: ${totalExamples}`,
716
+ ` Errors: ${errors.length}`,
717
+ ` Warnings: ${warnings.length}`,
718
+ ` Duplicates: ${duplicateCount}`,
719
+ ` Short/empty: ${emptyCount}`,
720
+ systemPromptNote,
721
+ '',
722
+ `Token Distribution`,
723
+ `${'─'.repeat(30)}`,
724
+ ` Min: ${minTokens}`,
725
+ ` Max: ${maxTokens}`,
726
+ ` Mean: ${meanTokens}`,
727
+ ` P95: ${p95Tokens}`,
728
+ ` Total: ~${tokenCounts.reduce((a, b) => a + b, 0).toLocaleString()} tokens`,
729
+ ];
730
+ if (errors.length > 0) {
731
+ report.push('', `Errors (${errors.length})`, '─'.repeat(30));
732
+ for (const issue of errors.slice(0, 25)) {
733
+ report.push(` Line ${issue.line}: ${issue.message}`);
734
+ }
735
+ if (errors.length > 25) {
736
+ report.push(` ... and ${errors.length - 25} more errors`);
737
+ }
738
+ }
739
+ if (warnings.length > 0) {
740
+ report.push('', `Warnings (${warnings.length})`, '─'.repeat(30));
741
+ for (const issue of warnings.slice(0, 15)) {
742
+ report.push(` Line ${issue.line}: ${issue.message}`);
743
+ }
744
+ if (warnings.length > 15) {
745
+ report.push(` ... and ${warnings.length - 15} more warnings`);
746
+ }
747
+ }
748
+ // Verdict
749
+ report.push('');
750
+ if (errors.length === 0 && warnings.length === 0) {
751
+ report.push('Verdict: PASS — Dataset is clean and ready for training.');
752
+ }
753
+ else if (errors.length === 0) {
754
+ report.push(`Verdict: PASS with warnings — ${warnings.length} warning(s) found but no blocking errors.`);
755
+ }
756
+ else {
757
+ report.push(`Verdict: FAIL — ${errors.length} error(s) must be fixed before training.`);
758
+ }
759
+ return report.join('\n');
760
+ }
761
+ catch (err) {
762
+ return `Validation error: ${err instanceof Error ? err.message : String(err)}`;
763
+ }
764
+ },
765
+ });
766
+ // ── train_start ────────────────────────────────────────────────────
767
+ registerTool({
768
+ name: 'train_start',
769
+ description: 'Launch a fine-tuning job. Supports cloud backends (OpenAI, Together AI, Mistral) and ' +
770
+ 'local backends (MLX on Apple Silicon, Unsloth, llama.cpp). For cloud: uploads dataset and creates job. ' +
771
+ 'For local: detects tool installation and launches the training process.',
772
+ parameters: {
773
+ dataset: {
774
+ type: 'string',
775
+ description: 'Path to training dataset file',
776
+ required: true,
777
+ },
778
+ backend: {
779
+ type: 'string',
780
+ description: 'Training backend: openai, together, mistral, mlx, unsloth, llama-cpp',
781
+ required: true,
782
+ },
783
+ base_model: {
784
+ type: 'string',
785
+ description: 'Base model to fine-tune (e.g., gpt-4.1-mini, meta-llama/Llama-3-8B, mlx-community/Llama-3-8B-4bit)',
786
+ required: true,
787
+ },
788
+ output: {
789
+ type: 'string',
790
+ description: 'Output path for local training (default: ./output/<model>-ft)',
791
+ },
792
+ epochs: {
793
+ type: 'number',
794
+ description: 'Number of training epochs (default: 3)',
795
+ },
796
+ learning_rate: {
797
+ type: 'number',
798
+ description: 'Learning rate (default: 1e-4 for local, auto for cloud)',
799
+ },
800
+ batch_size: {
801
+ type: 'number',
802
+ description: 'Batch size (default: 4)',
803
+ },
804
+ lora_rank: {
805
+ type: 'number',
806
+ description: 'LoRA rank for local training (default: 16)',
807
+ },
808
+ lora_alpha: {
809
+ type: 'number',
810
+ description: 'LoRA alpha for local training (default: 32)',
811
+ },
812
+ api_key: {
813
+ type: 'string',
814
+ description: 'API key for cloud backends (optional — reads from env or ~/.kbot/config.json)',
815
+ },
816
+ },
817
+ tier: 'pro',
818
+ timeout: 600_000,
819
+ async execute(args) {
820
+ try {
821
+ const datasetPath = resolve(String(args.dataset));
822
+ const backend = String(args.backend).toLowerCase();
823
+ const baseModel = String(args.base_model);
824
+ const epochs = typeof args.epochs === 'number' ? args.epochs : 3;
825
+ const batchSize = typeof args.batch_size === 'number' ? args.batch_size : 4;
826
+ const loraRank = typeof args.lora_rank === 'number' ? args.lora_rank : 16;
827
+ const loraAlpha = typeof args.lora_alpha === 'number' ? args.lora_alpha : 32;
828
+ const learningRate = typeof args.learning_rate === 'number' ? args.learning_rate : 1e-4;
829
+ if (!existsSync(datasetPath)) {
830
+ return `Error: Dataset not found: ${datasetPath}`;
831
+ }
832
+ const validBackends = ['openai', 'together', 'mistral', 'mlx', 'unsloth', 'llama-cpp'];
833
+ if (!validBackends.includes(backend)) {
834
+ return `Error: Invalid backend "${backend}". Supported: ${validBackends.join(', ')}`;
835
+ }
836
+ // Determine output directory for local backends
837
+ const modelSlug = baseModel.replace(/[^a-zA-Z0-9-]/g, '-').replace(/-+/g, '-').slice(0, 50);
838
+ const outputDir = args.output
839
+ ? resolve(String(args.output))
840
+ : resolve(process.cwd(), 'output', `${modelSlug}-ft`);
841
+ // ── Cloud backends ───────────────────────────────────────────
842
+ if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
843
+ const apiKey = getApiKey(backend, args.api_key ? String(args.api_key) : undefined);
844
+ if (!apiKey) {
845
+ const envVarHint = backend === 'openai' ? 'OPENAI_API_KEY'
846
+ : backend === 'together' ? 'TOGETHER_API_KEY'
847
+ : 'MISTRAL_API_KEY';
848
+ return `Error: No API key found for ${backend}. Set ${envVarHint} env var, add to ~/.kbot/config.json, or pass api_key parameter.`;
849
+ }
850
+ const datasetContent = readFileSync(datasetPath, 'utf-8');
851
+ if (backend === 'openai') {
852
+ // Step 1: Upload file
853
+ const formData = new FormData();
854
+ formData.append('purpose', 'fine-tune');
855
+ formData.append('file', new Blob([datasetContent], { type: 'application/jsonl' }), basename(datasetPath));
856
+ const uploadRes = await fetch('https://api.openai.com/v1/files', {
857
+ method: 'POST',
858
+ headers: { 'Authorization': `Bearer ${apiKey}` },
859
+ body: formData,
860
+ });
861
+ if (!uploadRes.ok) {
862
+ const errBody = await uploadRes.text();
863
+ return `Error uploading file to OpenAI: ${uploadRes.status} ${errBody}`;
864
+ }
865
+ const uploadData = await uploadRes.json();
866
+ const fileId = uploadData.id;
867
+ // Step 2: Create fine-tuning job
868
+ const jobBody = {
869
+ training_file: fileId,
870
+ model: baseModel,
871
+ hyperparameters: {
872
+ n_epochs: epochs,
873
+ batch_size: batchSize,
874
+ },
875
+ };
876
+ if (args.learning_rate !== undefined) {
877
+ jobBody.hyperparameters.learning_rate_multiplier = learningRate;
878
+ }
879
+ const jobRes = await fetch('https://api.openai.com/v1/fine_tuning/jobs', {
880
+ method: 'POST',
881
+ headers: {
882
+ 'Authorization': `Bearer ${apiKey}`,
883
+ 'Content-Type': 'application/json',
884
+ },
885
+ body: JSON.stringify(jobBody),
886
+ });
887
+ if (!jobRes.ok) {
888
+ const errBody = await jobRes.text();
889
+ return `Error creating OpenAI fine-tuning job: ${jobRes.status} ${errBody}`;
890
+ }
891
+ const jobData = await jobRes.json();
892
+ return [
893
+ `OpenAI fine-tuning job created.`,
894
+ '',
895
+ ` Job ID: ${jobData.id}`,
896
+ ` Base model: ${baseModel}`,
897
+ ` Status: ${jobData.status}`,
898
+ ` File ID: ${fileId}`,
899
+ ` Epochs: ${epochs}`,
900
+ ` Batch size: ${batchSize}`,
901
+ '',
902
+ `Check status with: train_status --job_id ${jobData.id} --backend openai`,
903
+ ].join('\n');
904
+ }
905
+ else if (backend === 'together') {
906
+ // Together AI fine-tuning
907
+ const jobBody = {
908
+ training_file: datasetPath,
909
+ model: baseModel,
910
+ n_epochs: epochs,
911
+ learning_rate: learningRate,
912
+ batch_size: batchSize,
913
+ };
914
+ // Together requires file upload first
915
+ const formData = new FormData();
916
+ formData.append('file', new Blob([datasetContent], { type: 'application/jsonl' }), basename(datasetPath));
917
+ formData.append('purpose', 'fine-tune');
918
+ const uploadRes = await fetch('https://api.together.xyz/v1/files', {
919
+ method: 'POST',
920
+ headers: { 'Authorization': `Bearer ${apiKey}` },
921
+ body: formData,
922
+ });
923
+ if (!uploadRes.ok) {
924
+ const errBody = await uploadRes.text();
925
+ return `Error uploading file to Together AI: ${uploadRes.status} ${errBody}`;
926
+ }
927
+ const uploadData = await uploadRes.json();
928
+ const jobRes = await fetch('https://api.together.xyz/v1/fine-tunes', {
929
+ method: 'POST',
930
+ headers: {
931
+ 'Authorization': `Bearer ${apiKey}`,
932
+ 'Content-Type': 'application/json',
933
+ },
934
+ body: JSON.stringify({
935
+ ...jobBody,
936
+ training_file: uploadData.id,
937
+ }),
938
+ });
939
+ if (!jobRes.ok) {
940
+ const errBody = await jobRes.text();
941
+ return `Error creating Together AI fine-tuning job: ${jobRes.status} ${errBody}`;
942
+ }
943
+ const jobData = await jobRes.json();
944
+ return [
945
+ `Together AI fine-tuning job created.`,
946
+ '',
947
+ ` Job ID: ${jobData.id}`,
948
+ ` Base model: ${baseModel}`,
949
+ ` Status: ${jobData.status}`,
950
+ ` Epochs: ${epochs}`,
951
+ ` Learning rate: ${learningRate}`,
952
+ ` Batch size: ${batchSize}`,
953
+ '',
954
+ `Check status with: train_status --job_id ${jobData.id} --backend together`,
955
+ ].join('\n');
956
+ }
957
+ else if (backend === 'mistral') {
958
+ // Mistral fine-tuning
959
+ const formData = new FormData();
960
+ formData.append('file', new Blob([datasetContent], { type: 'application/jsonl' }), basename(datasetPath));
961
+ formData.append('purpose', 'fine-tune');
962
+ const uploadRes = await fetch('https://api.mistral.ai/v1/files', {
963
+ method: 'POST',
964
+ headers: { 'Authorization': `Bearer ${apiKey}` },
965
+ body: formData,
966
+ });
967
+ if (!uploadRes.ok) {
968
+ const errBody = await uploadRes.text();
969
+ return `Error uploading file to Mistral: ${uploadRes.status} ${errBody}`;
970
+ }
971
+ const uploadData = await uploadRes.json();
972
+ const jobRes = await fetch('https://api.mistral.ai/v1/fine_tuning/jobs', {
973
+ method: 'POST',
974
+ headers: {
975
+ 'Authorization': `Bearer ${apiKey}`,
976
+ 'Content-Type': 'application/json',
977
+ },
978
+ body: JSON.stringify({
979
+ model: baseModel,
980
+ training_files: [{ file_id: uploadData.id, weight: 1 }],
981
+ hyperparameters: {
982
+ training_steps: epochs * 100, // Mistral uses steps
983
+ learning_rate: learningRate,
984
+ },
985
+ }),
986
+ });
987
+ if (!jobRes.ok) {
988
+ const errBody = await jobRes.text();
989
+ return `Error creating Mistral fine-tuning job: ${jobRes.status} ${errBody}`;
990
+ }
991
+ const jobData = await jobRes.json();
992
+ return [
993
+ `Mistral fine-tuning job created.`,
994
+ '',
995
+ ` Job ID: ${jobData.id}`,
996
+ ` Base model: ${baseModel}`,
997
+ ` Status: ${jobData.status}`,
998
+ ` Epochs: ${epochs}`,
999
+ ` Learning rate: ${learningRate}`,
1000
+ '',
1001
+ `Check status with: train_status --job_id ${jobData.id} --backend mistral`,
1002
+ ].join('\n');
1003
+ }
1004
+ }
1005
+ // ── Local backends ───────────────────────────────────────────
1006
+ mkdirSync(outputDir, { recursive: true });
1007
+ if (backend === 'mlx') {
1008
+ // Check if mlx_lm is available
1009
+ const hasMlx = shellSafe('python3 -c "import mlx_lm; print(mlx_lm.__version__)"');
1010
+ if (!hasMlx.ok) {
1011
+ return [
1012
+ 'MLX LM is not installed. Install it with:',
1013
+ '',
1014
+ ' pip install mlx-lm',
1015
+ '',
1016
+ 'Requirements:',
1017
+ ' - Apple Silicon Mac (M1/M2/M3/M4)',
1018
+ ' - macOS 14+ (Sonoma)',
1019
+ ' - Python 3.10+',
1020
+ ].join('\n');
1021
+ }
1022
+ const iters = epochs * 100; // Rough: iters = epochs * (dataset_size / batch_size)
1023
+ const adapterPath = join(outputDir, 'adapters');
1024
+ mkdirSync(adapterPath, { recursive: true });
1025
+ const cmd = [
1026
+ 'python3 -m mlx_lm.lora',
1027
+ `--model ${baseModel}`,
1028
+ `--data ${datasetPath}`,
1029
+ '--train',
1030
+ `--iters ${iters}`,
1031
+ `--batch-size ${batchSize}`,
1032
+ `--lora-layers ${loraRank}`,
1033
+ `--adapter-path ${adapterPath}`,
1034
+ ].join(' ');
1035
+ // Write the command to a script file for background execution
1036
+ const scriptPath = join(outputDir, 'train.sh');
1037
+ const logPath = join(outputDir, 'train.log');
1038
+ writeFileSync(scriptPath, [
1039
+ '#!/bin/bash',
1040
+ `echo "Training started at $(date)" > ${logPath}`,
1041
+ `echo "Command: ${cmd}" >> ${logPath}`,
1042
+ `${cmd} 2>&1 | tee -a ${logPath}`,
1043
+ `echo "Training finished at $(date)" >> ${logPath}`,
1044
+ ].join('\n'), 'utf-8');
1045
+ shell(`chmod +x ${scriptPath}`);
1046
+ // Launch in background
1047
+ const child = spawn('bash', [scriptPath], {
1048
+ detached: true,
1049
+ stdio: ['ignore', 'ignore', 'ignore'],
1050
+ });
1051
+ child.unref();
1052
+ return [
1053
+ `MLX LoRA training launched in background.`,
1054
+ '',
1055
+ ` Base model: ${baseModel}`,
1056
+ ` Dataset: ${datasetPath}`,
1057
+ ` Iterations: ${iters}`,
1058
+ ` Batch size: ${batchSize}`,
1059
+ ` LoRA layers: ${loraRank}`,
1060
+ ` Adapter path: ${adapterPath}`,
1061
+ ` Log file: ${logPath}`,
1062
+ ` PID: ${child.pid}`,
1063
+ '',
1064
+ `Monitor progress: train_status --backend mlx --log_path ${logPath}`,
1065
+ `When done, merge adapter: train_export --model_path ${baseModel} --operation merge_lora --base_model ${baseModel}`,
1066
+ ].join('\n');
1067
+ }
1068
+ else if (backend === 'unsloth') {
1069
+ // Check if unsloth is available
1070
+ const hasUnsloth = shellSafe('python3 -c "import unsloth; print(unsloth.__version__)"');
1071
+ if (!hasUnsloth.ok) {
1072
+ return [
1073
+ 'Unsloth is not installed. Install it with:',
1074
+ '',
1075
+ ' pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"',
1076
+ ' pip install --no-deps "trl<0.9.0" peft accelerate bitsandbytes',
1077
+ '',
1078
+ 'Requirements:',
1079
+ ' - NVIDIA GPU with CUDA support',
1080
+ ' - Python 3.10+',
1081
+ ' - PyTorch 2.0+',
1082
+ ].join('\n');
1083
+ }
1084
+ // Generate a Python training script
1085
+ const scriptContent = `#!/usr/bin/env python3
1086
+ """Unsloth fine-tuning script generated by K:BOT"""
1087
+ import json
1088
+ from unsloth import FastLanguageModel
1089
+ from trl import SFTTrainer
1090
+ from transformers import TrainingArguments
1091
+ from datasets import load_dataset
1092
+
1093
+ # Configuration
1094
+ BASE_MODEL = "${baseModel}"
1095
+ DATASET_PATH = "${datasetPath}"
1096
+ OUTPUT_DIR = "${outputDir}"
1097
+ EPOCHS = ${epochs}
1098
+ BATCH_SIZE = ${batchSize}
1099
+ LEARNING_RATE = ${learningRate}
1100
+ LORA_RANK = ${loraRank}
1101
+ LORA_ALPHA = ${loraAlpha}
1102
+ MAX_SEQ_LENGTH = 2048
1103
+
1104
+ print(f"Loading model: {BASE_MODEL}")
1105
+ model, tokenizer = FastLanguageModel.from_pretrained(
1106
+ model_name=BASE_MODEL,
1107
+ max_seq_length=MAX_SEQ_LENGTH,
1108
+ dtype=None, # auto-detect
1109
+ load_in_4bit=True,
1110
+ )
1111
+
1112
+ print("Applying LoRA adapters...")
1113
+ model = FastLanguageModel.get_peft_model(
1114
+ model,
1115
+ r=LORA_RANK,
1116
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
1117
+ "gate_proj", "up_proj", "down_proj"],
1118
+ lora_alpha=LORA_ALPHA,
1119
+ lora_dropout=0,
1120
+ bias="none",
1121
+ use_gradient_checkpointing="unsloth",
1122
+ random_state=42,
1123
+ )
1124
+
1125
+ print(f"Loading dataset: {DATASET_PATH}")
1126
+ dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
1127
+
1128
+ def formatting_func(examples):
1129
+ texts = []
1130
+ for msgs in examples.get("messages", [[]]):
1131
+ text_parts = []
1132
+ for msg in msgs:
1133
+ role = msg.get("role", "user")
1134
+ content = msg.get("content", "")
1135
+ if role == "system":
1136
+ text_parts.append(f"### System:\\n{content}")
1137
+ elif role == "user":
1138
+ text_parts.append(f"### User:\\n{content}")
1139
+ elif role == "assistant":
1140
+ text_parts.append(f"### Assistant:\\n{content}")
1141
+ texts.append("\\n\\n".join(text_parts))
1142
+ return {"text": texts}
1143
+
1144
+ dataset = dataset.map(formatting_func, batched=True, remove_columns=dataset.column_names)
1145
+
1146
+ trainer = SFTTrainer(
1147
+ model=model,
1148
+ tokenizer=tokenizer,
1149
+ train_dataset=dataset,
1150
+ dataset_text_field="text",
1151
+ max_seq_length=MAX_SEQ_LENGTH,
1152
+ dataset_num_proc=2,
1153
+ packing=True,
1154
+ args=TrainingArguments(
1155
+ per_device_train_batch_size=BATCH_SIZE,
1156
+ gradient_accumulation_steps=4,
1157
+ warmup_steps=10,
1158
+ num_train_epochs=EPOCHS,
1159
+ learning_rate=LEARNING_RATE,
1160
+ fp16=True,
1161
+ logging_steps=1,
1162
+ optim="adamw_8bit",
1163
+ weight_decay=0.01,
1164
+ lr_scheduler_type="linear",
1165
+ seed=42,
1166
+ output_dir=OUTPUT_DIR,
1167
+ save_strategy="epoch",
1168
+ ),
1169
+ )
1170
+
1171
+ print("Starting training...")
1172
+ trainer_stats = trainer.train()
1173
+ print(f"Training complete. Loss: {trainer_stats.training_loss:.4f}")
1174
+
1175
+ print(f"Saving model to {OUTPUT_DIR}")
1176
+ model.save_pretrained(OUTPUT_DIR)
1177
+ tokenizer.save_pretrained(OUTPUT_DIR)
1178
+ print("Done!")
1179
+ `;
1180
+ const scriptPath = join(outputDir, 'train_unsloth.py');
1181
+ const logPath = join(outputDir, 'train.log');
1182
+ writeFileSync(scriptPath, scriptContent, 'utf-8');
1183
+ // Launch in background
1184
+ const launchScript = join(outputDir, 'train.sh');
1185
+ writeFileSync(launchScript, [
1186
+ '#!/bin/bash',
1187
+ `echo "Training started at $(date)" > ${logPath}`,
1188
+ `python3 ${scriptPath} 2>&1 | tee -a ${logPath}`,
1189
+ `echo "Training finished at $(date)" >> ${logPath}`,
1190
+ ].join('\n'), 'utf-8');
1191
+ shell(`chmod +x ${launchScript}`);
1192
+ const child = spawn('bash', [launchScript], {
1193
+ detached: true,
1194
+ stdio: ['ignore', 'ignore', 'ignore'],
1195
+ });
1196
+ child.unref();
1197
+ return [
1198
+ `Unsloth fine-tuning launched in background.`,
1199
+ '',
1200
+ ` Base model: ${baseModel}`,
1201
+ ` Dataset: ${datasetPath}`,
1202
+ ` Epochs: ${epochs}`,
1203
+ ` Batch size: ${batchSize}`,
1204
+ ` Learning rate: ${learningRate}`,
1205
+ ` LoRA rank: ${loraRank}`,
1206
+ ` LoRA alpha: ${loraAlpha}`,
1207
+ ` Output: ${outputDir}`,
1208
+ ` Log file: ${logPath}`,
1209
+ ` Training script: ${scriptPath}`,
1210
+ ` PID: ${child.pid}`,
1211
+ '',
1212
+ `Monitor progress: train_status --backend unsloth --log_path ${logPath}`,
1213
+ ].join('\n');
1214
+ }
1215
+ else if (backend === 'llama-cpp') {
1216
+ // Check if llama-finetune is available
1217
+ if (!isCommandAvailable('llama-finetune')) {
1218
+ return [
1219
+ 'llama-finetune is not installed. Build it from llama.cpp:',
1220
+ '',
1221
+ ' git clone https://github.com/ggerganov/llama.cpp',
1222
+ ' cd llama.cpp',
1223
+ ' make llama-finetune',
1224
+ '',
1225
+ 'Then add the build directory to your PATH.',
1226
+ '',
1227
+ 'Alternative: use the MLX backend on Apple Silicon (faster, easier setup).',
1228
+ ].join('\n');
1229
+ }
1230
+ const threads = Math.max(1, cpus().length - 2);
1231
+ const loraOutPath = join(outputDir, 'lora-adapter.bin');
1232
+ const logPath = join(outputDir, 'train.log');
1233
+ const cmd = [
1234
+ 'llama-finetune',
1235
+ `--model-base ${baseModel}`,
1236
+ `--lora-out ${loraOutPath}`,
1237
+ `--train-data ${datasetPath}`,
1238
+ `--threads ${threads}`,
1239
+ `--epochs ${epochs}`,
1240
+ `--batch ${batchSize}`,
1241
+ `--lora-r ${loraRank}`,
1242
+ `--lora-alpha ${loraAlpha}`,
1243
+ ].join(' ');
1244
+ const scriptPath = join(outputDir, 'train.sh');
1245
+ writeFileSync(scriptPath, [
1246
+ '#!/bin/bash',
1247
+ `echo "Training started at $(date)" > ${logPath}`,
1248
+ `echo "Command: ${cmd}" >> ${logPath}`,
1249
+ `${cmd} 2>&1 | tee -a ${logPath}`,
1250
+ `echo "Training finished at $(date)" >> ${logPath}`,
1251
+ ].join('\n'), 'utf-8');
1252
+ shell(`chmod +x ${scriptPath}`);
1253
+ const child = spawn('bash', [scriptPath], {
1254
+ detached: true,
1255
+ stdio: ['ignore', 'ignore', 'ignore'],
1256
+ });
1257
+ child.unref();
1258
+ return [
1259
+ `llama.cpp fine-tuning launched in background.`,
1260
+ '',
1261
+ ` Base model: ${baseModel}`,
1262
+ ` Dataset: ${datasetPath}`,
1263
+ ` Epochs: ${epochs}`,
1264
+ ` Batch size: ${batchSize}`,
1265
+ ` LoRA rank: ${loraRank}`,
1266
+ ` LoRA alpha: ${loraAlpha}`,
1267
+ ` Threads: ${threads}`,
1268
+ ` LoRA output: ${loraOutPath}`,
1269
+ ` Log file: ${logPath}`,
1270
+ ` PID: ${child.pid}`,
1271
+ '',
1272
+ `Monitor progress: train_status --backend llama-cpp --log_path ${logPath}`,
1273
+ ].join('\n');
1274
+ }
1275
+ return `Error: Backend "${backend}" not handled.`;
1276
+ }
1277
+ catch (err) {
1278
+ return `Error starting training: ${err instanceof Error ? err.message : String(err)}`;
1279
+ }
1280
+ },
1281
+ });
1282
+ // ── train_status ───────────────────────────────────────────────────
1283
+ registerTool({
1284
+ name: 'train_status',
1285
+ description: 'Check the status of a fine-tuning job. For cloud backends, polls the provider API. ' +
1286
+ 'For local backends, checks process state and reads the latest loss from the log file.',
1287
+ parameters: {
1288
+ job_id: {
1289
+ type: 'string',
1290
+ description: 'Job ID for cloud backends (required for openai/together/mistral)',
1291
+ },
1292
+ backend: {
1293
+ type: 'string',
1294
+ description: 'Backend: openai, together, mistral, mlx, unsloth, llama-cpp',
1295
+ required: true,
1296
+ },
1297
+ log_path: {
1298
+ type: 'string',
1299
+ description: 'Path to training log file (for local backends)',
1300
+ },
1301
+ api_key: {
1302
+ type: 'string',
1303
+ description: 'API key for cloud backends (optional)',
1304
+ },
1305
+ },
1306
+ tier: 'pro',
1307
+ timeout: 30_000,
1308
+ async execute(args) {
1309
+ try {
1310
+ const backend = String(args.backend).toLowerCase();
1311
+ const jobId = args.job_id ? String(args.job_id) : '';
1312
+ const logPath = args.log_path ? resolve(String(args.log_path)) : '';
1313
+ // ── Cloud backends ─────────────────────────────────────────
1314
+ if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
1315
+ if (!jobId) {
1316
+ return `Error: job_id is required for cloud backend "${backend}".`;
1317
+ }
1318
+ const apiKey = getApiKey(backend, args.api_key ? String(args.api_key) : undefined);
1319
+ if (!apiKey) {
1320
+ return `Error: No API key found for ${backend}. Set the appropriate environment variable or pass api_key.`;
1321
+ }
1322
+ let url;
1323
+ if (backend === 'openai') {
1324
+ url = `https://api.openai.com/v1/fine_tuning/jobs/${jobId}`;
1325
+ }
1326
+ else if (backend === 'together') {
1327
+ url = `https://api.together.xyz/v1/fine-tunes/${jobId}`;
1328
+ }
1329
+ else {
1330
+ url = `https://api.mistral.ai/v1/fine_tuning/jobs/${jobId}`;
1331
+ }
1332
+ const res = await fetch(url, {
1333
+ headers: { 'Authorization': `Bearer ${apiKey}` },
1334
+ });
1335
+ if (!res.ok) {
1336
+ const errBody = await res.text();
1337
+ return `Error fetching job status from ${backend}: ${res.status} ${errBody}`;
1338
+ }
1339
+ const data = await res.json();
1340
+ const status = data.status || data.state || 'unknown';
1341
+ const model = data.model || data.fine_tuned_model || data.output_name || 'N/A';
1342
+ const createdAt = data.created_at ? new Date(data.created_at * 1000).toISOString() : 'N/A';
1343
+ const finishedAt = data.finished_at ? new Date(data.finished_at * 1000).toISOString() : 'N/A';
1344
+ // Extract training metrics if available
1345
+ let metricsInfo = '';
1346
+ const trainedTokens = data.trained_tokens || data.training_tokens;
1347
+ if (trainedTokens)
1348
+ metricsInfo += `\n Trained tokens: ${trainedTokens.toLocaleString()}`;
1349
+ const resultFiles = data.result_files;
1350
+ if (Array.isArray(resultFiles) && resultFiles.length > 0) {
1351
+ metricsInfo += `\n Result files: ${resultFiles.length}`;
1352
+ }
1353
+ const error = data.error;
1354
+ if (error) {
1355
+ metricsInfo += `\n Error: ${error.message || JSON.stringify(error)}`;
1356
+ }
1357
+ // OpenAI specific: get events
1358
+ let eventsInfo = '';
1359
+ if (backend === 'openai') {
1360
+ try {
1361
+ const eventsRes = await fetch(`${url}/events?limit=5`, {
1362
+ headers: { 'Authorization': `Bearer ${apiKey}` },
1363
+ });
1364
+ if (eventsRes.ok) {
1365
+ const eventsData = await eventsRes.json();
1366
+ if (eventsData.data && eventsData.data.length > 0) {
1367
+ eventsInfo = '\n\nRecent events:\n' + eventsData.data
1368
+ .map(e => ` [${new Date(e.created_at * 1000).toLocaleTimeString()}] ${e.message}`)
1369
+ .join('\n');
1370
+ }
1371
+ }
1372
+ }
1373
+ catch { /* ignore events fetch errors */ }
1374
+ }
1375
+ return [
1376
+ `${backend.charAt(0).toUpperCase() + backend.slice(1)} Fine-Tuning Job Status`,
1377
+ '='.repeat(40),
1378
+ '',
1379
+ ` Job ID: ${jobId}`,
1380
+ ` Status: ${status}`,
1381
+ ` Base model: ${model}`,
1382
+ ` Created: ${createdAt}`,
1383
+ ` Finished: ${finishedAt}`,
1384
+ metricsInfo,
1385
+ eventsInfo,
1386
+ ].filter(Boolean).join('\n');
1387
+ }
1388
+ // ── Local backends ─────────────────────────────────────────
1389
+ if (!logPath) {
1390
+ return `Error: log_path is required for local backend "${backend}". Pass the path to the training log file.`;
1391
+ }
1392
+ if (!existsSync(logPath)) {
1393
+ return `Error: Log file not found: ${logPath}`;
1394
+ }
1395
+ const logContent = readFileSync(logPath, 'utf-8');
1396
+ const logLines = logContent.split('\n').filter(l => l.trim());
1397
+ // Determine status
1398
+ let status = 'running';
1399
+ const lastLine = logLines[logLines.length - 1] || '';
1400
+ if (lastLine.includes('Training finished') || lastLine.includes('Done!') || lastLine.includes('Saving model')) {
1401
+ status = 'completed';
1402
+ }
1403
+ if (lastLine.includes('Error') || lastLine.includes('Traceback') || lastLine.includes('FAILED')) {
1404
+ status = 'failed';
1405
+ }
1406
+ // Extract training metrics from log
1407
+ let currentLoss = 'N/A';
1408
+ let currentStep = 'N/A';
1409
+ let totalSteps = 'N/A';
1410
+ // Search for loss values (common patterns across frameworks)
1411
+ const lossPatterns = [
1412
+ /loss[:\s=]+(\d+\.?\d*)/i,
1413
+ /train_loss[:\s=]+(\d+\.?\d*)/i,
1414
+ /training_loss[:\s=]+(\d+\.?\d*)/i,
1415
+ ];
1416
+ const stepPatterns = [
1417
+ /(?:step|iter|iteration)[:\s=]+(\d+)/i,
1418
+ /(\d+)\/(\d+)/, // step/total
1419
+ ];
1420
+ // Read from end of file for latest values
1421
+ for (let i = logLines.length - 1; i >= Math.max(0, logLines.length - 50); i--) {
1422
+ const line = logLines[i];
1423
+ if (currentLoss === 'N/A') {
1424
+ for (const pattern of lossPatterns) {
1425
+ const match = line.match(pattern);
1426
+ if (match) {
1427
+ currentLoss = match[1];
1428
+ break;
1429
+ }
1430
+ }
1431
+ }
1432
+ if (currentStep === 'N/A') {
1433
+ for (const pattern of stepPatterns) {
1434
+ const match = line.match(pattern);
1435
+ if (match) {
1436
+ currentStep = match[1];
1437
+ if (match[2])
1438
+ totalSteps = match[2];
1439
+ break;
1440
+ }
1441
+ }
1442
+ }
1443
+ if (currentLoss !== 'N/A' && currentStep !== 'N/A')
1444
+ break;
1445
+ }
1446
+ // Extract start time
1447
+ let elapsed = 'N/A';
1448
+ let eta = 'N/A';
1449
+ const startMatch = logContent.match(/Training started at (.+)/);
1450
+ if (startMatch) {
1451
+ const startTime = new Date(startMatch[1]);
1452
+ const elapsedMs = Date.now() - startTime.getTime();
1453
+ const elapsedMin = (elapsedMs / 60_000).toFixed(1);
1454
+ elapsed = `${elapsedMin} min`;
1455
+ // Estimate ETA
1456
+ if (currentStep !== 'N/A' && totalSteps !== 'N/A') {
1457
+ const step = parseInt(currentStep, 10);
1458
+ const total = parseInt(totalSteps, 10);
1459
+ if (step > 0 && total > step) {
1460
+ const msPerStep = elapsedMs / step;
1461
+ const remainingMs = msPerStep * (total - step);
1462
+ const remainingMin = (remainingMs / 60_000).toFixed(1);
1463
+ eta = `~${remainingMin} min remaining`;
1464
+ }
1465
+ }
1466
+ }
1467
+ // Last few log lines for context
1468
+ const recentLines = logLines.slice(-8).map(l => ` ${l}`).join('\n');
1469
+ return [
1470
+ `Local Training Status (${backend})`,
1471
+ '='.repeat(40),
1472
+ '',
1473
+ ` Status: ${status}`,
1474
+ ` Current step: ${currentStep}${totalSteps !== 'N/A' ? ` / ${totalSteps}` : ''}`,
1475
+ ` Current loss: ${currentLoss}`,
1476
+ ` Elapsed: ${elapsed}`,
1477
+ ` ETA: ${eta}`,
1478
+ ` Log file: ${logPath}`,
1479
+ '',
1480
+ 'Recent log output:',
1481
+ recentLines,
1482
+ ].join('\n');
1483
+ }
1484
+ catch (err) {
1485
+ return `Error checking training status: ${err instanceof Error ? err.message : String(err)}`;
1486
+ }
1487
+ },
1488
+ });
1489
+ // ── train_evaluate ─────────────────────────────────────────────────
1490
+ registerTool({
1491
+ name: 'train_evaluate',
1492
+ description: 'Evaluate a fine-tuned model against a test dataset. Runs test prompts through the model ' +
1493
+ 'and measures BLEU score, exact match rate, and average response length. ' +
1494
+ 'Supports local inference via Ollama or llama.cpp, and cloud via provider API.',
1495
+ parameters: {
1496
+ model: {
1497
+ type: 'string',
1498
+ description: 'Model name or path to evaluate',
1499
+ required: true,
1500
+ },
1501
+ test_data: {
1502
+ type: 'string',
1503
+ description: 'Path to test dataset (same format as training data)',
1504
+ required: true,
1505
+ },
1506
+ backend: {
1507
+ type: 'string',
1508
+ description: 'Inference backend: ollama, llama-cpp, openai, together, mistral (default: ollama)',
1509
+ },
1510
+ samples: {
1511
+ type: 'number',
1512
+ description: 'Maximum number of test samples to evaluate (default: 50)',
1513
+ },
1514
+ api_key: {
1515
+ type: 'string',
1516
+ description: 'API key for cloud backends (optional)',
1517
+ },
1518
+ },
1519
+ tier: 'pro',
1520
+ timeout: 600_000,
1521
+ async execute(args) {
1522
+ try {
1523
+ const model = String(args.model);
1524
+ const testDataPath = resolve(String(args.test_data));
1525
+ const backend = String(args.backend || 'ollama').toLowerCase();
1526
+ const maxSamples = typeof args.samples === 'number' ? args.samples : 50;
1527
+ if (!existsSync(testDataPath)) {
1528
+ return `Error: Test data not found: ${testDataPath}`;
1529
+ }
1530
+ // Parse test data — extract prompt/expected pairs
1531
+ const content = readFileSync(testDataPath, 'utf-8');
1532
+ let testCases = [];
1533
+ // Try parsing as JSONL first
1534
+ const lines = content.split('\n').filter(l => l.trim());
1535
+ for (const line of lines) {
1536
+ try {
1537
+ const obj = JSON.parse(line);
1538
+ if (obj.messages && Array.isArray(obj.messages)) {
1539
+ const msgs = obj.messages;
1540
+ const systemMsg = msgs.find(m => m.role === 'system');
1541
+ const userMsg = msgs.find(m => m.role === 'user');
1542
+ const assistantMsg = msgs.find(m => m.role === 'assistant');
1543
+ if (userMsg && assistantMsg) {
1544
+ testCases.push({
1545
+ prompt: userMsg.content,
1546
+ expected: assistantMsg.content,
1547
+ system: systemMsg?.content,
1548
+ });
1549
+ }
1550
+ }
1551
+ else if (obj.instruction && obj.output) {
1552
+ testCases.push({ prompt: obj.instruction, expected: obj.output });
1553
+ }
1554
+ else if (obj.conversations) {
1555
+ const convs = obj.conversations;
1556
+ const human = convs.find(c => c.from === 'human');
1557
+ const gpt = convs.find(c => c.from === 'gpt');
1558
+ const sys = convs.find(c => c.from === 'system');
1559
+ if (human && gpt) {
1560
+ testCases.push({ prompt: human.value, expected: gpt.value, system: sys?.value });
1561
+ }
1562
+ }
1563
+ }
1564
+ catch { /* skip malformed lines */ }
1565
+ }
1566
+ // If no JSONL, try as JSON array
1567
+ if (testCases.length === 0) {
1568
+ try {
1569
+ const parsed = JSON.parse(content);
1570
+ if (Array.isArray(parsed)) {
1571
+ for (const item of parsed) {
1572
+ if (item.messages) {
1573
+ const msgs = item.messages;
1574
+ const userMsg = msgs.find((m) => m.role === 'user');
1575
+ const assistantMsg = msgs.find((m) => m.role === 'assistant');
1576
+ if (userMsg && assistantMsg) {
1577
+ testCases.push({ prompt: userMsg.content, expected: assistantMsg.content });
1578
+ }
1579
+ }
1580
+ else if (item.instruction && item.output) {
1581
+ testCases.push({ prompt: item.instruction, expected: item.output });
1582
+ }
1583
+ }
1584
+ }
1585
+ }
1586
+ catch { /* not a JSON array */ }
1587
+ }
1588
+ if (testCases.length === 0) {
1589
+ return `Error: No test cases extracted from ${testDataPath}. Ensure the file contains valid JSONL/Alpaca/ShareGPT examples.`;
1590
+ }
1591
+ // Limit samples
1592
+ testCases = testCases.slice(0, maxSamples);
1593
+ // Run inference for each test case
1594
+ const results = [];
1595
+ let errorCount = 0;
1596
+ for (const tc of testCases) {
1597
+ let actual = '';
1598
+ try {
1599
+ if (backend === 'ollama') {
1600
+ if (!isCommandAvailable('ollama')) {
1601
+ return 'Error: Ollama is not installed. Install from https://ollama.ai';
1602
+ }
1603
+ const prompt = tc.system
1604
+ ? `System: ${tc.system}\n\nUser: ${tc.prompt}\n\nAssistant:`
1605
+ : `User: ${tc.prompt}\n\nAssistant:`;
1606
+ const result = shellSafe(`ollama run ${model} ${JSON.stringify(prompt)}`, { timeout: 60_000 });
1607
+ actual = result.ok ? result.output : '';
1608
+ }
1609
+ else if (backend === 'llama-cpp') {
1610
+ if (!isCommandAvailable('llama-cli')) {
1611
+ return 'Error: llama-cli is not installed. Build from https://github.com/ggerganov/llama.cpp';
1612
+ }
1613
+ const prompt = tc.prompt;
1614
+ const result = shellSafe(`llama-cli -m ${model} -p ${JSON.stringify(prompt)} -n 512 --temp 0.1`, { timeout: 120_000 });
1615
+ actual = result.ok ? result.output : '';
1616
+ }
1617
+ else if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
1618
+ const apiKey = getApiKey(backend, args.api_key ? String(args.api_key) : undefined);
1619
+ if (!apiKey) {
1620
+ return `Error: No API key for ${backend}. Set the appropriate env var or pass api_key.`;
1621
+ }
1622
+ const apiUrl = backend === 'openai' ? 'https://api.openai.com/v1/chat/completions'
1623
+ : backend === 'together' ? 'https://api.together.xyz/v1/chat/completions'
1624
+ : 'https://api.mistral.ai/v1/chat/completions';
1625
+ const messages = [];
1626
+ if (tc.system)
1627
+ messages.push({ role: 'system', content: tc.system });
1628
+ messages.push({ role: 'user', content: tc.prompt });
1629
+ const res = await fetch(apiUrl, {
1630
+ method: 'POST',
1631
+ headers: {
1632
+ 'Authorization': `Bearer ${apiKey}`,
1633
+ 'Content-Type': 'application/json',
1634
+ },
1635
+ body: JSON.stringify({
1636
+ model,
1637
+ messages,
1638
+ max_tokens: 1024,
1639
+ temperature: 0.1,
1640
+ }),
1641
+ });
1642
+ if (res.ok) {
1643
+ const data = await res.json();
1644
+ actual = data.choices?.[0]?.message?.content || '';
1645
+ }
1646
+ }
1647
+ }
1648
+ catch {
1649
+ errorCount++;
1650
+ continue;
1651
+ }
1652
+ if (!actual) {
1653
+ errorCount++;
1654
+ continue;
1655
+ }
1656
+ const bleu = bleuScore(tc.expected, actual);
1657
+ const exactMatch = normalizeText(actual) === normalizeText(tc.expected);
1658
+ results.push({
1659
+ prompt: tc.prompt.slice(0, 80),
1660
+ expected: tc.expected.slice(0, 80),
1661
+ actual: actual.slice(0, 80),
1662
+ bleu,
1663
+ exactMatch,
1664
+ });
1665
+ }
1666
+ if (results.length === 0) {
1667
+ return `Error: No successful evaluations. ${errorCount} inference calls failed. Check model name and backend.`;
1668
+ }
1669
+ // Compute aggregate metrics
1670
+ const avgBleu = results.reduce((s, r) => s + r.bleu, 0) / results.length;
1671
+ const exactMatchRate = results.filter(r => r.exactMatch).length / results.length;
1672
+ const avgResponseLen = results.reduce((s, r) => s + r.actual.length, 0) / results.length;
1673
+ const avgExpectedLen = results.reduce((s, r) => s + r.expected.length, 0) / results.length;
1674
+ // Build report
1675
+ const report = [
1676
+ `Model Evaluation Report`,
1677
+ '='.repeat(50),
1678
+ '',
1679
+ ` Model: ${model}`,
1680
+ ` Backend: ${backend}`,
1681
+ ` Test samples: ${results.length} / ${testCases.length}`,
1682
+ ` Errors: ${errorCount}`,
1683
+ '',
1684
+ `Metrics`,
1685
+ '─'.repeat(30),
1686
+ ` BLEU score: ${(avgBleu * 100).toFixed(1)}%`,
1687
+ ` Exact match: ${(exactMatchRate * 100).toFixed(1)}%`,
1688
+ ` Avg response: ${Math.round(avgResponseLen)} chars`,
1689
+ ` Avg expected: ${Math.round(avgExpectedLen)} chars`,
1690
+ ` Length ratio: ${(avgResponseLen / Math.max(avgExpectedLen, 1)).toFixed(2)}x`,
1691
+ ];
1692
+ // Show a few examples
1693
+ report.push('', 'Sample Results', '─'.repeat(30));
1694
+ for (const r of results.slice(0, 5)) {
1695
+ report.push(` Prompt: ${r.prompt}...`, ` Expected: ${r.expected}...`, ` Actual: ${r.actual}...`, ` BLEU: ${(r.bleu * 100).toFixed(1)}% | Match: ${r.exactMatch ? 'YES' : 'NO'}`, '');
1696
+ }
1697
+ return report.join('\n');
1698
+ }
1699
+ catch (err) {
1700
+ return `Evaluation error: ${err instanceof Error ? err.message : String(err)}`;
1701
+ }
1702
+ },
1703
+ });
1704
+ // ── train_export ───────────────────────────────────────────────────
1705
+ registerTool({
1706
+ name: 'train_export',
1707
+ description: 'Convert and export models between formats. Merge LoRA adapters back into base models, ' +
1708
+ 'convert HuggingFace models to GGUF, or quantize GGUF files for efficient deployment.',
1709
+ parameters: {
1710
+ model_path: {
1711
+ type: 'string',
1712
+ description: 'Path to model or adapter directory',
1713
+ required: true,
1714
+ },
1715
+ operation: {
1716
+ type: 'string',
1717
+ description: 'Operation: merge_lora, to_gguf, quantize',
1718
+ required: true,
1719
+ },
1720
+ output: {
1721
+ type: 'string',
1722
+ description: 'Output path (default: auto-generated)',
1723
+ },
1724
+ base_model: {
1725
+ type: 'string',
1726
+ description: 'Base model name/path (required for merge_lora)',
1727
+ },
1728
+ quantization: {
1729
+ type: 'string',
1730
+ description: 'Quantization type for to_gguf and quantize: q4_K_M, q5_K_M, q8_0, f16 (default: q4_K_M)',
1731
+ },
1732
+ },
1733
+ tier: 'pro',
1734
+ timeout: 600_000,
1735
+ async execute(args) {
1736
+ try {
1737
+ const modelPath = resolve(String(args.model_path));
1738
+ const operation = String(args.operation).toLowerCase();
1739
+ const quantType = String(args.quantization || 'q4_K_M');
1740
+ if (!existsSync(modelPath)) {
1741
+ return `Error: Model path not found: ${modelPath}`;
1742
+ }
1743
+ if (!['merge_lora', 'to_gguf', 'quantize'].includes(operation)) {
1744
+ return `Error: Invalid operation "${operation}". Use: merge_lora, to_gguf, quantize`;
1745
+ }
1746
+ if (operation === 'merge_lora') {
1747
+ const baseModel = args.base_model ? String(args.base_model) : '';
1748
+ if (!baseModel) {
1749
+ return 'Error: base_model is required for merge_lora operation.';
1750
+ }
1751
+ const outputPath = args.output
1752
+ ? resolve(String(args.output))
1753
+ : resolve(dirname(modelPath), `${basename(modelPath)}-merged`);
1754
+ // Try MLX merge first (Apple Silicon)
1755
+ const hasMlx = shellSafe('python3 -c "import mlx_lm"');
1756
+ if (hasMlx.ok) {
1757
+ const cmd = `python3 -m mlx_lm.fuse --model ${baseModel} --adapter-path ${modelPath} --save-path ${outputPath}`;
1758
+ const result = shellSafe(cmd, { timeout: 300_000 });
1759
+ if (result.ok) {
1760
+ return [
1761
+ `LoRA merge completed (MLX).`,
1762
+ '',
1763
+ ` Base model: ${baseModel}`,
1764
+ ` Adapter: ${modelPath}`,
1765
+ ` Merged output: ${outputPath}`,
1766
+ '',
1767
+ result.output,
1768
+ ].join('\n');
1769
+ }
1770
+ // Fall through to try other methods
1771
+ }
1772
+ // Try with PEFT (PyTorch)
1773
+ const mergeScript = `
1774
+ import torch
1775
+ from peft import PeftModel
1776
+ from transformers import AutoModelForCausalLM, AutoTokenizer
1777
+
1778
+ print("Loading base model: ${baseModel}")
1779
+ model = AutoModelForCausalLM.from_pretrained("${baseModel}", torch_dtype=torch.float16, device_map="auto")
1780
+ tokenizer = AutoTokenizer.from_pretrained("${baseModel}")
1781
+
1782
+ print("Loading LoRA adapter: ${modelPath}")
1783
+ model = PeftModel.from_pretrained(model, "${modelPath}")
1784
+
1785
+ print("Merging LoRA weights into base model...")
1786
+ model = model.merge_and_unload()
1787
+
1788
+ print("Saving merged model to: ${outputPath}")
1789
+ model.save_pretrained("${outputPath}")
1790
+ tokenizer.save_pretrained("${outputPath}")
1791
+ print("Done!")
1792
+ `;
1793
+ const scriptPath = join(dirname(modelPath), '_merge_lora.py');
1794
+ writeFileSync(scriptPath, mergeScript, 'utf-8');
1795
+ const result = shellSafe(`python3 ${scriptPath}`, { timeout: 600_000 });
1796
+ // Clean up script
1797
+ try {
1798
+ execSync(`rm -f ${scriptPath}`, { stdio: 'pipe' });
1799
+ }
1800
+ catch { /* ignore */ }
1801
+ if (!result.ok) {
1802
+ return `Error merging LoRA:\n${result.output}\n\nEnsure transformers and peft are installed: pip install transformers peft torch`;
1803
+ }
1804
+ return [
1805
+ `LoRA merge completed (PEFT).`,
1806
+ '',
1807
+ ` Base model: ${baseModel}`,
1808
+ ` Adapter: ${modelPath}`,
1809
+ ` Merged output: ${outputPath}`,
1810
+ '',
1811
+ result.output,
1812
+ ].join('\n');
1813
+ }
1814
+ else if (operation === 'to_gguf') {
1815
+ const outputPath = args.output
1816
+ ? resolve(String(args.output))
1817
+ : resolve(dirname(modelPath), `${basename(modelPath)}.${quantType}.gguf`);
1818
+ // Try llama.cpp's convert script
1819
+ const convertScript = shellSafe('which convert_hf_to_gguf.py || which convert-hf-to-gguf.py');
1820
+ let convertCmd;
1821
+ if (convertScript.ok && convertScript.output) {
1822
+ convertCmd = `python3 ${convertScript.output} ${modelPath} --outfile ${outputPath} --outtype ${quantType}`;
1823
+ }
1824
+ else {
1825
+ // Try finding it in common locations
1826
+ const commonPaths = [
1827
+ join(homedir(), 'llama.cpp/convert_hf_to_gguf.py'),
1828
+ join(homedir(), 'llama.cpp/convert-hf-to-gguf.py'),
1829
+ '/opt/llama.cpp/convert_hf_to_gguf.py',
1830
+ '/usr/local/share/llama.cpp/convert_hf_to_gguf.py',
1831
+ ];
1832
+ const found = commonPaths.find(p => existsSync(p));
1833
+ if (!found) {
1834
+ return [
1835
+ 'GGUF conversion script not found. Ensure llama.cpp is built:',
1836
+ '',
1837
+ ' git clone https://github.com/ggerganov/llama.cpp',
1838
+ ' cd llama.cpp',
1839
+ ' pip install -r requirements.txt',
1840
+ '',
1841
+ 'Then run:',
1842
+ ` python3 convert_hf_to_gguf.py ${modelPath} --outfile ${outputPath} --outtype ${quantType}`,
1843
+ ].join('\n');
1844
+ }
1845
+ convertCmd = `python3 ${found} ${modelPath} --outfile ${outputPath} --outtype ${quantType}`;
1846
+ }
1847
+ const result = shellSafe(convertCmd, { timeout: 600_000 });
1848
+ if (!result.ok) {
1849
+ return `Error converting to GGUF:\n${result.output}`;
1850
+ }
1851
+ const fileSize = existsSync(outputPath) ? (statSync(outputPath).size / (1024 * 1024 * 1024)).toFixed(2) : '?';
1852
+ return [
1853
+ `GGUF conversion completed.`,
1854
+ '',
1855
+ ` Input: ${modelPath}`,
1856
+ ` Output: ${outputPath}`,
1857
+ ` Quantization: ${quantType}`,
1858
+ ` File size: ${fileSize} GB`,
1859
+ '',
1860
+ result.output,
1861
+ ].join('\n');
1862
+ }
1863
+ else if (operation === 'quantize') {
1864
+ if (!isCommandAvailable('llama-quantize')) {
1865
+ return [
1866
+ 'llama-quantize is not installed. Build from llama.cpp:',
1867
+ '',
1868
+ ' git clone https://github.com/ggerganov/llama.cpp',
1869
+ ' cd llama.cpp',
1870
+ ' make llama-quantize',
1871
+ ].join('\n');
1872
+ }
1873
+ const outputPath = args.output
1874
+ ? resolve(String(args.output))
1875
+ : modelPath.replace(/\.gguf$/, '') + `.${quantType}.gguf`;
1876
+ const result = shellSafe(`llama-quantize ${modelPath} ${outputPath} ${quantType}`, { timeout: 600_000 });
1877
+ if (!result.ok) {
1878
+ return `Error quantizing model:\n${result.output}`;
1879
+ }
1880
+ const inputSize = (statSync(modelPath).size / (1024 * 1024 * 1024)).toFixed(2);
1881
+ const outputSize = existsSync(outputPath) ? (statSync(outputPath).size / (1024 * 1024 * 1024)).toFixed(2) : '?';
1882
+ return [
1883
+ `Quantization completed.`,
1884
+ '',
1885
+ ` Input: ${modelPath} (${inputSize} GB)`,
1886
+ ` Output: ${outputPath} (${outputSize} GB)`,
1887
+ ` Type: ${quantType}`,
1888
+ ` Compression: ${inputSize !== '?' && outputSize !== '?'
1889
+ ? ((1 - parseFloat(outputSize) / parseFloat(inputSize)) * 100).toFixed(1) + '%'
1890
+ : 'N/A'}`,
1891
+ '',
1892
+ result.output,
1893
+ ].join('\n');
1894
+ }
1895
+ return `Error: Operation "${operation}" not handled.`;
1896
+ }
1897
+ catch (err) {
1898
+ return `Export error: ${err instanceof Error ? err.message : String(err)}`;
1899
+ }
1900
+ },
1901
+ });
1902
+ // ── train_deploy ───────────────────────────────────────────────────
1903
+ registerTool({
1904
+ name: 'train_deploy',
1905
+ description: 'Deploy a fine-tuned model. Targets: Ollama (local serving), HuggingFace Hub (public/private repo), ' +
1906
+ 'or K:BOT local models directory (~/.kbot/models/).',
1907
+ parameters: {
1908
+ model_path: {
1909
+ type: 'string',
1910
+ description: 'Path to the model file (GGUF) or directory',
1911
+ required: true,
1912
+ },
1913
+ target: {
1914
+ type: 'string',
1915
+ description: 'Deployment target: ollama, huggingface, kbot-local',
1916
+ required: true,
1917
+ },
1918
+ name: {
1919
+ type: 'string',
1920
+ description: 'Model name for the deployment',
1921
+ required: true,
1922
+ },
1923
+ description: {
1924
+ type: 'string',
1925
+ description: 'Model description (optional)',
1926
+ },
1927
+ },
1928
+ tier: 'pro',
1929
+ timeout: 600_000,
1930
+ async execute(args) {
1931
+ try {
1932
+ const modelPath = resolve(String(args.model_path));
1933
+ const target = String(args.target).toLowerCase();
1934
+ const name = String(args.name);
1935
+ const description = args.description ? String(args.description) : `Fine-tuned model: ${name}`;
1936
+ if (!existsSync(modelPath)) {
1937
+ return `Error: Model path not found: ${modelPath}`;
1938
+ }
1939
+ if (!['ollama', 'huggingface', 'kbot-local'].includes(target)) {
1940
+ return `Error: Invalid target "${target}". Use: ollama, huggingface, kbot-local`;
1941
+ }
1942
+ if (target === 'ollama') {
1943
+ if (!isCommandAvailable('ollama')) {
1944
+ return 'Error: Ollama is not installed. Download from https://ollama.ai';
1945
+ }
1946
+ // Determine if model is GGUF file or directory
1947
+ const isGguf = modelPath.endsWith('.gguf');
1948
+ const fromLine = isGguf ? `FROM ${modelPath}` : `FROM ${modelPath}`;
1949
+ // Create a Modelfile
1950
+ const modelfileContent = [
1951
+ fromLine,
1952
+ '',
1953
+ `PARAMETER temperature 0.7`,
1954
+ `PARAMETER top_p 0.9`,
1955
+ `PARAMETER top_k 40`,
1956
+ '',
1957
+ `SYSTEM """${description}"""`,
1958
+ ].join('\n');
1959
+ const modelfilePath = join(dirname(modelPath), 'Modelfile');
1960
+ writeFileSync(modelfilePath, modelfileContent, 'utf-8');
1961
+ // Create the Ollama model
1962
+ const result = shellSafe(`ollama create ${name} -f ${modelfilePath}`, { timeout: 300_000 });
1963
+ if (!result.ok) {
1964
+ return `Error creating Ollama model:\n${result.output}\n\nModelfile written to: ${modelfilePath}`;
1965
+ }
1966
+ return [
1967
+ `Model deployed to Ollama.`,
1968
+ '',
1969
+ ` Name: ${name}`,
1970
+ ` Source: ${modelPath}`,
1971
+ ` Modelfile: ${modelfilePath}`,
1972
+ '',
1973
+ `Run it with: ollama run ${name}`,
1974
+ `Use in K:BOT: kbot --model ${name}`,
1975
+ '',
1976
+ result.output,
1977
+ ].join('\n');
1978
+ }
1979
+ else if (target === 'huggingface') {
1980
+ if (!isCommandAvailable('huggingface-cli')) {
1981
+ return [
1982
+ 'HuggingFace CLI is not installed. Install with:',
1983
+ '',
1984
+ ' pip install huggingface_hub[cli]',
1985
+ '',
1986
+ 'Then authenticate:',
1987
+ ' huggingface-cli login',
1988
+ ].join('\n');
1989
+ }
1990
+ // Check if HF_TOKEN is available
1991
+ const hasToken = process.env.HF_TOKEN || process.env.HUGGING_FACE_HUB_TOKEN;
1992
+ if (!hasToken) {
1993
+ const loginCheck = shellSafe('huggingface-cli whoami');
1994
+ if (!loginCheck.ok) {
1995
+ return 'Error: Not authenticated with HuggingFace. Run: huggingface-cli login\nOr set HF_TOKEN environment variable.';
1996
+ }
1997
+ }
1998
+ const result = shellSafe(`huggingface-cli upload ${name} ${modelPath}`, { timeout: 600_000 });
1999
+ if (!result.ok) {
2000
+ return `Error uploading to HuggingFace:\n${result.output}`;
2001
+ }
2002
+ return [
2003
+ `Model uploaded to HuggingFace Hub.`,
2004
+ '',
2005
+ ` Repository: ${name}`,
2006
+ ` Source: ${modelPath}`,
2007
+ ` URL: https://huggingface.co/${name}`,
2008
+ '',
2009
+ result.output,
2010
+ ].join('\n');
2011
+ }
2012
+ else if (target === 'kbot-local') {
2013
+ const modelsDir = join(homedir(), '.kbot', 'models');
2014
+ mkdirSync(modelsDir, { recursive: true });
2015
+ const isGguf = modelPath.endsWith('.gguf');
2016
+ const destFilename = isGguf ? `${name}.gguf` : name;
2017
+ const destPath = join(modelsDir, destFilename);
2018
+ // Copy model file(s)
2019
+ const stat = statSync(modelPath);
2020
+ if (stat.isFile()) {
2021
+ const result = shellSafe(`cp ${modelPath} ${destPath}`, { timeout: 120_000 });
2022
+ if (!result.ok) {
2023
+ return `Error copying model: ${result.output}`;
2024
+ }
2025
+ }
2026
+ else if (stat.isDirectory()) {
2027
+ const destDir = join(modelsDir, name);
2028
+ mkdirSync(destDir, { recursive: true });
2029
+ const result = shellSafe(`cp -r ${modelPath}/* ${destDir}/`, { timeout: 300_000 });
2030
+ if (!result.ok) {
2031
+ return `Error copying model directory: ${result.output}`;
2032
+ }
2033
+ }
2034
+ // Register in K:BOT config
2035
+ const configPath = join(homedir(), '.kbot', 'config.json');
2036
+ let config = {};
2037
+ if (existsSync(configPath)) {
2038
+ try {
2039
+ config = JSON.parse(readFileSync(configPath, 'utf-8'));
2040
+ }
2041
+ catch { /* start fresh */ }
2042
+ }
2043
+ if (!config.local_models || !Array.isArray(config.local_models)) {
2044
+ config.local_models = [];
2045
+ }
2046
+ const modelEntry = {
2047
+ name,
2048
+ path: stat.isDirectory() ? join(modelsDir, name) : destPath,
2049
+ description,
2050
+ added: new Date().toISOString(),
2051
+ type: isGguf ? 'gguf' : 'directory',
2052
+ };
2053
+ // Remove existing entry with same name
2054
+ config.local_models = config.local_models.filter(m => m.name !== name);
2055
+ config.local_models.push(modelEntry);
2056
+ writeFileSync(configPath, JSON.stringify(config, null, 2), 'utf-8');
2057
+ const fileSize = stat.isFile()
2058
+ ? `${(stat.size / (1024 * 1024 * 1024)).toFixed(2)} GB`
2059
+ : 'directory';
2060
+ return [
2061
+ `Model registered in K:BOT local models.`,
2062
+ '',
2063
+ ` Name: ${name}`,
2064
+ ` Path: ${modelEntry.path}`,
2065
+ ` Size: ${fileSize}`,
2066
+ ` Description: ${description}`,
2067
+ '',
2068
+ `Use in K:BOT: kbot --model ${name}`,
2069
+ `List models: kbot models`,
2070
+ ].join('\n');
2071
+ }
2072
+ return `Error: Target "${target}" not handled.`;
2073
+ }
2074
+ catch (err) {
2075
+ return `Deploy error: ${err instanceof Error ? err.message : String(err)}`;
2076
+ }
2077
+ },
2078
+ });
2079
+ // ── train_cost ─────────────────────────────────────────────────────
2080
+ registerTool({
2081
+ name: 'train_cost',
2082
+ description: 'Estimate the cost, time, and VRAM requirements for a fine-tuning job before starting. ' +
2083
+ 'For cloud backends, calculates token-based pricing. For local, estimates GPU hours and VRAM usage.',
2084
+ parameters: {
2085
+ dataset: {
2086
+ type: 'string',
2087
+ description: 'Path to training dataset file',
2088
+ required: true,
2089
+ },
2090
+ base_model: {
2091
+ type: 'string',
2092
+ description: 'Base model to fine-tune (e.g., gpt-4.1-mini, llama-3-8b)',
2093
+ required: true,
2094
+ },
2095
+ backend: {
2096
+ type: 'string',
2097
+ description: 'Training backend: openai, together, mistral, mlx, unsloth, llama-cpp',
2098
+ required: true,
2099
+ },
2100
+ epochs: {
2101
+ type: 'number',
2102
+ description: 'Number of training epochs (default: 3)',
2103
+ },
2104
+ },
2105
+ tier: 'free',
2106
+ timeout: 30_000,
2107
+ async execute(args) {
2108
+ try {
2109
+ const datasetPath = resolve(String(args.dataset));
2110
+ const baseModel = String(args.base_model);
2111
+ const backend = String(args.backend).toLowerCase();
2112
+ const epochs = typeof args.epochs === 'number' ? args.epochs : 3;
2113
+ if (!existsSync(datasetPath)) {
2114
+ return `Error: Dataset not found: ${datasetPath}`;
2115
+ }
2116
+ const validBackends = ['openai', 'together', 'mistral', 'mlx', 'unsloth', 'llama-cpp'];
2117
+ if (!validBackends.includes(backend)) {
2118
+ return `Error: Invalid backend "${backend}". Supported: ${validBackends.join(', ')}`;
2119
+ }
2120
+ // Calculate dataset size
2121
+ const content = readFileSync(datasetPath, 'utf-8');
2122
+ const fileSize = statSync(datasetPath).size;
2123
+ const totalTokens = estimateTokens(content);
2124
+ const trainingTokens = totalTokens * epochs;
2125
+ // Estimate model parameters
2126
+ const modelParams = estimateModelParams(baseModel);
2127
+ const modelParamsB = (modelParams / 1e9).toFixed(1);
2128
+ // ── Cloud pricing ────────────────────────────────────────────
2129
+ const cloudPricing = {
2130
+ openai: {
2131
+ perKToken: baseModel.includes('gpt-4.1-mini') ? 0.008 : baseModel.includes('gpt-4.1') ? 0.025 : 0.008,
2132
+ label: baseModel.includes('gpt-4.1') && !baseModel.includes('mini') ? 'GPT-4.1' : 'GPT-4.1 Mini',
2133
+ },
2134
+ together: {
2135
+ perKToken: 0.004,
2136
+ label: 'Together AI',
2137
+ },
2138
+ mistral: {
2139
+ perKToken: 0.008,
2140
+ label: 'Mistral',
2141
+ },
2142
+ };
2143
+ // ── VRAM estimation ──────────────────────────────────────────
2144
+ // Full fine-tuning: ~18-20 bytes per param (fp16 model + optimizer states + gradients)
2145
+ // QLoRA: ~4-6 bytes per param (4-bit model + small LoRA overhead)
2146
+ // LoRA (fp16): ~10-12 bytes per param (fp16 model + small LoRA)
2147
+ let vramGB;
2148
+ let vramMethod;
2149
+ if (backend === 'mlx') {
2150
+ // MLX uses unified memory, fp16 model + LoRA
2151
+ vramGB = (modelParams * 4) / (1024 * 1024 * 1024); // ~4 bytes for 4-bit + LoRA overhead
2152
+ vramMethod = '4-bit + LoRA (MLX unified memory)';
2153
+ }
2154
+ else if (backend === 'unsloth') {
2155
+ // Unsloth uses 4-bit QLoRA
2156
+ vramGB = (modelParams * 5) / (1024 * 1024 * 1024);
2157
+ vramMethod = '4-bit QLoRA (Unsloth)';
2158
+ }
2159
+ else if (backend === 'llama-cpp') {
2160
+ // llama.cpp LoRA fine-tuning
2161
+ vramGB = (modelParams * 6) / (1024 * 1024 * 1024);
2162
+ vramMethod = 'LoRA (llama.cpp, CPU/GPU mixed)';
2163
+ }
2164
+ else {
2165
+ // Cloud — user doesn't need to worry about VRAM
2166
+ vramGB = 0;
2167
+ vramMethod = 'Cloud-managed (no local VRAM needed)';
2168
+ }
2169
+ // ── Time estimation ──────────────────────────────────────────
2170
+ // Rough estimates based on model size and training tokens
2171
+ let estimatedMinutes;
2172
+ let timeNote;
2173
+ if (backend === 'openai') {
2174
+ // OpenAI typically processes ~1M tokens/hour for fine-tuning
2175
+ estimatedMinutes = (trainingTokens / 1_000_000) * 60;
2176
+ timeNote = 'OpenAI job queue + training';
2177
+ }
2178
+ else if (backend === 'together') {
2179
+ estimatedMinutes = (trainingTokens / 800_000) * 60;
2180
+ timeNote = 'Together AI queue + training';
2181
+ }
2182
+ else if (backend === 'mistral') {
2183
+ estimatedMinutes = (trainingTokens / 700_000) * 60;
2184
+ timeNote = 'Mistral queue + training';
2185
+ }
2186
+ else if (backend === 'mlx') {
2187
+ // Apple Silicon: ~200-500 tokens/sec for 7B model LoRA
2188
+ const tokPerSec = modelParams <= 8e9 ? 400 : modelParams <= 14e9 ? 150 : 50;
2189
+ estimatedMinutes = (trainingTokens / tokPerSec) / 60;
2190
+ timeNote = `~${tokPerSec} tok/s estimated for ${modelParamsB}B on Apple Silicon`;
2191
+ }
2192
+ else if (backend === 'unsloth') {
2193
+ // Unsloth with consumer GPU: ~500-1000 tokens/sec for 7B
2194
+ const tokPerSec = modelParams <= 8e9 ? 800 : modelParams <= 14e9 ? 300 : 100;
2195
+ estimatedMinutes = (trainingTokens / tokPerSec) / 60;
2196
+ timeNote = `~${tokPerSec} tok/s estimated for ${modelParamsB}B with Unsloth`;
2197
+ }
2198
+ else {
2199
+ // llama.cpp CPU: ~50-200 tokens/sec
2200
+ const tokPerSec = modelParams <= 8e9 ? 150 : modelParams <= 14e9 ? 50 : 15;
2201
+ estimatedMinutes = (trainingTokens / tokPerSec) / 60;
2202
+ timeNote = `~${tokPerSec} tok/s estimated for ${modelParamsB}B on CPU`;
2203
+ }
2204
+ // Format time
2205
+ let timeStr;
2206
+ if (estimatedMinutes < 60) {
2207
+ timeStr = `${Math.ceil(estimatedMinutes)} minutes`;
2208
+ }
2209
+ else if (estimatedMinutes < 1440) {
2210
+ timeStr = `${(estimatedMinutes / 60).toFixed(1)} hours`;
2211
+ }
2212
+ else {
2213
+ timeStr = `${(estimatedMinutes / 1440).toFixed(1)} days`;
2214
+ }
2215
+ // ── Cost for cloud ───────────────────────────────────────────
2216
+ let costStr;
2217
+ if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
2218
+ const pricing = cloudPricing[backend];
2219
+ const cost = (trainingTokens / 1000) * pricing.perKToken;
2220
+ costStr = `$${cost.toFixed(2)} (${pricing.label} @ $${pricing.perKToken}/1K tokens)`;
2221
+ }
2222
+ else {
2223
+ // Local: estimate electricity cost
2224
+ let wattage;
2225
+ if (backend === 'mlx')
2226
+ wattage = 30; // Apple Silicon TDP
2227
+ else if (backend === 'unsloth')
2228
+ wattage = 300; // GPU TDP
2229
+ else
2230
+ wattage = 150; // CPU
2231
+ const kWh = (wattage * estimatedMinutes / 60) / 1000;
2232
+ const electricityCost = kWh * 0.15; // $0.15/kWh average
2233
+ costStr = `~$${electricityCost.toFixed(2)} electricity (${kWh.toFixed(2)} kWh @ $0.15/kWh)`;
2234
+ }
2235
+ // ── GPU recommendations ──────────────────────────────────────
2236
+ let gpuRec;
2237
+ if (backend === 'mlx') {
2238
+ if (vramGB <= 8)
2239
+ gpuRec = 'M1/M2/M3 with 8GB+ unified memory';
2240
+ else if (vramGB <= 16)
2241
+ gpuRec = 'M1 Pro/Max/M2 Pro/Max with 16GB+ unified memory';
2242
+ else if (vramGB <= 36)
2243
+ gpuRec = 'M2 Max/M3 Max with 36GB+ unified memory';
2244
+ else if (vramGB <= 64)
2245
+ gpuRec = 'M2 Ultra/M3 Ultra with 64GB+ unified memory';
2246
+ else
2247
+ gpuRec = 'M2 Ultra/M3 Ultra with 128GB+ unified memory (or use cloud)';
2248
+ }
2249
+ else if (backend === 'unsloth') {
2250
+ if (vramGB <= 8)
2251
+ gpuRec = 'RTX 3060 12GB / RTX 4060 Ti';
2252
+ else if (vramGB <= 16)
2253
+ gpuRec = 'RTX 3090 / RTX 4080 / A5000';
2254
+ else if (vramGB <= 24)
2255
+ gpuRec = 'RTX 3090 Ti / RTX 4090 / A5000';
2256
+ else if (vramGB <= 48)
2257
+ gpuRec = 'A6000 / 2x RTX 4090';
2258
+ else
2259
+ gpuRec = 'A100 80GB / H100 (or use cloud)';
2260
+ }
2261
+ else if (backend === 'llama-cpp') {
2262
+ const ramGB = Math.ceil(vramGB * 1.5);
2263
+ gpuRec = `${ramGB}GB+ system RAM (llama.cpp uses CPU + optional GPU offload)`;
2264
+ }
2265
+ else {
2266
+ gpuRec = 'N/A (cloud-managed)';
2267
+ }
2268
+ // ── Build report ─────────────────────────────────────────────
2269
+ const lines = content.split('\n').filter(l => l.trim()).length;
2270
+ return [
2271
+ `Training Cost Estimate`,
2272
+ '='.repeat(50),
2273
+ '',
2274
+ `Dataset`,
2275
+ '─'.repeat(30),
2276
+ ` File: ${datasetPath}`,
2277
+ ` File size: ${(fileSize / 1024).toFixed(1)} KB`,
2278
+ ` Examples: ~${lines.toLocaleString()}`,
2279
+ ` Tokens: ~${totalTokens.toLocaleString()}`,
2280
+ ` Training tokens: ~${trainingTokens.toLocaleString()} (${epochs} epochs)`,
2281
+ '',
2282
+ `Model`,
2283
+ '─'.repeat(30),
2284
+ ` Base model: ${baseModel}`,
2285
+ ` Est. params: ${modelParamsB}B`,
2286
+ ` Backend: ${backend}`,
2287
+ '',
2288
+ `Cost`,
2289
+ '─'.repeat(30),
2290
+ ` Estimated: ${costStr}`,
2291
+ '',
2292
+ `Time`,
2293
+ '─'.repeat(30),
2294
+ ` Estimated: ${timeStr}`,
2295
+ ` Note: ${timeNote}`,
2296
+ '',
2297
+ `Resources`,
2298
+ '─'.repeat(30),
2299
+ ` VRAM needed: ${vramGB > 0 ? `${vramGB.toFixed(1)} GB` : 'N/A (cloud)'}`,
2300
+ ` Method: ${vramMethod}`,
2301
+ ` Recommended: ${gpuRec}`,
2302
+ '',
2303
+ `Note: These are rough estimates. Actual cost and time depend on dataset complexity,`,
2304
+ `hardware configuration, queue times (cloud), and hyperparameters.`,
2305
+ ].join('\n');
2306
+ }
2307
+ catch (err) {
2308
+ return `Cost estimation error: ${err instanceof Error ? err.message : String(err)}`;
2309
+ }
2310
+ },
2311
+ });
2312
+ }
2313
+ //# sourceMappingURL=training.js.map