@kernel.chat/kbot 2.23.2 → 2.25.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/pair.d.ts +81 -0
- package/dist/pair.d.ts.map +1 -0
- package/dist/pair.js +993 -0
- package/dist/pair.js.map +1 -0
- package/dist/plugin-sdk.d.ts +136 -0
- package/dist/plugin-sdk.d.ts.map +1 -0
- package/dist/plugin-sdk.js +946 -0
- package/dist/plugin-sdk.js.map +1 -0
- package/dist/record.d.ts +174 -0
- package/dist/record.d.ts.map +1 -0
- package/dist/record.js +1182 -0
- package/dist/record.js.map +1 -0
- package/dist/team.d.ts +106 -0
- package/dist/team.d.ts.map +1 -0
- package/dist/team.js +917 -0
- package/dist/team.js.map +1 -0
- package/dist/tools/database.d.ts +2 -0
- package/dist/tools/database.d.ts.map +1 -0
- package/dist/tools/database.js +751 -0
- package/dist/tools/database.js.map +1 -0
- package/dist/tools/deploy.d.ts +2 -0
- package/dist/tools/deploy.d.ts.map +1 -0
- package/dist/tools/deploy.js +824 -0
- package/dist/tools/deploy.js.map +1 -0
- package/dist/tools/index.d.ts.map +1 -1
- package/dist/tools/index.js +13 -1
- package/dist/tools/index.js.map +1 -1
- package/dist/tools/mcp-marketplace.d.ts +2 -0
- package/dist/tools/mcp-marketplace.d.ts.map +1 -0
- package/dist/tools/mcp-marketplace.js +759 -0
- package/dist/tools/mcp-marketplace.js.map +1 -0
- package/dist/tools/training.d.ts +2 -0
- package/dist/tools/training.d.ts.map +1 -0
- package/dist/tools/training.js +2313 -0
- package/dist/tools/training.js.map +1 -0
- package/package.json +35 -3
|
@@ -0,0 +1,2313 @@
|
|
|
1
|
+
// K:BOT Training & Fine-Tuning Tools
|
|
2
|
+
//
|
|
3
|
+
// Prepare datasets, launch fine-tuning jobs (cloud + local), monitor training,
|
|
4
|
+
// evaluate results, export/convert models, deploy, and estimate costs.
|
|
5
|
+
//
|
|
6
|
+
// Tools:
|
|
7
|
+
// train_prepare — Convert data into training formats (JSONL/Alpaca/ShareGPT)
|
|
8
|
+
// train_validate — Validate a training dataset before launching
|
|
9
|
+
// train_start — Launch fine-tuning (OpenAI, Together, Mistral, MLX, Unsloth, llama.cpp)
|
|
10
|
+
// train_status — Check training job status (cloud or local)
|
|
11
|
+
// train_evaluate — Evaluate a fine-tuned model against test data
|
|
12
|
+
// train_export — Merge LoRA, convert to GGUF, quantize
|
|
13
|
+
// train_deploy — Deploy to Ollama, HuggingFace, or K:BOT local
|
|
14
|
+
// train_cost — Estimate training cost, time, and VRAM
|
|
15
|
+
import { execSync, spawn } from 'node:child_process';
|
|
16
|
+
import { existsSync, readFileSync, writeFileSync, mkdirSync, readdirSync, statSync } from 'node:fs';
|
|
17
|
+
import { resolve, join, basename, extname, dirname } from 'node:path';
|
|
18
|
+
import { homedir, cpus } from 'node:os';
|
|
19
|
+
import { registerTool } from './index.js';
|
|
20
|
+
// ── Helpers ──────────────────────────────────────────────────────────
|
|
21
|
+
function shell(cmd, opts) {
|
|
22
|
+
return execSync(cmd, {
|
|
23
|
+
encoding: 'utf-8',
|
|
24
|
+
timeout: opts?.timeout ?? 60_000,
|
|
25
|
+
maxBuffer: 50 * 1024 * 1024,
|
|
26
|
+
cwd: opts?.cwd ?? process.cwd(),
|
|
27
|
+
env: opts?.env ?? process.env,
|
|
28
|
+
stdio: ['pipe', 'pipe', 'pipe'],
|
|
29
|
+
}).trim();
|
|
30
|
+
}
|
|
31
|
+
function shellSafe(cmd, opts) {
|
|
32
|
+
try {
|
|
33
|
+
const output = shell(cmd, opts);
|
|
34
|
+
return { ok: true, output };
|
|
35
|
+
}
|
|
36
|
+
catch (err) {
|
|
37
|
+
const e = err;
|
|
38
|
+
const output = [e.stdout, e.stderr].filter(Boolean).join('\n').trim();
|
|
39
|
+
return { ok: false, output: output || e.message || 'Command failed' };
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
function isCommandAvailable(cmd) {
|
|
43
|
+
try {
|
|
44
|
+
execSync(`which ${cmd}`, { encoding: 'utf-8', timeout: 5_000, stdio: ['pipe', 'pipe', 'pipe'] });
|
|
45
|
+
return true;
|
|
46
|
+
}
|
|
47
|
+
catch {
|
|
48
|
+
return false;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
function estimateTokens(text) {
|
|
52
|
+
// Rough estimate: ~4 chars per token for English text
|
|
53
|
+
return Math.ceil(text.length / 4);
|
|
54
|
+
}
|
|
55
|
+
function readKbotConfig() {
|
|
56
|
+
const configPath = join(homedir(), '.kbot', 'config.json');
|
|
57
|
+
if (!existsSync(configPath))
|
|
58
|
+
return {};
|
|
59
|
+
try {
|
|
60
|
+
return JSON.parse(readFileSync(configPath, 'utf-8'));
|
|
61
|
+
}
|
|
62
|
+
catch {
|
|
63
|
+
return {};
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
function getApiKey(backend, explicit) {
|
|
67
|
+
if (explicit)
|
|
68
|
+
return explicit;
|
|
69
|
+
// Check environment variables
|
|
70
|
+
const envVarMap = {
|
|
71
|
+
openai: ['OPENAI_API_KEY'],
|
|
72
|
+
together: ['TOGETHER_API_KEY', 'TOGETHER_AI_KEY'],
|
|
73
|
+
mistral: ['MISTRAL_API_KEY'],
|
|
74
|
+
};
|
|
75
|
+
for (const envVar of envVarMap[backend]) {
|
|
76
|
+
if (process.env[envVar])
|
|
77
|
+
return process.env[envVar];
|
|
78
|
+
}
|
|
79
|
+
// Check ~/.kbot/config.json
|
|
80
|
+
const config = readKbotConfig();
|
|
81
|
+
const configKeyMap = {
|
|
82
|
+
openai: ['openai_api_key', 'openaiApiKey'],
|
|
83
|
+
together: ['together_api_key', 'togetherApiKey'],
|
|
84
|
+
mistral: ['mistral_api_key', 'mistralApiKey'],
|
|
85
|
+
};
|
|
86
|
+
for (const key of configKeyMap[backend]) {
|
|
87
|
+
if (config[key] && typeof config[key] === 'string')
|
|
88
|
+
return config[key];
|
|
89
|
+
}
|
|
90
|
+
return null;
|
|
91
|
+
}
|
|
92
|
+
function normalizeText(text) {
|
|
93
|
+
return text.toLowerCase().replace(/\s+/g, ' ').trim();
|
|
94
|
+
}
|
|
95
|
+
/** Recursively collect files from a directory matching given extensions */
|
|
96
|
+
function collectFiles(dirPath, extensions) {
|
|
97
|
+
const results = [];
|
|
98
|
+
if (!existsSync(dirPath))
|
|
99
|
+
return results;
|
|
100
|
+
const stat = statSync(dirPath);
|
|
101
|
+
if (stat.isFile()) {
|
|
102
|
+
const ext = extname(dirPath).toLowerCase();
|
|
103
|
+
if (extensions.length === 0 || extensions.includes(ext)) {
|
|
104
|
+
results.push(dirPath);
|
|
105
|
+
}
|
|
106
|
+
return results;
|
|
107
|
+
}
|
|
108
|
+
if (!stat.isDirectory())
|
|
109
|
+
return results;
|
|
110
|
+
const entries = readdirSync(dirPath);
|
|
111
|
+
for (const entry of entries) {
|
|
112
|
+
if (entry.startsWith('.') || entry === 'node_modules')
|
|
113
|
+
continue;
|
|
114
|
+
const fullPath = join(dirPath, entry);
|
|
115
|
+
const entryStat = statSync(fullPath);
|
|
116
|
+
if (entryStat.isDirectory()) {
|
|
117
|
+
results.push(...collectFiles(fullPath, extensions));
|
|
118
|
+
}
|
|
119
|
+
else {
|
|
120
|
+
const ext = extname(entry).toLowerCase();
|
|
121
|
+
if (extensions.length === 0 || extensions.includes(ext)) {
|
|
122
|
+
results.push(fullPath);
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
return results;
|
|
127
|
+
}
|
|
128
|
+
// ── Source parsing helpers ────────────────────────────────────────────
|
|
129
|
+
/** Extract instruction/response pairs from a JSON conversation file */
|
|
130
|
+
function parseConversationJson(content) {
|
|
131
|
+
const pairs = [];
|
|
132
|
+
try {
|
|
133
|
+
const data = JSON.parse(content);
|
|
134
|
+
// Handle array of messages
|
|
135
|
+
if (Array.isArray(data)) {
|
|
136
|
+
for (let i = 0; i < data.length - 1; i++) {
|
|
137
|
+
const curr = data[i];
|
|
138
|
+
const next = data[i + 1];
|
|
139
|
+
if ((curr.role === 'user' || curr.from === 'human') &&
|
|
140
|
+
(next.role === 'assistant' || next.from === 'gpt' || next.from === 'assistant')) {
|
|
141
|
+
pairs.push({
|
|
142
|
+
instruction: curr.content || curr.value || curr.text || '',
|
|
143
|
+
response: next.content || next.value || next.text || '',
|
|
144
|
+
});
|
|
145
|
+
i++; // skip the response message
|
|
146
|
+
}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
// Handle object with messages array
|
|
150
|
+
if (data.messages && Array.isArray(data.messages)) {
|
|
151
|
+
return parseConversationJson(JSON.stringify(data.messages));
|
|
152
|
+
}
|
|
153
|
+
// Handle object with conversations array
|
|
154
|
+
if (data.conversations && Array.isArray(data.conversations)) {
|
|
155
|
+
return parseConversationJson(JSON.stringify(data.conversations));
|
|
156
|
+
}
|
|
157
|
+
// Handle array of conversation objects
|
|
158
|
+
if (Array.isArray(data) && data.length > 0 && (data[0].messages || data[0].conversations)) {
|
|
159
|
+
for (const item of data) {
|
|
160
|
+
const msgs = item.messages || item.conversations;
|
|
161
|
+
if (Array.isArray(msgs)) {
|
|
162
|
+
pairs.push(...parseConversationJson(JSON.stringify(msgs)));
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
catch {
|
|
168
|
+
// Not valid JSON — skip
|
|
169
|
+
}
|
|
170
|
+
return pairs;
|
|
171
|
+
}
|
|
172
|
+
/** Extract instruction/response pairs from markdown (headings -> content) */
|
|
173
|
+
function parseMarkdown(content) {
|
|
174
|
+
const pairs = [];
|
|
175
|
+
const sections = content.split(/^(#{1,6}\s+.+)$/m);
|
|
176
|
+
for (let i = 1; i < sections.length - 1; i += 2) {
|
|
177
|
+
const heading = sections[i].replace(/^#+\s+/, '').trim();
|
|
178
|
+
const body = sections[i + 1].trim();
|
|
179
|
+
if (heading && body && body.length > 20) {
|
|
180
|
+
pairs.push({ instruction: heading, response: body });
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
return pairs;
|
|
184
|
+
}
|
|
185
|
+
/** Extract instruction/response pairs from code files (docstrings + implementations) */
|
|
186
|
+
function parseCodeFile(content, ext) {
|
|
187
|
+
const pairs = [];
|
|
188
|
+
if (['.py'].includes(ext)) {
|
|
189
|
+
// Python: extract function definitions with docstrings
|
|
190
|
+
const funcPattern = /^(def\s+\w+\s*\([^)]*\)(?:\s*->\s*[^:]+)?)\s*:\s*\n\s*("""[\s\S]*?"""|'''[\s\S]*?''')\s*\n([\s\S]*?)(?=\ndef\s|\nclass\s|$)/gm;
|
|
191
|
+
let match;
|
|
192
|
+
while ((match = funcPattern.exec(content)) !== null) {
|
|
193
|
+
const signature = match[1].trim();
|
|
194
|
+
const docstring = match[2].replace(/^("""|''')|("""|''')$/g, '').trim();
|
|
195
|
+
const body = match[0].trim();
|
|
196
|
+
if (docstring && body) {
|
|
197
|
+
pairs.push({
|
|
198
|
+
instruction: `Write a Python function: ${signature}\n\nDescription: ${docstring}`,
|
|
199
|
+
response: body,
|
|
200
|
+
});
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
if (['.ts', '.tsx', '.js', '.jsx'].includes(ext)) {
|
|
205
|
+
// TypeScript/JavaScript: extract functions with JSDoc
|
|
206
|
+
const funcPattern = /(\/\*\*[\s\S]*?\*\/)\s*\n\s*((?:export\s+)?(?:async\s+)?(?:function\s+\w+|(?:const|let)\s+\w+\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=]+)=>)[\s\S]*?)(?=\n\/\*\*|\nexport\s+(?:default\s+)?(?:function|const|class|interface|type)|\nclass\s|$)/gm;
|
|
207
|
+
let match;
|
|
208
|
+
while ((match = funcPattern.exec(content)) !== null) {
|
|
209
|
+
const jsdoc = match[1].replace(/\/\*\*|\*\/|\*\s?/g, '').trim();
|
|
210
|
+
const impl = match[2].trim();
|
|
211
|
+
if (jsdoc && impl && impl.length > 30) {
|
|
212
|
+
// Extract function name from implementation
|
|
213
|
+
const nameMatch = impl.match(/(?:function\s+(\w+)|(?:const|let)\s+(\w+))/);
|
|
214
|
+
const funcName = nameMatch ? (nameMatch[1] || nameMatch[2]) : 'function';
|
|
215
|
+
pairs.push({
|
|
216
|
+
instruction: `Write a TypeScript function \`${funcName}\`: ${jsdoc}`,
|
|
217
|
+
response: impl,
|
|
218
|
+
});
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
if (['.rs'].includes(ext)) {
|
|
223
|
+
// Rust: extract functions with doc comments
|
|
224
|
+
const funcPattern = /((?:\/\/\/.*\n)+)\s*(pub\s+(?:async\s+)?fn\s+\w+[\s\S]*?)(?=\n\/\/\/|\npub\s+(?:async\s+)?fn|\nfn\s|\nimpl\s|$)/gm;
|
|
225
|
+
let match;
|
|
226
|
+
while ((match = funcPattern.exec(content)) !== null) {
|
|
227
|
+
const docComment = match[1].replace(/\/\/\/\s?/g, '').trim();
|
|
228
|
+
const impl = match[2].trim();
|
|
229
|
+
if (docComment && impl) {
|
|
230
|
+
const nameMatch = impl.match(/fn\s+(\w+)/);
|
|
231
|
+
const funcName = nameMatch ? nameMatch[1] : 'function';
|
|
232
|
+
pairs.push({
|
|
233
|
+
instruction: `Write a Rust function \`${funcName}\`: ${docComment}`,
|
|
234
|
+
response: impl,
|
|
235
|
+
});
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
if (['.go'].includes(ext)) {
|
|
240
|
+
// Go: extract functions with doc comments
|
|
241
|
+
const funcPattern = /((?:\/\/.*\n)+)\s*(func\s+(?:\([^)]*\)\s+)?\w+[\s\S]*?)(?=\n\/\/|\nfunc\s|$)/gm;
|
|
242
|
+
let match;
|
|
243
|
+
while ((match = funcPattern.exec(content)) !== null) {
|
|
244
|
+
const docComment = match[1].replace(/\/\/\s?/g, '').trim();
|
|
245
|
+
const impl = match[2].trim();
|
|
246
|
+
if (docComment && impl) {
|
|
247
|
+
const nameMatch = impl.match(/func\s+(?:\([^)]*\)\s+)?(\w+)/);
|
|
248
|
+
const funcName = nameMatch ? nameMatch[1] : 'function';
|
|
249
|
+
pairs.push({
|
|
250
|
+
instruction: `Write a Go function \`${funcName}\`: ${docComment}`,
|
|
251
|
+
response: impl,
|
|
252
|
+
});
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
return pairs;
|
|
257
|
+
}
|
|
258
|
+
/** Convert pairs to the target training format */
|
|
259
|
+
function formatExample(pair, format, systemPrompt) {
|
|
260
|
+
switch (format) {
|
|
261
|
+
case 'jsonl': {
|
|
262
|
+
const messages = [];
|
|
263
|
+
if (systemPrompt)
|
|
264
|
+
messages.push({ role: 'system', content: systemPrompt });
|
|
265
|
+
messages.push({ role: 'user', content: pair.instruction });
|
|
266
|
+
messages.push({ role: 'assistant', content: pair.response });
|
|
267
|
+
return JSON.stringify({ messages });
|
|
268
|
+
}
|
|
269
|
+
case 'alpaca': {
|
|
270
|
+
const example = {
|
|
271
|
+
instruction: pair.instruction,
|
|
272
|
+
input: '',
|
|
273
|
+
output: pair.response,
|
|
274
|
+
};
|
|
275
|
+
return JSON.stringify(example);
|
|
276
|
+
}
|
|
277
|
+
case 'sharegpt': {
|
|
278
|
+
const conversations = [];
|
|
279
|
+
if (systemPrompt)
|
|
280
|
+
conversations.push({ from: 'system', value: systemPrompt });
|
|
281
|
+
conversations.push({ from: 'human', value: pair.instruction });
|
|
282
|
+
conversations.push({ from: 'gpt', value: pair.response });
|
|
283
|
+
return JSON.stringify({ conversations });
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
// ── BLEU score helper ────────────────────────────────────────────────
|
|
288
|
+
function computeNgrams(tokens, n) {
|
|
289
|
+
const ngrams = new Map();
|
|
290
|
+
for (let i = 0; i <= tokens.length - n; i++) {
|
|
291
|
+
const ngram = tokens.slice(i, i + n).join(' ');
|
|
292
|
+
ngrams.set(ngram, (ngrams.get(ngram) || 0) + 1);
|
|
293
|
+
}
|
|
294
|
+
return ngrams;
|
|
295
|
+
}
|
|
296
|
+
function bleuScore(reference, candidate, maxN = 4) {
|
|
297
|
+
const refTokens = reference.toLowerCase().split(/\s+/).filter(Boolean);
|
|
298
|
+
const candTokens = candidate.toLowerCase().split(/\s+/).filter(Boolean);
|
|
299
|
+
if (candTokens.length === 0 || refTokens.length === 0)
|
|
300
|
+
return 0;
|
|
301
|
+
let logBleu = 0;
|
|
302
|
+
let count = 0;
|
|
303
|
+
for (let n = 1; n <= Math.min(maxN, candTokens.length); n++) {
|
|
304
|
+
const refNgrams = computeNgrams(refTokens, n);
|
|
305
|
+
const candNgrams = computeNgrams(candTokens, n);
|
|
306
|
+
let clippedCount = 0;
|
|
307
|
+
let totalCount = 0;
|
|
308
|
+
for (const [ngram, candCount] of candNgrams) {
|
|
309
|
+
const refCount = refNgrams.get(ngram) || 0;
|
|
310
|
+
clippedCount += Math.min(candCount, refCount);
|
|
311
|
+
totalCount += candCount;
|
|
312
|
+
}
|
|
313
|
+
if (totalCount === 0)
|
|
314
|
+
continue;
|
|
315
|
+
const precision = clippedCount / totalCount;
|
|
316
|
+
if (precision === 0)
|
|
317
|
+
return 0;
|
|
318
|
+
logBleu += Math.log(precision);
|
|
319
|
+
count++;
|
|
320
|
+
}
|
|
321
|
+
if (count === 0)
|
|
322
|
+
return 0;
|
|
323
|
+
// Brevity penalty
|
|
324
|
+
const bp = candTokens.length >= refTokens.length
|
|
325
|
+
? 1
|
|
326
|
+
: Math.exp(1 - refTokens.length / candTokens.length);
|
|
327
|
+
return bp * Math.exp(logBleu / count);
|
|
328
|
+
}
|
|
329
|
+
// ── Model size estimation helpers ────────────────────────────────────
|
|
330
|
+
/** Estimate parameter count from model name */
|
|
331
|
+
function estimateModelParams(modelName) {
|
|
332
|
+
const name = modelName.toLowerCase();
|
|
333
|
+
// Common model size patterns
|
|
334
|
+
const sizePatterns = [
|
|
335
|
+
[/(\d+)b/, 0], // will capture
|
|
336
|
+
[/(\d+\.?\d*)b/, 0],
|
|
337
|
+
[/gpt-4/i, 200e9],
|
|
338
|
+
[/gpt-3\.5/i, 7e9],
|
|
339
|
+
[/gpt-4.1-mini/i, 8e9],
|
|
340
|
+
[/gpt-4.1/i, 200e9],
|
|
341
|
+
[/llama.*70b/i, 70e9],
|
|
342
|
+
[/llama.*13b/i, 13e9],
|
|
343
|
+
[/llama.*8b/i, 8e9],
|
|
344
|
+
[/llama.*7b/i, 7e9],
|
|
345
|
+
[/llama.*3b/i, 3e9],
|
|
346
|
+
[/llama.*1b/i, 1e9],
|
|
347
|
+
[/mistral.*7b/i, 7e9],
|
|
348
|
+
[/mixtral/i, 47e9],
|
|
349
|
+
[/phi.*3/i, 3.8e9],
|
|
350
|
+
[/gemma.*7b/i, 7e9],
|
|
351
|
+
[/gemma.*2b/i, 2e9],
|
|
352
|
+
[/qwen.*72b/i, 72e9],
|
|
353
|
+
[/qwen.*14b/i, 14e9],
|
|
354
|
+
[/qwen.*7b/i, 7e9],
|
|
355
|
+
[/qwen.*1\.8b/i, 1.8e9],
|
|
356
|
+
];
|
|
357
|
+
// Try to extract explicit size (e.g., "7b", "13B", "70b")
|
|
358
|
+
const sizeMatch = name.match(/(\d+\.?\d*)\s*b(?:illion)?/i);
|
|
359
|
+
if (sizeMatch) {
|
|
360
|
+
return parseFloat(sizeMatch[1]) * 1e9;
|
|
361
|
+
}
|
|
362
|
+
for (const [pattern, size] of sizePatterns) {
|
|
363
|
+
if (pattern.test(name) && size > 0)
|
|
364
|
+
return size;
|
|
365
|
+
}
|
|
366
|
+
// Default estimate for unknown models
|
|
367
|
+
return 7e9;
|
|
368
|
+
}
|
|
369
|
+
// ── Tool Registration ────────────────────────────────────────────────
|
|
370
|
+
export function registerTrainingTools() {
|
|
371
|
+
// ── train_prepare ──────────────────────────────────────────────────
|
|
372
|
+
registerTool({
|
|
373
|
+
name: 'train_prepare',
|
|
374
|
+
description: 'Convert data into training formats for fine-tuning. Supports JSONL chat (OpenAI SFT), Alpaca, and ShareGPT formats. ' +
|
|
375
|
+
'Auto-detects source type: conversation JSON, markdown, or code files with docstrings. ' +
|
|
376
|
+
'Extracts instruction/response pairs and writes the output dataset.',
|
|
377
|
+
parameters: {
|
|
378
|
+
source: {
|
|
379
|
+
type: 'string',
|
|
380
|
+
description: 'Path to source file or directory containing training data',
|
|
381
|
+
required: true,
|
|
382
|
+
},
|
|
383
|
+
format: {
|
|
384
|
+
type: 'string',
|
|
385
|
+
description: 'Output format: jsonl (OpenAI chat), alpaca, sharegpt (default: jsonl)',
|
|
386
|
+
},
|
|
387
|
+
output: {
|
|
388
|
+
type: 'string',
|
|
389
|
+
description: 'Output file path (default: <source>_prepared.<format>)',
|
|
390
|
+
},
|
|
391
|
+
system_prompt: {
|
|
392
|
+
type: 'string',
|
|
393
|
+
description: 'System prompt to prepend to each example (optional)',
|
|
394
|
+
},
|
|
395
|
+
},
|
|
396
|
+
tier: 'pro',
|
|
397
|
+
timeout: 300_000,
|
|
398
|
+
async execute(args) {
|
|
399
|
+
try {
|
|
400
|
+
const sourcePath = resolve(String(args.source));
|
|
401
|
+
const format = (String(args.format || 'jsonl').toLowerCase());
|
|
402
|
+
const systemPrompt = args.system_prompt ? String(args.system_prompt) : undefined;
|
|
403
|
+
if (!['jsonl', 'alpaca', 'sharegpt'].includes(format)) {
|
|
404
|
+
return `Error: Invalid format "${format}". Use: jsonl, alpaca, sharegpt`;
|
|
405
|
+
}
|
|
406
|
+
if (!existsSync(sourcePath)) {
|
|
407
|
+
return `Error: Source path not found: ${sourcePath}`;
|
|
408
|
+
}
|
|
409
|
+
// Collect all source files
|
|
410
|
+
const stat = statSync(sourcePath);
|
|
411
|
+
let sourceFiles;
|
|
412
|
+
if (stat.isFile()) {
|
|
413
|
+
sourceFiles = [sourcePath];
|
|
414
|
+
}
|
|
415
|
+
else {
|
|
416
|
+
sourceFiles = collectFiles(sourcePath, [
|
|
417
|
+
'.json', '.jsonl', '.md', '.markdown',
|
|
418
|
+
'.py', '.ts', '.tsx', '.js', '.jsx', '.rs', '.go',
|
|
419
|
+
]);
|
|
420
|
+
}
|
|
421
|
+
if (sourceFiles.length === 0) {
|
|
422
|
+
return `Error: No supported source files found in ${sourcePath}\n\nSupported: .json, .jsonl, .md, .py, .ts, .js, .rs, .go`;
|
|
423
|
+
}
|
|
424
|
+
// Extract pairs from all sources
|
|
425
|
+
const allPairs = [];
|
|
426
|
+
for (const file of sourceFiles) {
|
|
427
|
+
const content = readFileSync(file, 'utf-8');
|
|
428
|
+
const ext = extname(file).toLowerCase();
|
|
429
|
+
if (ext === '.json') {
|
|
430
|
+
allPairs.push(...parseConversationJson(content));
|
|
431
|
+
}
|
|
432
|
+
else if (ext === '.jsonl') {
|
|
433
|
+
// Each line is a JSON object
|
|
434
|
+
const lines = content.split('\n').filter(l => l.trim());
|
|
435
|
+
for (const line of lines) {
|
|
436
|
+
try {
|
|
437
|
+
const obj = JSON.parse(line);
|
|
438
|
+
if (obj.messages) {
|
|
439
|
+
allPairs.push(...parseConversationJson(JSON.stringify(obj.messages)));
|
|
440
|
+
}
|
|
441
|
+
else if (obj.instruction && obj.output) {
|
|
442
|
+
allPairs.push({ instruction: obj.instruction, response: obj.output });
|
|
443
|
+
}
|
|
444
|
+
else if (obj.conversations) {
|
|
445
|
+
allPairs.push(...parseConversationJson(JSON.stringify(obj.conversations)));
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
catch { /* skip malformed lines */ }
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
else if (ext === '.md' || ext === '.markdown') {
|
|
452
|
+
allPairs.push(...parseMarkdown(content));
|
|
453
|
+
}
|
|
454
|
+
else {
|
|
455
|
+
// Code files
|
|
456
|
+
allPairs.push(...parseCodeFile(content, ext));
|
|
457
|
+
}
|
|
458
|
+
}
|
|
459
|
+
if (allPairs.length === 0) {
|
|
460
|
+
return [
|
|
461
|
+
`No instruction/response pairs extracted from ${sourceFiles.length} files.`,
|
|
462
|
+
'',
|
|
463
|
+
'Tips:',
|
|
464
|
+
' - JSON files should have messages with role/content or from/value fields',
|
|
465
|
+
' - Markdown files should have headings followed by content paragraphs',
|
|
466
|
+
' - Code files should have functions with docstrings/JSDoc/doc comments',
|
|
467
|
+
].join('\n');
|
|
468
|
+
}
|
|
469
|
+
// Format all pairs
|
|
470
|
+
const outputLines = allPairs.map(pair => formatExample(pair, format, systemPrompt));
|
|
471
|
+
// Determine output path
|
|
472
|
+
const outputExt = format === 'jsonl' ? '.jsonl' : '.json';
|
|
473
|
+
let outputPath;
|
|
474
|
+
if (args.output) {
|
|
475
|
+
outputPath = resolve(String(args.output));
|
|
476
|
+
}
|
|
477
|
+
else {
|
|
478
|
+
const srcBase = stat.isFile()
|
|
479
|
+
? sourcePath.replace(extname(sourcePath), '')
|
|
480
|
+
: sourcePath.replace(/\/$/, '');
|
|
481
|
+
outputPath = `${srcBase}_prepared${outputExt}`;
|
|
482
|
+
}
|
|
483
|
+
// Write output
|
|
484
|
+
mkdirSync(dirname(outputPath), { recursive: true });
|
|
485
|
+
if (format === 'jsonl') {
|
|
486
|
+
writeFileSync(outputPath, outputLines.join('\n') + '\n', 'utf-8');
|
|
487
|
+
}
|
|
488
|
+
else {
|
|
489
|
+
// For alpaca and sharegpt, write as a JSON array
|
|
490
|
+
const parsed = outputLines.map(l => JSON.parse(l));
|
|
491
|
+
writeFileSync(outputPath, JSON.stringify(parsed, null, 2) + '\n', 'utf-8');
|
|
492
|
+
}
|
|
493
|
+
// Calculate stats
|
|
494
|
+
const totalChars = allPairs.reduce((sum, p) => sum + p.instruction.length + p.response.length, 0);
|
|
495
|
+
const estimatedTokens = estimateTokens(allPairs.reduce((s, p) => s + p.instruction + p.response, ''));
|
|
496
|
+
return [
|
|
497
|
+
`Training data prepared successfully.`,
|
|
498
|
+
'',
|
|
499
|
+
` Format: ${format}`,
|
|
500
|
+
` Source: ${sourceFiles.length} file(s)`,
|
|
501
|
+
` Examples: ${allPairs.length}`,
|
|
502
|
+
` Est. tokens: ~${estimatedTokens.toLocaleString()} (~${totalChars.toLocaleString()} chars)`,
|
|
503
|
+
` Output: ${outputPath}`,
|
|
504
|
+
'',
|
|
505
|
+
`File size: ${(statSync(outputPath).size / 1024).toFixed(1)} KB`,
|
|
506
|
+
].join('\n');
|
|
507
|
+
}
|
|
508
|
+
catch (err) {
|
|
509
|
+
return `Error preparing training data: ${err instanceof Error ? err.message : String(err)}`;
|
|
510
|
+
}
|
|
511
|
+
},
|
|
512
|
+
});
|
|
513
|
+
// ── train_validate ─────────────────────────────────────────────────
|
|
514
|
+
registerTool({
|
|
515
|
+
name: 'train_validate',
|
|
516
|
+
description: 'Validate a training dataset before starting a fine-tuning job. Checks format correctness, ' +
|
|
517
|
+
'duplicates, token length distribution, empty examples, and system prompt consistency. ' +
|
|
518
|
+
'Reports warnings and errors with line numbers.',
|
|
519
|
+
parameters: {
|
|
520
|
+
dataset: {
|
|
521
|
+
type: 'string',
|
|
522
|
+
description: 'Path to the training dataset file',
|
|
523
|
+
required: true,
|
|
524
|
+
},
|
|
525
|
+
format: {
|
|
526
|
+
type: 'string',
|
|
527
|
+
description: 'Dataset format: jsonl, alpaca, sharegpt (default: jsonl)',
|
|
528
|
+
},
|
|
529
|
+
},
|
|
530
|
+
tier: 'pro',
|
|
531
|
+
timeout: 120_000,
|
|
532
|
+
async execute(args) {
|
|
533
|
+
try {
|
|
534
|
+
const datasetPath = resolve(String(args.dataset));
|
|
535
|
+
const format = (String(args.format || 'jsonl').toLowerCase());
|
|
536
|
+
if (!existsSync(datasetPath)) {
|
|
537
|
+
return `Error: Dataset not found: ${datasetPath}`;
|
|
538
|
+
}
|
|
539
|
+
const content = readFileSync(datasetPath, 'utf-8');
|
|
540
|
+
const issues = [];
|
|
541
|
+
const tokenCounts = [];
|
|
542
|
+
const normalizedTexts = new Set();
|
|
543
|
+
const systemPrompts = new Map();
|
|
544
|
+
let totalExamples = 0;
|
|
545
|
+
let emptyCount = 0;
|
|
546
|
+
let duplicateCount = 0;
|
|
547
|
+
const lines = format === 'jsonl'
|
|
548
|
+
? content.split('\n').filter(l => l.trim())
|
|
549
|
+
: (() => {
|
|
550
|
+
try {
|
|
551
|
+
const parsed = JSON.parse(content);
|
|
552
|
+
return Array.isArray(parsed) ? parsed.map((item) => JSON.stringify(item)) : [content];
|
|
553
|
+
}
|
|
554
|
+
catch {
|
|
555
|
+
return content.split('\n').filter(l => l.trim());
|
|
556
|
+
}
|
|
557
|
+
})();
|
|
558
|
+
for (let i = 0; i < lines.length; i++) {
|
|
559
|
+
const lineNum = i + 1;
|
|
560
|
+
const line = typeof lines[i] === 'string' ? lines[i] : String(lines[i]);
|
|
561
|
+
// Parse JSON
|
|
562
|
+
let obj;
|
|
563
|
+
try {
|
|
564
|
+
obj = JSON.parse(line);
|
|
565
|
+
}
|
|
566
|
+
catch {
|
|
567
|
+
issues.push({ line: lineNum, severity: 'error', message: 'Invalid JSON' });
|
|
568
|
+
continue;
|
|
569
|
+
}
|
|
570
|
+
totalExamples++;
|
|
571
|
+
// Format-specific validation
|
|
572
|
+
if (format === 'jsonl') {
|
|
573
|
+
if (!obj.messages || !Array.isArray(obj.messages)) {
|
|
574
|
+
issues.push({ line: lineNum, severity: 'error', message: 'Missing "messages" array' });
|
|
575
|
+
continue;
|
|
576
|
+
}
|
|
577
|
+
const msgs = obj.messages;
|
|
578
|
+
if (msgs.length < 2) {
|
|
579
|
+
issues.push({ line: lineNum, severity: 'error', message: `Only ${msgs.length} message(s) — need at least user + assistant` });
|
|
580
|
+
continue;
|
|
581
|
+
}
|
|
582
|
+
// Check roles
|
|
583
|
+
const hasUser = msgs.some(m => m.role === 'user');
|
|
584
|
+
const hasAssistant = msgs.some(m => m.role === 'assistant');
|
|
585
|
+
if (!hasUser)
|
|
586
|
+
issues.push({ line: lineNum, severity: 'error', message: 'No "user" role message found' });
|
|
587
|
+
if (!hasAssistant)
|
|
588
|
+
issues.push({ line: lineNum, severity: 'error', message: 'No "assistant" role message found' });
|
|
589
|
+
for (let j = 0; j < msgs.length; j++) {
|
|
590
|
+
const m = msgs[j];
|
|
591
|
+
if (!m.role || !['system', 'user', 'assistant'].includes(m.role)) {
|
|
592
|
+
issues.push({ line: lineNum, severity: 'error', message: `Message ${j}: invalid role "${m.role}"` });
|
|
593
|
+
}
|
|
594
|
+
if (!m.content || typeof m.content !== 'string') {
|
|
595
|
+
issues.push({ line: lineNum, severity: 'error', message: `Message ${j}: missing or non-string content` });
|
|
596
|
+
}
|
|
597
|
+
}
|
|
598
|
+
// Track system prompts
|
|
599
|
+
const sysMsg = msgs.find(m => m.role === 'system');
|
|
600
|
+
if (sysMsg?.content) {
|
|
601
|
+
const key = normalizeText(sysMsg.content);
|
|
602
|
+
systemPrompts.set(key, (systemPrompts.get(key) || 0) + 1);
|
|
603
|
+
}
|
|
604
|
+
// Token count
|
|
605
|
+
const fullText = msgs.map(m => m.content || '').join(' ');
|
|
606
|
+
const tokens = estimateTokens(fullText);
|
|
607
|
+
tokenCounts.push(tokens);
|
|
608
|
+
// Empty check
|
|
609
|
+
const userMsg = msgs.find(m => m.role === 'user');
|
|
610
|
+
const assistantMsg = msgs.find(m => m.role === 'assistant');
|
|
611
|
+
if ((userMsg?.content?.length || 0) < 5 || (assistantMsg?.content?.length || 0) < 5) {
|
|
612
|
+
issues.push({ line: lineNum, severity: 'warning', message: 'Very short user or assistant message (< 5 chars)' });
|
|
613
|
+
emptyCount++;
|
|
614
|
+
}
|
|
615
|
+
// Duplicate check
|
|
616
|
+
const normalized = normalizeText(fullText);
|
|
617
|
+
if (normalizedTexts.has(normalized)) {
|
|
618
|
+
issues.push({ line: lineNum, severity: 'warning', message: 'Duplicate or near-duplicate example' });
|
|
619
|
+
duplicateCount++;
|
|
620
|
+
}
|
|
621
|
+
normalizedTexts.add(normalized);
|
|
622
|
+
}
|
|
623
|
+
else if (format === 'alpaca') {
|
|
624
|
+
if (!obj.instruction || typeof obj.instruction !== 'string') {
|
|
625
|
+
issues.push({ line: lineNum, severity: 'error', message: 'Missing or non-string "instruction" field' });
|
|
626
|
+
}
|
|
627
|
+
if (!obj.output || typeof obj.output !== 'string') {
|
|
628
|
+
issues.push({ line: lineNum, severity: 'error', message: 'Missing or non-string "output" field' });
|
|
629
|
+
}
|
|
630
|
+
if (obj.instruction && obj.instruction.length < 5) {
|
|
631
|
+
issues.push({ line: lineNum, severity: 'warning', message: 'Very short instruction (< 5 chars)' });
|
|
632
|
+
emptyCount++;
|
|
633
|
+
}
|
|
634
|
+
if (obj.output && obj.output.length < 5) {
|
|
635
|
+
issues.push({ line: lineNum, severity: 'warning', message: 'Very short output (< 5 chars)' });
|
|
636
|
+
}
|
|
637
|
+
const fullText = `${obj.instruction || ''} ${obj.input || ''} ${obj.output || ''}`;
|
|
638
|
+
tokenCounts.push(estimateTokens(fullText));
|
|
639
|
+
const normalized = normalizeText(fullText);
|
|
640
|
+
if (normalizedTexts.has(normalized)) {
|
|
641
|
+
issues.push({ line: lineNum, severity: 'warning', message: 'Duplicate or near-duplicate example' });
|
|
642
|
+
duplicateCount++;
|
|
643
|
+
}
|
|
644
|
+
normalizedTexts.add(normalized);
|
|
645
|
+
}
|
|
646
|
+
else if (format === 'sharegpt') {
|
|
647
|
+
if (!obj.conversations || !Array.isArray(obj.conversations)) {
|
|
648
|
+
issues.push({ line: lineNum, severity: 'error', message: 'Missing "conversations" array' });
|
|
649
|
+
continue;
|
|
650
|
+
}
|
|
651
|
+
const convs = obj.conversations;
|
|
652
|
+
if (convs.length < 2) {
|
|
653
|
+
issues.push({ line: lineNum, severity: 'error', message: `Only ${convs.length} turn(s) — need at least human + gpt` });
|
|
654
|
+
}
|
|
655
|
+
const hasHuman = convs.some(c => c.from === 'human');
|
|
656
|
+
const hasGpt = convs.some(c => c.from === 'gpt');
|
|
657
|
+
if (!hasHuman)
|
|
658
|
+
issues.push({ line: lineNum, severity: 'error', message: 'No "human" turn found' });
|
|
659
|
+
if (!hasGpt)
|
|
660
|
+
issues.push({ line: lineNum, severity: 'error', message: 'No "gpt" turn found' });
|
|
661
|
+
for (let j = 0; j < convs.length; j++) {
|
|
662
|
+
if (!convs[j].from || !['human', 'gpt', 'system'].includes(convs[j].from)) {
|
|
663
|
+
issues.push({ line: lineNum, severity: 'error', message: `Turn ${j}: invalid "from" value "${convs[j].from}"` });
|
|
664
|
+
}
|
|
665
|
+
if (!convs[j].value || typeof convs[j].value !== 'string') {
|
|
666
|
+
issues.push({ line: lineNum, severity: 'error', message: `Turn ${j}: missing or non-string "value"` });
|
|
667
|
+
}
|
|
668
|
+
}
|
|
669
|
+
// Track system prompts
|
|
670
|
+
const sysTurn = convs.find(c => c.from === 'system');
|
|
671
|
+
if (sysTurn?.value) {
|
|
672
|
+
const key = normalizeText(sysTurn.value);
|
|
673
|
+
systemPrompts.set(key, (systemPrompts.get(key) || 0) + 1);
|
|
674
|
+
}
|
|
675
|
+
const fullText = convs.map(c => c.value || '').join(' ');
|
|
676
|
+
tokenCounts.push(estimateTokens(fullText));
|
|
677
|
+
const normalized = normalizeText(fullText);
|
|
678
|
+
if (normalizedTexts.has(normalized)) {
|
|
679
|
+
issues.push({ line: lineNum, severity: 'warning', message: 'Duplicate or near-duplicate example' });
|
|
680
|
+
duplicateCount++;
|
|
681
|
+
}
|
|
682
|
+
normalizedTexts.add(normalized);
|
|
683
|
+
}
|
|
684
|
+
}
|
|
685
|
+
// Compute token stats
|
|
686
|
+
const sortedTokens = [...tokenCounts].sort((a, b) => a - b);
|
|
687
|
+
const minTokens = sortedTokens[0] || 0;
|
|
688
|
+
const maxTokens = sortedTokens[sortedTokens.length - 1] || 0;
|
|
689
|
+
const meanTokens = tokenCounts.length > 0
|
|
690
|
+
? Math.round(tokenCounts.reduce((a, b) => a + b, 0) / tokenCounts.length)
|
|
691
|
+
: 0;
|
|
692
|
+
const p95Index = Math.floor(tokenCounts.length * 0.95);
|
|
693
|
+
const p95Tokens = sortedTokens[p95Index] || maxTokens;
|
|
694
|
+
const errors = issues.filter(i => i.severity === 'error');
|
|
695
|
+
const warnings = issues.filter(i => i.severity === 'warning');
|
|
696
|
+
// System prompt consistency
|
|
697
|
+
let systemPromptNote = '';
|
|
698
|
+
if (systemPrompts.size > 1) {
|
|
699
|
+
systemPromptNote = `\n System prompts: ${systemPrompts.size} unique variants (inconsistent — consider standardizing)`;
|
|
700
|
+
}
|
|
701
|
+
else if (systemPrompts.size === 1) {
|
|
702
|
+
const [prompt, count] = [...systemPrompts.entries()][0];
|
|
703
|
+
systemPromptNote = `\n System prompt: consistent across ${count} examples ("${prompt.slice(0, 60)}${prompt.length > 60 ? '...' : ''}")`;
|
|
704
|
+
}
|
|
705
|
+
else {
|
|
706
|
+
systemPromptNote = '\n System prompt: none found';
|
|
707
|
+
}
|
|
708
|
+
// Build report
|
|
709
|
+
const report = [
|
|
710
|
+
`Dataset Validation Report`,
|
|
711
|
+
`${'='.repeat(50)}`,
|
|
712
|
+
'',
|
|
713
|
+
` File: ${datasetPath}`,
|
|
714
|
+
` Format: ${format}`,
|
|
715
|
+
` Examples: ${totalExamples}`,
|
|
716
|
+
` Errors: ${errors.length}`,
|
|
717
|
+
` Warnings: ${warnings.length}`,
|
|
718
|
+
` Duplicates: ${duplicateCount}`,
|
|
719
|
+
` Short/empty: ${emptyCount}`,
|
|
720
|
+
systemPromptNote,
|
|
721
|
+
'',
|
|
722
|
+
`Token Distribution`,
|
|
723
|
+
`${'─'.repeat(30)}`,
|
|
724
|
+
` Min: ${minTokens}`,
|
|
725
|
+
` Max: ${maxTokens}`,
|
|
726
|
+
` Mean: ${meanTokens}`,
|
|
727
|
+
` P95: ${p95Tokens}`,
|
|
728
|
+
` Total: ~${tokenCounts.reduce((a, b) => a + b, 0).toLocaleString()} tokens`,
|
|
729
|
+
];
|
|
730
|
+
if (errors.length > 0) {
|
|
731
|
+
report.push('', `Errors (${errors.length})`, '─'.repeat(30));
|
|
732
|
+
for (const issue of errors.slice(0, 25)) {
|
|
733
|
+
report.push(` Line ${issue.line}: ${issue.message}`);
|
|
734
|
+
}
|
|
735
|
+
if (errors.length > 25) {
|
|
736
|
+
report.push(` ... and ${errors.length - 25} more errors`);
|
|
737
|
+
}
|
|
738
|
+
}
|
|
739
|
+
if (warnings.length > 0) {
|
|
740
|
+
report.push('', `Warnings (${warnings.length})`, '─'.repeat(30));
|
|
741
|
+
for (const issue of warnings.slice(0, 15)) {
|
|
742
|
+
report.push(` Line ${issue.line}: ${issue.message}`);
|
|
743
|
+
}
|
|
744
|
+
if (warnings.length > 15) {
|
|
745
|
+
report.push(` ... and ${warnings.length - 15} more warnings`);
|
|
746
|
+
}
|
|
747
|
+
}
|
|
748
|
+
// Verdict
|
|
749
|
+
report.push('');
|
|
750
|
+
if (errors.length === 0 && warnings.length === 0) {
|
|
751
|
+
report.push('Verdict: PASS — Dataset is clean and ready for training.');
|
|
752
|
+
}
|
|
753
|
+
else if (errors.length === 0) {
|
|
754
|
+
report.push(`Verdict: PASS with warnings — ${warnings.length} warning(s) found but no blocking errors.`);
|
|
755
|
+
}
|
|
756
|
+
else {
|
|
757
|
+
report.push(`Verdict: FAIL — ${errors.length} error(s) must be fixed before training.`);
|
|
758
|
+
}
|
|
759
|
+
return report.join('\n');
|
|
760
|
+
}
|
|
761
|
+
catch (err) {
|
|
762
|
+
return `Validation error: ${err instanceof Error ? err.message : String(err)}`;
|
|
763
|
+
}
|
|
764
|
+
},
|
|
765
|
+
});
|
|
766
|
+
// ── train_start ────────────────────────────────────────────────────
|
|
767
|
+
registerTool({
|
|
768
|
+
name: 'train_start',
|
|
769
|
+
description: 'Launch a fine-tuning job. Supports cloud backends (OpenAI, Together AI, Mistral) and ' +
|
|
770
|
+
'local backends (MLX on Apple Silicon, Unsloth, llama.cpp). For cloud: uploads dataset and creates job. ' +
|
|
771
|
+
'For local: detects tool installation and launches the training process.',
|
|
772
|
+
parameters: {
|
|
773
|
+
dataset: {
|
|
774
|
+
type: 'string',
|
|
775
|
+
description: 'Path to training dataset file',
|
|
776
|
+
required: true,
|
|
777
|
+
},
|
|
778
|
+
backend: {
|
|
779
|
+
type: 'string',
|
|
780
|
+
description: 'Training backend: openai, together, mistral, mlx, unsloth, llama-cpp',
|
|
781
|
+
required: true,
|
|
782
|
+
},
|
|
783
|
+
base_model: {
|
|
784
|
+
type: 'string',
|
|
785
|
+
description: 'Base model to fine-tune (e.g., gpt-4.1-mini, meta-llama/Llama-3-8B, mlx-community/Llama-3-8B-4bit)',
|
|
786
|
+
required: true,
|
|
787
|
+
},
|
|
788
|
+
output: {
|
|
789
|
+
type: 'string',
|
|
790
|
+
description: 'Output path for local training (default: ./output/<model>-ft)',
|
|
791
|
+
},
|
|
792
|
+
epochs: {
|
|
793
|
+
type: 'number',
|
|
794
|
+
description: 'Number of training epochs (default: 3)',
|
|
795
|
+
},
|
|
796
|
+
learning_rate: {
|
|
797
|
+
type: 'number',
|
|
798
|
+
description: 'Learning rate (default: 1e-4 for local, auto for cloud)',
|
|
799
|
+
},
|
|
800
|
+
batch_size: {
|
|
801
|
+
type: 'number',
|
|
802
|
+
description: 'Batch size (default: 4)',
|
|
803
|
+
},
|
|
804
|
+
lora_rank: {
|
|
805
|
+
type: 'number',
|
|
806
|
+
description: 'LoRA rank for local training (default: 16)',
|
|
807
|
+
},
|
|
808
|
+
lora_alpha: {
|
|
809
|
+
type: 'number',
|
|
810
|
+
description: 'LoRA alpha for local training (default: 32)',
|
|
811
|
+
},
|
|
812
|
+
api_key: {
|
|
813
|
+
type: 'string',
|
|
814
|
+
description: 'API key for cloud backends (optional — reads from env or ~/.kbot/config.json)',
|
|
815
|
+
},
|
|
816
|
+
},
|
|
817
|
+
tier: 'pro',
|
|
818
|
+
timeout: 600_000,
|
|
819
|
+
async execute(args) {
|
|
820
|
+
try {
|
|
821
|
+
const datasetPath = resolve(String(args.dataset));
|
|
822
|
+
const backend = String(args.backend).toLowerCase();
|
|
823
|
+
const baseModel = String(args.base_model);
|
|
824
|
+
const epochs = typeof args.epochs === 'number' ? args.epochs : 3;
|
|
825
|
+
const batchSize = typeof args.batch_size === 'number' ? args.batch_size : 4;
|
|
826
|
+
const loraRank = typeof args.lora_rank === 'number' ? args.lora_rank : 16;
|
|
827
|
+
const loraAlpha = typeof args.lora_alpha === 'number' ? args.lora_alpha : 32;
|
|
828
|
+
const learningRate = typeof args.learning_rate === 'number' ? args.learning_rate : 1e-4;
|
|
829
|
+
if (!existsSync(datasetPath)) {
|
|
830
|
+
return `Error: Dataset not found: ${datasetPath}`;
|
|
831
|
+
}
|
|
832
|
+
const validBackends = ['openai', 'together', 'mistral', 'mlx', 'unsloth', 'llama-cpp'];
|
|
833
|
+
if (!validBackends.includes(backend)) {
|
|
834
|
+
return `Error: Invalid backend "${backend}". Supported: ${validBackends.join(', ')}`;
|
|
835
|
+
}
|
|
836
|
+
// Determine output directory for local backends
|
|
837
|
+
const modelSlug = baseModel.replace(/[^a-zA-Z0-9-]/g, '-').replace(/-+/g, '-').slice(0, 50);
|
|
838
|
+
const outputDir = args.output
|
|
839
|
+
? resolve(String(args.output))
|
|
840
|
+
: resolve(process.cwd(), 'output', `${modelSlug}-ft`);
|
|
841
|
+
// ── Cloud backends ───────────────────────────────────────────
|
|
842
|
+
if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
|
|
843
|
+
const apiKey = getApiKey(backend, args.api_key ? String(args.api_key) : undefined);
|
|
844
|
+
if (!apiKey) {
|
|
845
|
+
const envVarHint = backend === 'openai' ? 'OPENAI_API_KEY'
|
|
846
|
+
: backend === 'together' ? 'TOGETHER_API_KEY'
|
|
847
|
+
: 'MISTRAL_API_KEY';
|
|
848
|
+
return `Error: No API key found for ${backend}. Set ${envVarHint} env var, add to ~/.kbot/config.json, or pass api_key parameter.`;
|
|
849
|
+
}
|
|
850
|
+
const datasetContent = readFileSync(datasetPath, 'utf-8');
|
|
851
|
+
if (backend === 'openai') {
|
|
852
|
+
// Step 1: Upload file
|
|
853
|
+
const formData = new FormData();
|
|
854
|
+
formData.append('purpose', 'fine-tune');
|
|
855
|
+
formData.append('file', new Blob([datasetContent], { type: 'application/jsonl' }), basename(datasetPath));
|
|
856
|
+
const uploadRes = await fetch('https://api.openai.com/v1/files', {
|
|
857
|
+
method: 'POST',
|
|
858
|
+
headers: { 'Authorization': `Bearer ${apiKey}` },
|
|
859
|
+
body: formData,
|
|
860
|
+
});
|
|
861
|
+
if (!uploadRes.ok) {
|
|
862
|
+
const errBody = await uploadRes.text();
|
|
863
|
+
return `Error uploading file to OpenAI: ${uploadRes.status} ${errBody}`;
|
|
864
|
+
}
|
|
865
|
+
const uploadData = await uploadRes.json();
|
|
866
|
+
const fileId = uploadData.id;
|
|
867
|
+
// Step 2: Create fine-tuning job
|
|
868
|
+
const jobBody = {
|
|
869
|
+
training_file: fileId,
|
|
870
|
+
model: baseModel,
|
|
871
|
+
hyperparameters: {
|
|
872
|
+
n_epochs: epochs,
|
|
873
|
+
batch_size: batchSize,
|
|
874
|
+
},
|
|
875
|
+
};
|
|
876
|
+
if (args.learning_rate !== undefined) {
|
|
877
|
+
jobBody.hyperparameters.learning_rate_multiplier = learningRate;
|
|
878
|
+
}
|
|
879
|
+
const jobRes = await fetch('https://api.openai.com/v1/fine_tuning/jobs', {
|
|
880
|
+
method: 'POST',
|
|
881
|
+
headers: {
|
|
882
|
+
'Authorization': `Bearer ${apiKey}`,
|
|
883
|
+
'Content-Type': 'application/json',
|
|
884
|
+
},
|
|
885
|
+
body: JSON.stringify(jobBody),
|
|
886
|
+
});
|
|
887
|
+
if (!jobRes.ok) {
|
|
888
|
+
const errBody = await jobRes.text();
|
|
889
|
+
return `Error creating OpenAI fine-tuning job: ${jobRes.status} ${errBody}`;
|
|
890
|
+
}
|
|
891
|
+
const jobData = await jobRes.json();
|
|
892
|
+
return [
|
|
893
|
+
`OpenAI fine-tuning job created.`,
|
|
894
|
+
'',
|
|
895
|
+
` Job ID: ${jobData.id}`,
|
|
896
|
+
` Base model: ${baseModel}`,
|
|
897
|
+
` Status: ${jobData.status}`,
|
|
898
|
+
` File ID: ${fileId}`,
|
|
899
|
+
` Epochs: ${epochs}`,
|
|
900
|
+
` Batch size: ${batchSize}`,
|
|
901
|
+
'',
|
|
902
|
+
`Check status with: train_status --job_id ${jobData.id} --backend openai`,
|
|
903
|
+
].join('\n');
|
|
904
|
+
}
|
|
905
|
+
else if (backend === 'together') {
|
|
906
|
+
// Together AI fine-tuning
|
|
907
|
+
const jobBody = {
|
|
908
|
+
training_file: datasetPath,
|
|
909
|
+
model: baseModel,
|
|
910
|
+
n_epochs: epochs,
|
|
911
|
+
learning_rate: learningRate,
|
|
912
|
+
batch_size: batchSize,
|
|
913
|
+
};
|
|
914
|
+
// Together requires file upload first
|
|
915
|
+
const formData = new FormData();
|
|
916
|
+
formData.append('file', new Blob([datasetContent], { type: 'application/jsonl' }), basename(datasetPath));
|
|
917
|
+
formData.append('purpose', 'fine-tune');
|
|
918
|
+
const uploadRes = await fetch('https://api.together.xyz/v1/files', {
|
|
919
|
+
method: 'POST',
|
|
920
|
+
headers: { 'Authorization': `Bearer ${apiKey}` },
|
|
921
|
+
body: formData,
|
|
922
|
+
});
|
|
923
|
+
if (!uploadRes.ok) {
|
|
924
|
+
const errBody = await uploadRes.text();
|
|
925
|
+
return `Error uploading file to Together AI: ${uploadRes.status} ${errBody}`;
|
|
926
|
+
}
|
|
927
|
+
const uploadData = await uploadRes.json();
|
|
928
|
+
const jobRes = await fetch('https://api.together.xyz/v1/fine-tunes', {
|
|
929
|
+
method: 'POST',
|
|
930
|
+
headers: {
|
|
931
|
+
'Authorization': `Bearer ${apiKey}`,
|
|
932
|
+
'Content-Type': 'application/json',
|
|
933
|
+
},
|
|
934
|
+
body: JSON.stringify({
|
|
935
|
+
...jobBody,
|
|
936
|
+
training_file: uploadData.id,
|
|
937
|
+
}),
|
|
938
|
+
});
|
|
939
|
+
if (!jobRes.ok) {
|
|
940
|
+
const errBody = await jobRes.text();
|
|
941
|
+
return `Error creating Together AI fine-tuning job: ${jobRes.status} ${errBody}`;
|
|
942
|
+
}
|
|
943
|
+
const jobData = await jobRes.json();
|
|
944
|
+
return [
|
|
945
|
+
`Together AI fine-tuning job created.`,
|
|
946
|
+
'',
|
|
947
|
+
` Job ID: ${jobData.id}`,
|
|
948
|
+
` Base model: ${baseModel}`,
|
|
949
|
+
` Status: ${jobData.status}`,
|
|
950
|
+
` Epochs: ${epochs}`,
|
|
951
|
+
` Learning rate: ${learningRate}`,
|
|
952
|
+
` Batch size: ${batchSize}`,
|
|
953
|
+
'',
|
|
954
|
+
`Check status with: train_status --job_id ${jobData.id} --backend together`,
|
|
955
|
+
].join('\n');
|
|
956
|
+
}
|
|
957
|
+
else if (backend === 'mistral') {
|
|
958
|
+
// Mistral fine-tuning
|
|
959
|
+
const formData = new FormData();
|
|
960
|
+
formData.append('file', new Blob([datasetContent], { type: 'application/jsonl' }), basename(datasetPath));
|
|
961
|
+
formData.append('purpose', 'fine-tune');
|
|
962
|
+
const uploadRes = await fetch('https://api.mistral.ai/v1/files', {
|
|
963
|
+
method: 'POST',
|
|
964
|
+
headers: { 'Authorization': `Bearer ${apiKey}` },
|
|
965
|
+
body: formData,
|
|
966
|
+
});
|
|
967
|
+
if (!uploadRes.ok) {
|
|
968
|
+
const errBody = await uploadRes.text();
|
|
969
|
+
return `Error uploading file to Mistral: ${uploadRes.status} ${errBody}`;
|
|
970
|
+
}
|
|
971
|
+
const uploadData = await uploadRes.json();
|
|
972
|
+
const jobRes = await fetch('https://api.mistral.ai/v1/fine_tuning/jobs', {
|
|
973
|
+
method: 'POST',
|
|
974
|
+
headers: {
|
|
975
|
+
'Authorization': `Bearer ${apiKey}`,
|
|
976
|
+
'Content-Type': 'application/json',
|
|
977
|
+
},
|
|
978
|
+
body: JSON.stringify({
|
|
979
|
+
model: baseModel,
|
|
980
|
+
training_files: [{ file_id: uploadData.id, weight: 1 }],
|
|
981
|
+
hyperparameters: {
|
|
982
|
+
training_steps: epochs * 100, // Mistral uses steps
|
|
983
|
+
learning_rate: learningRate,
|
|
984
|
+
},
|
|
985
|
+
}),
|
|
986
|
+
});
|
|
987
|
+
if (!jobRes.ok) {
|
|
988
|
+
const errBody = await jobRes.text();
|
|
989
|
+
return `Error creating Mistral fine-tuning job: ${jobRes.status} ${errBody}`;
|
|
990
|
+
}
|
|
991
|
+
const jobData = await jobRes.json();
|
|
992
|
+
return [
|
|
993
|
+
`Mistral fine-tuning job created.`,
|
|
994
|
+
'',
|
|
995
|
+
` Job ID: ${jobData.id}`,
|
|
996
|
+
` Base model: ${baseModel}`,
|
|
997
|
+
` Status: ${jobData.status}`,
|
|
998
|
+
` Epochs: ${epochs}`,
|
|
999
|
+
` Learning rate: ${learningRate}`,
|
|
1000
|
+
'',
|
|
1001
|
+
`Check status with: train_status --job_id ${jobData.id} --backend mistral`,
|
|
1002
|
+
].join('\n');
|
|
1003
|
+
}
|
|
1004
|
+
}
|
|
1005
|
+
// ── Local backends ───────────────────────────────────────────
|
|
1006
|
+
mkdirSync(outputDir, { recursive: true });
|
|
1007
|
+
if (backend === 'mlx') {
|
|
1008
|
+
// Check if mlx_lm is available
|
|
1009
|
+
const hasMlx = shellSafe('python3 -c "import mlx_lm; print(mlx_lm.__version__)"');
|
|
1010
|
+
if (!hasMlx.ok) {
|
|
1011
|
+
return [
|
|
1012
|
+
'MLX LM is not installed. Install it with:',
|
|
1013
|
+
'',
|
|
1014
|
+
' pip install mlx-lm',
|
|
1015
|
+
'',
|
|
1016
|
+
'Requirements:',
|
|
1017
|
+
' - Apple Silicon Mac (M1/M2/M3/M4)',
|
|
1018
|
+
' - macOS 14+ (Sonoma)',
|
|
1019
|
+
' - Python 3.10+',
|
|
1020
|
+
].join('\n');
|
|
1021
|
+
}
|
|
1022
|
+
const iters = epochs * 100; // Rough: iters = epochs * (dataset_size / batch_size)
|
|
1023
|
+
const adapterPath = join(outputDir, 'adapters');
|
|
1024
|
+
mkdirSync(adapterPath, { recursive: true });
|
|
1025
|
+
const cmd = [
|
|
1026
|
+
'python3 -m mlx_lm.lora',
|
|
1027
|
+
`--model ${baseModel}`,
|
|
1028
|
+
`--data ${datasetPath}`,
|
|
1029
|
+
'--train',
|
|
1030
|
+
`--iters ${iters}`,
|
|
1031
|
+
`--batch-size ${batchSize}`,
|
|
1032
|
+
`--lora-layers ${loraRank}`,
|
|
1033
|
+
`--adapter-path ${adapterPath}`,
|
|
1034
|
+
].join(' ');
|
|
1035
|
+
// Write the command to a script file for background execution
|
|
1036
|
+
const scriptPath = join(outputDir, 'train.sh');
|
|
1037
|
+
const logPath = join(outputDir, 'train.log');
|
|
1038
|
+
writeFileSync(scriptPath, [
|
|
1039
|
+
'#!/bin/bash',
|
|
1040
|
+
`echo "Training started at $(date)" > ${logPath}`,
|
|
1041
|
+
`echo "Command: ${cmd}" >> ${logPath}`,
|
|
1042
|
+
`${cmd} 2>&1 | tee -a ${logPath}`,
|
|
1043
|
+
`echo "Training finished at $(date)" >> ${logPath}`,
|
|
1044
|
+
].join('\n'), 'utf-8');
|
|
1045
|
+
shell(`chmod +x ${scriptPath}`);
|
|
1046
|
+
// Launch in background
|
|
1047
|
+
const child = spawn('bash', [scriptPath], {
|
|
1048
|
+
detached: true,
|
|
1049
|
+
stdio: ['ignore', 'ignore', 'ignore'],
|
|
1050
|
+
});
|
|
1051
|
+
child.unref();
|
|
1052
|
+
return [
|
|
1053
|
+
`MLX LoRA training launched in background.`,
|
|
1054
|
+
'',
|
|
1055
|
+
` Base model: ${baseModel}`,
|
|
1056
|
+
` Dataset: ${datasetPath}`,
|
|
1057
|
+
` Iterations: ${iters}`,
|
|
1058
|
+
` Batch size: ${batchSize}`,
|
|
1059
|
+
` LoRA layers: ${loraRank}`,
|
|
1060
|
+
` Adapter path: ${adapterPath}`,
|
|
1061
|
+
` Log file: ${logPath}`,
|
|
1062
|
+
` PID: ${child.pid}`,
|
|
1063
|
+
'',
|
|
1064
|
+
`Monitor progress: train_status --backend mlx --log_path ${logPath}`,
|
|
1065
|
+
`When done, merge adapter: train_export --model_path ${baseModel} --operation merge_lora --base_model ${baseModel}`,
|
|
1066
|
+
].join('\n');
|
|
1067
|
+
}
|
|
1068
|
+
else if (backend === 'unsloth') {
|
|
1069
|
+
// Check if unsloth is available
|
|
1070
|
+
const hasUnsloth = shellSafe('python3 -c "import unsloth; print(unsloth.__version__)"');
|
|
1071
|
+
if (!hasUnsloth.ok) {
|
|
1072
|
+
return [
|
|
1073
|
+
'Unsloth is not installed. Install it with:',
|
|
1074
|
+
'',
|
|
1075
|
+
' pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"',
|
|
1076
|
+
' pip install --no-deps "trl<0.9.0" peft accelerate bitsandbytes',
|
|
1077
|
+
'',
|
|
1078
|
+
'Requirements:',
|
|
1079
|
+
' - NVIDIA GPU with CUDA support',
|
|
1080
|
+
' - Python 3.10+',
|
|
1081
|
+
' - PyTorch 2.0+',
|
|
1082
|
+
].join('\n');
|
|
1083
|
+
}
|
|
1084
|
+
// Generate a Python training script
|
|
1085
|
+
const scriptContent = `#!/usr/bin/env python3
|
|
1086
|
+
"""Unsloth fine-tuning script generated by K:BOT"""
|
|
1087
|
+
import json
|
|
1088
|
+
from unsloth import FastLanguageModel
|
|
1089
|
+
from trl import SFTTrainer
|
|
1090
|
+
from transformers import TrainingArguments
|
|
1091
|
+
from datasets import load_dataset
|
|
1092
|
+
|
|
1093
|
+
# Configuration
|
|
1094
|
+
BASE_MODEL = "${baseModel}"
|
|
1095
|
+
DATASET_PATH = "${datasetPath}"
|
|
1096
|
+
OUTPUT_DIR = "${outputDir}"
|
|
1097
|
+
EPOCHS = ${epochs}
|
|
1098
|
+
BATCH_SIZE = ${batchSize}
|
|
1099
|
+
LEARNING_RATE = ${learningRate}
|
|
1100
|
+
LORA_RANK = ${loraRank}
|
|
1101
|
+
LORA_ALPHA = ${loraAlpha}
|
|
1102
|
+
MAX_SEQ_LENGTH = 2048
|
|
1103
|
+
|
|
1104
|
+
print(f"Loading model: {BASE_MODEL}")
|
|
1105
|
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
1106
|
+
model_name=BASE_MODEL,
|
|
1107
|
+
max_seq_length=MAX_SEQ_LENGTH,
|
|
1108
|
+
dtype=None, # auto-detect
|
|
1109
|
+
load_in_4bit=True,
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
print("Applying LoRA adapters...")
|
|
1113
|
+
model = FastLanguageModel.get_peft_model(
|
|
1114
|
+
model,
|
|
1115
|
+
r=LORA_RANK,
|
|
1116
|
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
|
1117
|
+
"gate_proj", "up_proj", "down_proj"],
|
|
1118
|
+
lora_alpha=LORA_ALPHA,
|
|
1119
|
+
lora_dropout=0,
|
|
1120
|
+
bias="none",
|
|
1121
|
+
use_gradient_checkpointing="unsloth",
|
|
1122
|
+
random_state=42,
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
print(f"Loading dataset: {DATASET_PATH}")
|
|
1126
|
+
dataset = load_dataset("json", data_files=DATASET_PATH, split="train")
|
|
1127
|
+
|
|
1128
|
+
def formatting_func(examples):
|
|
1129
|
+
texts = []
|
|
1130
|
+
for msgs in examples.get("messages", [[]]):
|
|
1131
|
+
text_parts = []
|
|
1132
|
+
for msg in msgs:
|
|
1133
|
+
role = msg.get("role", "user")
|
|
1134
|
+
content = msg.get("content", "")
|
|
1135
|
+
if role == "system":
|
|
1136
|
+
text_parts.append(f"### System:\\n{content}")
|
|
1137
|
+
elif role == "user":
|
|
1138
|
+
text_parts.append(f"### User:\\n{content}")
|
|
1139
|
+
elif role == "assistant":
|
|
1140
|
+
text_parts.append(f"### Assistant:\\n{content}")
|
|
1141
|
+
texts.append("\\n\\n".join(text_parts))
|
|
1142
|
+
return {"text": texts}
|
|
1143
|
+
|
|
1144
|
+
dataset = dataset.map(formatting_func, batched=True, remove_columns=dataset.column_names)
|
|
1145
|
+
|
|
1146
|
+
trainer = SFTTrainer(
|
|
1147
|
+
model=model,
|
|
1148
|
+
tokenizer=tokenizer,
|
|
1149
|
+
train_dataset=dataset,
|
|
1150
|
+
dataset_text_field="text",
|
|
1151
|
+
max_seq_length=MAX_SEQ_LENGTH,
|
|
1152
|
+
dataset_num_proc=2,
|
|
1153
|
+
packing=True,
|
|
1154
|
+
args=TrainingArguments(
|
|
1155
|
+
per_device_train_batch_size=BATCH_SIZE,
|
|
1156
|
+
gradient_accumulation_steps=4,
|
|
1157
|
+
warmup_steps=10,
|
|
1158
|
+
num_train_epochs=EPOCHS,
|
|
1159
|
+
learning_rate=LEARNING_RATE,
|
|
1160
|
+
fp16=True,
|
|
1161
|
+
logging_steps=1,
|
|
1162
|
+
optim="adamw_8bit",
|
|
1163
|
+
weight_decay=0.01,
|
|
1164
|
+
lr_scheduler_type="linear",
|
|
1165
|
+
seed=42,
|
|
1166
|
+
output_dir=OUTPUT_DIR,
|
|
1167
|
+
save_strategy="epoch",
|
|
1168
|
+
),
|
|
1169
|
+
)
|
|
1170
|
+
|
|
1171
|
+
print("Starting training...")
|
|
1172
|
+
trainer_stats = trainer.train()
|
|
1173
|
+
print(f"Training complete. Loss: {trainer_stats.training_loss:.4f}")
|
|
1174
|
+
|
|
1175
|
+
print(f"Saving model to {OUTPUT_DIR}")
|
|
1176
|
+
model.save_pretrained(OUTPUT_DIR)
|
|
1177
|
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
1178
|
+
print("Done!")
|
|
1179
|
+
`;
|
|
1180
|
+
const scriptPath = join(outputDir, 'train_unsloth.py');
|
|
1181
|
+
const logPath = join(outputDir, 'train.log');
|
|
1182
|
+
writeFileSync(scriptPath, scriptContent, 'utf-8');
|
|
1183
|
+
// Launch in background
|
|
1184
|
+
const launchScript = join(outputDir, 'train.sh');
|
|
1185
|
+
writeFileSync(launchScript, [
|
|
1186
|
+
'#!/bin/bash',
|
|
1187
|
+
`echo "Training started at $(date)" > ${logPath}`,
|
|
1188
|
+
`python3 ${scriptPath} 2>&1 | tee -a ${logPath}`,
|
|
1189
|
+
`echo "Training finished at $(date)" >> ${logPath}`,
|
|
1190
|
+
].join('\n'), 'utf-8');
|
|
1191
|
+
shell(`chmod +x ${launchScript}`);
|
|
1192
|
+
const child = spawn('bash', [launchScript], {
|
|
1193
|
+
detached: true,
|
|
1194
|
+
stdio: ['ignore', 'ignore', 'ignore'],
|
|
1195
|
+
});
|
|
1196
|
+
child.unref();
|
|
1197
|
+
return [
|
|
1198
|
+
`Unsloth fine-tuning launched in background.`,
|
|
1199
|
+
'',
|
|
1200
|
+
` Base model: ${baseModel}`,
|
|
1201
|
+
` Dataset: ${datasetPath}`,
|
|
1202
|
+
` Epochs: ${epochs}`,
|
|
1203
|
+
` Batch size: ${batchSize}`,
|
|
1204
|
+
` Learning rate: ${learningRate}`,
|
|
1205
|
+
` LoRA rank: ${loraRank}`,
|
|
1206
|
+
` LoRA alpha: ${loraAlpha}`,
|
|
1207
|
+
` Output: ${outputDir}`,
|
|
1208
|
+
` Log file: ${logPath}`,
|
|
1209
|
+
` Training script: ${scriptPath}`,
|
|
1210
|
+
` PID: ${child.pid}`,
|
|
1211
|
+
'',
|
|
1212
|
+
`Monitor progress: train_status --backend unsloth --log_path ${logPath}`,
|
|
1213
|
+
].join('\n');
|
|
1214
|
+
}
|
|
1215
|
+
else if (backend === 'llama-cpp') {
|
|
1216
|
+
// Check if llama-finetune is available
|
|
1217
|
+
if (!isCommandAvailable('llama-finetune')) {
|
|
1218
|
+
return [
|
|
1219
|
+
'llama-finetune is not installed. Build it from llama.cpp:',
|
|
1220
|
+
'',
|
|
1221
|
+
' git clone https://github.com/ggerganov/llama.cpp',
|
|
1222
|
+
' cd llama.cpp',
|
|
1223
|
+
' make llama-finetune',
|
|
1224
|
+
'',
|
|
1225
|
+
'Then add the build directory to your PATH.',
|
|
1226
|
+
'',
|
|
1227
|
+
'Alternative: use the MLX backend on Apple Silicon (faster, easier setup).',
|
|
1228
|
+
].join('\n');
|
|
1229
|
+
}
|
|
1230
|
+
const threads = Math.max(1, cpus().length - 2);
|
|
1231
|
+
const loraOutPath = join(outputDir, 'lora-adapter.bin');
|
|
1232
|
+
const logPath = join(outputDir, 'train.log');
|
|
1233
|
+
const cmd = [
|
|
1234
|
+
'llama-finetune',
|
|
1235
|
+
`--model-base ${baseModel}`,
|
|
1236
|
+
`--lora-out ${loraOutPath}`,
|
|
1237
|
+
`--train-data ${datasetPath}`,
|
|
1238
|
+
`--threads ${threads}`,
|
|
1239
|
+
`--epochs ${epochs}`,
|
|
1240
|
+
`--batch ${batchSize}`,
|
|
1241
|
+
`--lora-r ${loraRank}`,
|
|
1242
|
+
`--lora-alpha ${loraAlpha}`,
|
|
1243
|
+
].join(' ');
|
|
1244
|
+
const scriptPath = join(outputDir, 'train.sh');
|
|
1245
|
+
writeFileSync(scriptPath, [
|
|
1246
|
+
'#!/bin/bash',
|
|
1247
|
+
`echo "Training started at $(date)" > ${logPath}`,
|
|
1248
|
+
`echo "Command: ${cmd}" >> ${logPath}`,
|
|
1249
|
+
`${cmd} 2>&1 | tee -a ${logPath}`,
|
|
1250
|
+
`echo "Training finished at $(date)" >> ${logPath}`,
|
|
1251
|
+
].join('\n'), 'utf-8');
|
|
1252
|
+
shell(`chmod +x ${scriptPath}`);
|
|
1253
|
+
const child = spawn('bash', [scriptPath], {
|
|
1254
|
+
detached: true,
|
|
1255
|
+
stdio: ['ignore', 'ignore', 'ignore'],
|
|
1256
|
+
});
|
|
1257
|
+
child.unref();
|
|
1258
|
+
return [
|
|
1259
|
+
`llama.cpp fine-tuning launched in background.`,
|
|
1260
|
+
'',
|
|
1261
|
+
` Base model: ${baseModel}`,
|
|
1262
|
+
` Dataset: ${datasetPath}`,
|
|
1263
|
+
` Epochs: ${epochs}`,
|
|
1264
|
+
` Batch size: ${batchSize}`,
|
|
1265
|
+
` LoRA rank: ${loraRank}`,
|
|
1266
|
+
` LoRA alpha: ${loraAlpha}`,
|
|
1267
|
+
` Threads: ${threads}`,
|
|
1268
|
+
` LoRA output: ${loraOutPath}`,
|
|
1269
|
+
` Log file: ${logPath}`,
|
|
1270
|
+
` PID: ${child.pid}`,
|
|
1271
|
+
'',
|
|
1272
|
+
`Monitor progress: train_status --backend llama-cpp --log_path ${logPath}`,
|
|
1273
|
+
].join('\n');
|
|
1274
|
+
}
|
|
1275
|
+
return `Error: Backend "${backend}" not handled.`;
|
|
1276
|
+
}
|
|
1277
|
+
catch (err) {
|
|
1278
|
+
return `Error starting training: ${err instanceof Error ? err.message : String(err)}`;
|
|
1279
|
+
}
|
|
1280
|
+
},
|
|
1281
|
+
});
|
|
1282
|
+
// ── train_status ───────────────────────────────────────────────────
|
|
1283
|
+
registerTool({
|
|
1284
|
+
name: 'train_status',
|
|
1285
|
+
description: 'Check the status of a fine-tuning job. For cloud backends, polls the provider API. ' +
|
|
1286
|
+
'For local backends, checks process state and reads the latest loss from the log file.',
|
|
1287
|
+
parameters: {
|
|
1288
|
+
job_id: {
|
|
1289
|
+
type: 'string',
|
|
1290
|
+
description: 'Job ID for cloud backends (required for openai/together/mistral)',
|
|
1291
|
+
},
|
|
1292
|
+
backend: {
|
|
1293
|
+
type: 'string',
|
|
1294
|
+
description: 'Backend: openai, together, mistral, mlx, unsloth, llama-cpp',
|
|
1295
|
+
required: true,
|
|
1296
|
+
},
|
|
1297
|
+
log_path: {
|
|
1298
|
+
type: 'string',
|
|
1299
|
+
description: 'Path to training log file (for local backends)',
|
|
1300
|
+
},
|
|
1301
|
+
api_key: {
|
|
1302
|
+
type: 'string',
|
|
1303
|
+
description: 'API key for cloud backends (optional)',
|
|
1304
|
+
},
|
|
1305
|
+
},
|
|
1306
|
+
tier: 'pro',
|
|
1307
|
+
timeout: 30_000,
|
|
1308
|
+
async execute(args) {
|
|
1309
|
+
try {
|
|
1310
|
+
const backend = String(args.backend).toLowerCase();
|
|
1311
|
+
const jobId = args.job_id ? String(args.job_id) : '';
|
|
1312
|
+
const logPath = args.log_path ? resolve(String(args.log_path)) : '';
|
|
1313
|
+
// ── Cloud backends ─────────────────────────────────────────
|
|
1314
|
+
if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
|
|
1315
|
+
if (!jobId) {
|
|
1316
|
+
return `Error: job_id is required for cloud backend "${backend}".`;
|
|
1317
|
+
}
|
|
1318
|
+
const apiKey = getApiKey(backend, args.api_key ? String(args.api_key) : undefined);
|
|
1319
|
+
if (!apiKey) {
|
|
1320
|
+
return `Error: No API key found for ${backend}. Set the appropriate environment variable or pass api_key.`;
|
|
1321
|
+
}
|
|
1322
|
+
let url;
|
|
1323
|
+
if (backend === 'openai') {
|
|
1324
|
+
url = `https://api.openai.com/v1/fine_tuning/jobs/${jobId}`;
|
|
1325
|
+
}
|
|
1326
|
+
else if (backend === 'together') {
|
|
1327
|
+
url = `https://api.together.xyz/v1/fine-tunes/${jobId}`;
|
|
1328
|
+
}
|
|
1329
|
+
else {
|
|
1330
|
+
url = `https://api.mistral.ai/v1/fine_tuning/jobs/${jobId}`;
|
|
1331
|
+
}
|
|
1332
|
+
const res = await fetch(url, {
|
|
1333
|
+
headers: { 'Authorization': `Bearer ${apiKey}` },
|
|
1334
|
+
});
|
|
1335
|
+
if (!res.ok) {
|
|
1336
|
+
const errBody = await res.text();
|
|
1337
|
+
return `Error fetching job status from ${backend}: ${res.status} ${errBody}`;
|
|
1338
|
+
}
|
|
1339
|
+
const data = await res.json();
|
|
1340
|
+
const status = data.status || data.state || 'unknown';
|
|
1341
|
+
const model = data.model || data.fine_tuned_model || data.output_name || 'N/A';
|
|
1342
|
+
const createdAt = data.created_at ? new Date(data.created_at * 1000).toISOString() : 'N/A';
|
|
1343
|
+
const finishedAt = data.finished_at ? new Date(data.finished_at * 1000).toISOString() : 'N/A';
|
|
1344
|
+
// Extract training metrics if available
|
|
1345
|
+
let metricsInfo = '';
|
|
1346
|
+
const trainedTokens = data.trained_tokens || data.training_tokens;
|
|
1347
|
+
if (trainedTokens)
|
|
1348
|
+
metricsInfo += `\n Trained tokens: ${trainedTokens.toLocaleString()}`;
|
|
1349
|
+
const resultFiles = data.result_files;
|
|
1350
|
+
if (Array.isArray(resultFiles) && resultFiles.length > 0) {
|
|
1351
|
+
metricsInfo += `\n Result files: ${resultFiles.length}`;
|
|
1352
|
+
}
|
|
1353
|
+
const error = data.error;
|
|
1354
|
+
if (error) {
|
|
1355
|
+
metricsInfo += `\n Error: ${error.message || JSON.stringify(error)}`;
|
|
1356
|
+
}
|
|
1357
|
+
// OpenAI specific: get events
|
|
1358
|
+
let eventsInfo = '';
|
|
1359
|
+
if (backend === 'openai') {
|
|
1360
|
+
try {
|
|
1361
|
+
const eventsRes = await fetch(`${url}/events?limit=5`, {
|
|
1362
|
+
headers: { 'Authorization': `Bearer ${apiKey}` },
|
|
1363
|
+
});
|
|
1364
|
+
if (eventsRes.ok) {
|
|
1365
|
+
const eventsData = await eventsRes.json();
|
|
1366
|
+
if (eventsData.data && eventsData.data.length > 0) {
|
|
1367
|
+
eventsInfo = '\n\nRecent events:\n' + eventsData.data
|
|
1368
|
+
.map(e => ` [${new Date(e.created_at * 1000).toLocaleTimeString()}] ${e.message}`)
|
|
1369
|
+
.join('\n');
|
|
1370
|
+
}
|
|
1371
|
+
}
|
|
1372
|
+
}
|
|
1373
|
+
catch { /* ignore events fetch errors */ }
|
|
1374
|
+
}
|
|
1375
|
+
return [
|
|
1376
|
+
`${backend.charAt(0).toUpperCase() + backend.slice(1)} Fine-Tuning Job Status`,
|
|
1377
|
+
'='.repeat(40),
|
|
1378
|
+
'',
|
|
1379
|
+
` Job ID: ${jobId}`,
|
|
1380
|
+
` Status: ${status}`,
|
|
1381
|
+
` Base model: ${model}`,
|
|
1382
|
+
` Created: ${createdAt}`,
|
|
1383
|
+
` Finished: ${finishedAt}`,
|
|
1384
|
+
metricsInfo,
|
|
1385
|
+
eventsInfo,
|
|
1386
|
+
].filter(Boolean).join('\n');
|
|
1387
|
+
}
|
|
1388
|
+
// ── Local backends ─────────────────────────────────────────
|
|
1389
|
+
if (!logPath) {
|
|
1390
|
+
return `Error: log_path is required for local backend "${backend}". Pass the path to the training log file.`;
|
|
1391
|
+
}
|
|
1392
|
+
if (!existsSync(logPath)) {
|
|
1393
|
+
return `Error: Log file not found: ${logPath}`;
|
|
1394
|
+
}
|
|
1395
|
+
const logContent = readFileSync(logPath, 'utf-8');
|
|
1396
|
+
const logLines = logContent.split('\n').filter(l => l.trim());
|
|
1397
|
+
// Determine status
|
|
1398
|
+
let status = 'running';
|
|
1399
|
+
const lastLine = logLines[logLines.length - 1] || '';
|
|
1400
|
+
if (lastLine.includes('Training finished') || lastLine.includes('Done!') || lastLine.includes('Saving model')) {
|
|
1401
|
+
status = 'completed';
|
|
1402
|
+
}
|
|
1403
|
+
if (lastLine.includes('Error') || lastLine.includes('Traceback') || lastLine.includes('FAILED')) {
|
|
1404
|
+
status = 'failed';
|
|
1405
|
+
}
|
|
1406
|
+
// Extract training metrics from log
|
|
1407
|
+
let currentLoss = 'N/A';
|
|
1408
|
+
let currentStep = 'N/A';
|
|
1409
|
+
let totalSteps = 'N/A';
|
|
1410
|
+
// Search for loss values (common patterns across frameworks)
|
|
1411
|
+
const lossPatterns = [
|
|
1412
|
+
/loss[:\s=]+(\d+\.?\d*)/i,
|
|
1413
|
+
/train_loss[:\s=]+(\d+\.?\d*)/i,
|
|
1414
|
+
/training_loss[:\s=]+(\d+\.?\d*)/i,
|
|
1415
|
+
];
|
|
1416
|
+
const stepPatterns = [
|
|
1417
|
+
/(?:step|iter|iteration)[:\s=]+(\d+)/i,
|
|
1418
|
+
/(\d+)\/(\d+)/, // step/total
|
|
1419
|
+
];
|
|
1420
|
+
// Read from end of file for latest values
|
|
1421
|
+
for (let i = logLines.length - 1; i >= Math.max(0, logLines.length - 50); i--) {
|
|
1422
|
+
const line = logLines[i];
|
|
1423
|
+
if (currentLoss === 'N/A') {
|
|
1424
|
+
for (const pattern of lossPatterns) {
|
|
1425
|
+
const match = line.match(pattern);
|
|
1426
|
+
if (match) {
|
|
1427
|
+
currentLoss = match[1];
|
|
1428
|
+
break;
|
|
1429
|
+
}
|
|
1430
|
+
}
|
|
1431
|
+
}
|
|
1432
|
+
if (currentStep === 'N/A') {
|
|
1433
|
+
for (const pattern of stepPatterns) {
|
|
1434
|
+
const match = line.match(pattern);
|
|
1435
|
+
if (match) {
|
|
1436
|
+
currentStep = match[1];
|
|
1437
|
+
if (match[2])
|
|
1438
|
+
totalSteps = match[2];
|
|
1439
|
+
break;
|
|
1440
|
+
}
|
|
1441
|
+
}
|
|
1442
|
+
}
|
|
1443
|
+
if (currentLoss !== 'N/A' && currentStep !== 'N/A')
|
|
1444
|
+
break;
|
|
1445
|
+
}
|
|
1446
|
+
// Extract start time
|
|
1447
|
+
let elapsed = 'N/A';
|
|
1448
|
+
let eta = 'N/A';
|
|
1449
|
+
const startMatch = logContent.match(/Training started at (.+)/);
|
|
1450
|
+
if (startMatch) {
|
|
1451
|
+
const startTime = new Date(startMatch[1]);
|
|
1452
|
+
const elapsedMs = Date.now() - startTime.getTime();
|
|
1453
|
+
const elapsedMin = (elapsedMs / 60_000).toFixed(1);
|
|
1454
|
+
elapsed = `${elapsedMin} min`;
|
|
1455
|
+
// Estimate ETA
|
|
1456
|
+
if (currentStep !== 'N/A' && totalSteps !== 'N/A') {
|
|
1457
|
+
const step = parseInt(currentStep, 10);
|
|
1458
|
+
const total = parseInt(totalSteps, 10);
|
|
1459
|
+
if (step > 0 && total > step) {
|
|
1460
|
+
const msPerStep = elapsedMs / step;
|
|
1461
|
+
const remainingMs = msPerStep * (total - step);
|
|
1462
|
+
const remainingMin = (remainingMs / 60_000).toFixed(1);
|
|
1463
|
+
eta = `~${remainingMin} min remaining`;
|
|
1464
|
+
}
|
|
1465
|
+
}
|
|
1466
|
+
}
|
|
1467
|
+
// Last few log lines for context
|
|
1468
|
+
const recentLines = logLines.slice(-8).map(l => ` ${l}`).join('\n');
|
|
1469
|
+
return [
|
|
1470
|
+
`Local Training Status (${backend})`,
|
|
1471
|
+
'='.repeat(40),
|
|
1472
|
+
'',
|
|
1473
|
+
` Status: ${status}`,
|
|
1474
|
+
` Current step: ${currentStep}${totalSteps !== 'N/A' ? ` / ${totalSteps}` : ''}`,
|
|
1475
|
+
` Current loss: ${currentLoss}`,
|
|
1476
|
+
` Elapsed: ${elapsed}`,
|
|
1477
|
+
` ETA: ${eta}`,
|
|
1478
|
+
` Log file: ${logPath}`,
|
|
1479
|
+
'',
|
|
1480
|
+
'Recent log output:',
|
|
1481
|
+
recentLines,
|
|
1482
|
+
].join('\n');
|
|
1483
|
+
}
|
|
1484
|
+
catch (err) {
|
|
1485
|
+
return `Error checking training status: ${err instanceof Error ? err.message : String(err)}`;
|
|
1486
|
+
}
|
|
1487
|
+
},
|
|
1488
|
+
});
|
|
1489
|
+
// ── train_evaluate ─────────────────────────────────────────────────
|
|
1490
|
+
registerTool({
|
|
1491
|
+
name: 'train_evaluate',
|
|
1492
|
+
description: 'Evaluate a fine-tuned model against a test dataset. Runs test prompts through the model ' +
|
|
1493
|
+
'and measures BLEU score, exact match rate, and average response length. ' +
|
|
1494
|
+
'Supports local inference via Ollama or llama.cpp, and cloud via provider API.',
|
|
1495
|
+
parameters: {
|
|
1496
|
+
model: {
|
|
1497
|
+
type: 'string',
|
|
1498
|
+
description: 'Model name or path to evaluate',
|
|
1499
|
+
required: true,
|
|
1500
|
+
},
|
|
1501
|
+
test_data: {
|
|
1502
|
+
type: 'string',
|
|
1503
|
+
description: 'Path to test dataset (same format as training data)',
|
|
1504
|
+
required: true,
|
|
1505
|
+
},
|
|
1506
|
+
backend: {
|
|
1507
|
+
type: 'string',
|
|
1508
|
+
description: 'Inference backend: ollama, llama-cpp, openai, together, mistral (default: ollama)',
|
|
1509
|
+
},
|
|
1510
|
+
samples: {
|
|
1511
|
+
type: 'number',
|
|
1512
|
+
description: 'Maximum number of test samples to evaluate (default: 50)',
|
|
1513
|
+
},
|
|
1514
|
+
api_key: {
|
|
1515
|
+
type: 'string',
|
|
1516
|
+
description: 'API key for cloud backends (optional)',
|
|
1517
|
+
},
|
|
1518
|
+
},
|
|
1519
|
+
tier: 'pro',
|
|
1520
|
+
timeout: 600_000,
|
|
1521
|
+
async execute(args) {
|
|
1522
|
+
try {
|
|
1523
|
+
const model = String(args.model);
|
|
1524
|
+
const testDataPath = resolve(String(args.test_data));
|
|
1525
|
+
const backend = String(args.backend || 'ollama').toLowerCase();
|
|
1526
|
+
const maxSamples = typeof args.samples === 'number' ? args.samples : 50;
|
|
1527
|
+
if (!existsSync(testDataPath)) {
|
|
1528
|
+
return `Error: Test data not found: ${testDataPath}`;
|
|
1529
|
+
}
|
|
1530
|
+
// Parse test data — extract prompt/expected pairs
|
|
1531
|
+
const content = readFileSync(testDataPath, 'utf-8');
|
|
1532
|
+
let testCases = [];
|
|
1533
|
+
// Try parsing as JSONL first
|
|
1534
|
+
const lines = content.split('\n').filter(l => l.trim());
|
|
1535
|
+
for (const line of lines) {
|
|
1536
|
+
try {
|
|
1537
|
+
const obj = JSON.parse(line);
|
|
1538
|
+
if (obj.messages && Array.isArray(obj.messages)) {
|
|
1539
|
+
const msgs = obj.messages;
|
|
1540
|
+
const systemMsg = msgs.find(m => m.role === 'system');
|
|
1541
|
+
const userMsg = msgs.find(m => m.role === 'user');
|
|
1542
|
+
const assistantMsg = msgs.find(m => m.role === 'assistant');
|
|
1543
|
+
if (userMsg && assistantMsg) {
|
|
1544
|
+
testCases.push({
|
|
1545
|
+
prompt: userMsg.content,
|
|
1546
|
+
expected: assistantMsg.content,
|
|
1547
|
+
system: systemMsg?.content,
|
|
1548
|
+
});
|
|
1549
|
+
}
|
|
1550
|
+
}
|
|
1551
|
+
else if (obj.instruction && obj.output) {
|
|
1552
|
+
testCases.push({ prompt: obj.instruction, expected: obj.output });
|
|
1553
|
+
}
|
|
1554
|
+
else if (obj.conversations) {
|
|
1555
|
+
const convs = obj.conversations;
|
|
1556
|
+
const human = convs.find(c => c.from === 'human');
|
|
1557
|
+
const gpt = convs.find(c => c.from === 'gpt');
|
|
1558
|
+
const sys = convs.find(c => c.from === 'system');
|
|
1559
|
+
if (human && gpt) {
|
|
1560
|
+
testCases.push({ prompt: human.value, expected: gpt.value, system: sys?.value });
|
|
1561
|
+
}
|
|
1562
|
+
}
|
|
1563
|
+
}
|
|
1564
|
+
catch { /* skip malformed lines */ }
|
|
1565
|
+
}
|
|
1566
|
+
// If no JSONL, try as JSON array
|
|
1567
|
+
if (testCases.length === 0) {
|
|
1568
|
+
try {
|
|
1569
|
+
const parsed = JSON.parse(content);
|
|
1570
|
+
if (Array.isArray(parsed)) {
|
|
1571
|
+
for (const item of parsed) {
|
|
1572
|
+
if (item.messages) {
|
|
1573
|
+
const msgs = item.messages;
|
|
1574
|
+
const userMsg = msgs.find((m) => m.role === 'user');
|
|
1575
|
+
const assistantMsg = msgs.find((m) => m.role === 'assistant');
|
|
1576
|
+
if (userMsg && assistantMsg) {
|
|
1577
|
+
testCases.push({ prompt: userMsg.content, expected: assistantMsg.content });
|
|
1578
|
+
}
|
|
1579
|
+
}
|
|
1580
|
+
else if (item.instruction && item.output) {
|
|
1581
|
+
testCases.push({ prompt: item.instruction, expected: item.output });
|
|
1582
|
+
}
|
|
1583
|
+
}
|
|
1584
|
+
}
|
|
1585
|
+
}
|
|
1586
|
+
catch { /* not a JSON array */ }
|
|
1587
|
+
}
|
|
1588
|
+
if (testCases.length === 0) {
|
|
1589
|
+
return `Error: No test cases extracted from ${testDataPath}. Ensure the file contains valid JSONL/Alpaca/ShareGPT examples.`;
|
|
1590
|
+
}
|
|
1591
|
+
// Limit samples
|
|
1592
|
+
testCases = testCases.slice(0, maxSamples);
|
|
1593
|
+
// Run inference for each test case
|
|
1594
|
+
const results = [];
|
|
1595
|
+
let errorCount = 0;
|
|
1596
|
+
for (const tc of testCases) {
|
|
1597
|
+
let actual = '';
|
|
1598
|
+
try {
|
|
1599
|
+
if (backend === 'ollama') {
|
|
1600
|
+
if (!isCommandAvailable('ollama')) {
|
|
1601
|
+
return 'Error: Ollama is not installed. Install from https://ollama.ai';
|
|
1602
|
+
}
|
|
1603
|
+
const prompt = tc.system
|
|
1604
|
+
? `System: ${tc.system}\n\nUser: ${tc.prompt}\n\nAssistant:`
|
|
1605
|
+
: `User: ${tc.prompt}\n\nAssistant:`;
|
|
1606
|
+
const result = shellSafe(`ollama run ${model} ${JSON.stringify(prompt)}`, { timeout: 60_000 });
|
|
1607
|
+
actual = result.ok ? result.output : '';
|
|
1608
|
+
}
|
|
1609
|
+
else if (backend === 'llama-cpp') {
|
|
1610
|
+
if (!isCommandAvailable('llama-cli')) {
|
|
1611
|
+
return 'Error: llama-cli is not installed. Build from https://github.com/ggerganov/llama.cpp';
|
|
1612
|
+
}
|
|
1613
|
+
const prompt = tc.prompt;
|
|
1614
|
+
const result = shellSafe(`llama-cli -m ${model} -p ${JSON.stringify(prompt)} -n 512 --temp 0.1`, { timeout: 120_000 });
|
|
1615
|
+
actual = result.ok ? result.output : '';
|
|
1616
|
+
}
|
|
1617
|
+
else if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
|
|
1618
|
+
const apiKey = getApiKey(backend, args.api_key ? String(args.api_key) : undefined);
|
|
1619
|
+
if (!apiKey) {
|
|
1620
|
+
return `Error: No API key for ${backend}. Set the appropriate env var or pass api_key.`;
|
|
1621
|
+
}
|
|
1622
|
+
const apiUrl = backend === 'openai' ? 'https://api.openai.com/v1/chat/completions'
|
|
1623
|
+
: backend === 'together' ? 'https://api.together.xyz/v1/chat/completions'
|
|
1624
|
+
: 'https://api.mistral.ai/v1/chat/completions';
|
|
1625
|
+
const messages = [];
|
|
1626
|
+
if (tc.system)
|
|
1627
|
+
messages.push({ role: 'system', content: tc.system });
|
|
1628
|
+
messages.push({ role: 'user', content: tc.prompt });
|
|
1629
|
+
const res = await fetch(apiUrl, {
|
|
1630
|
+
method: 'POST',
|
|
1631
|
+
headers: {
|
|
1632
|
+
'Authorization': `Bearer ${apiKey}`,
|
|
1633
|
+
'Content-Type': 'application/json',
|
|
1634
|
+
},
|
|
1635
|
+
body: JSON.stringify({
|
|
1636
|
+
model,
|
|
1637
|
+
messages,
|
|
1638
|
+
max_tokens: 1024,
|
|
1639
|
+
temperature: 0.1,
|
|
1640
|
+
}),
|
|
1641
|
+
});
|
|
1642
|
+
if (res.ok) {
|
|
1643
|
+
const data = await res.json();
|
|
1644
|
+
actual = data.choices?.[0]?.message?.content || '';
|
|
1645
|
+
}
|
|
1646
|
+
}
|
|
1647
|
+
}
|
|
1648
|
+
catch {
|
|
1649
|
+
errorCount++;
|
|
1650
|
+
continue;
|
|
1651
|
+
}
|
|
1652
|
+
if (!actual) {
|
|
1653
|
+
errorCount++;
|
|
1654
|
+
continue;
|
|
1655
|
+
}
|
|
1656
|
+
const bleu = bleuScore(tc.expected, actual);
|
|
1657
|
+
const exactMatch = normalizeText(actual) === normalizeText(tc.expected);
|
|
1658
|
+
results.push({
|
|
1659
|
+
prompt: tc.prompt.slice(0, 80),
|
|
1660
|
+
expected: tc.expected.slice(0, 80),
|
|
1661
|
+
actual: actual.slice(0, 80),
|
|
1662
|
+
bleu,
|
|
1663
|
+
exactMatch,
|
|
1664
|
+
});
|
|
1665
|
+
}
|
|
1666
|
+
if (results.length === 0) {
|
|
1667
|
+
return `Error: No successful evaluations. ${errorCount} inference calls failed. Check model name and backend.`;
|
|
1668
|
+
}
|
|
1669
|
+
// Compute aggregate metrics
|
|
1670
|
+
const avgBleu = results.reduce((s, r) => s + r.bleu, 0) / results.length;
|
|
1671
|
+
const exactMatchRate = results.filter(r => r.exactMatch).length / results.length;
|
|
1672
|
+
const avgResponseLen = results.reduce((s, r) => s + r.actual.length, 0) / results.length;
|
|
1673
|
+
const avgExpectedLen = results.reduce((s, r) => s + r.expected.length, 0) / results.length;
|
|
1674
|
+
// Build report
|
|
1675
|
+
const report = [
|
|
1676
|
+
`Model Evaluation Report`,
|
|
1677
|
+
'='.repeat(50),
|
|
1678
|
+
'',
|
|
1679
|
+
` Model: ${model}`,
|
|
1680
|
+
` Backend: ${backend}`,
|
|
1681
|
+
` Test samples: ${results.length} / ${testCases.length}`,
|
|
1682
|
+
` Errors: ${errorCount}`,
|
|
1683
|
+
'',
|
|
1684
|
+
`Metrics`,
|
|
1685
|
+
'─'.repeat(30),
|
|
1686
|
+
` BLEU score: ${(avgBleu * 100).toFixed(1)}%`,
|
|
1687
|
+
` Exact match: ${(exactMatchRate * 100).toFixed(1)}%`,
|
|
1688
|
+
` Avg response: ${Math.round(avgResponseLen)} chars`,
|
|
1689
|
+
` Avg expected: ${Math.round(avgExpectedLen)} chars`,
|
|
1690
|
+
` Length ratio: ${(avgResponseLen / Math.max(avgExpectedLen, 1)).toFixed(2)}x`,
|
|
1691
|
+
];
|
|
1692
|
+
// Show a few examples
|
|
1693
|
+
report.push('', 'Sample Results', '─'.repeat(30));
|
|
1694
|
+
for (const r of results.slice(0, 5)) {
|
|
1695
|
+
report.push(` Prompt: ${r.prompt}...`, ` Expected: ${r.expected}...`, ` Actual: ${r.actual}...`, ` BLEU: ${(r.bleu * 100).toFixed(1)}% | Match: ${r.exactMatch ? 'YES' : 'NO'}`, '');
|
|
1696
|
+
}
|
|
1697
|
+
return report.join('\n');
|
|
1698
|
+
}
|
|
1699
|
+
catch (err) {
|
|
1700
|
+
return `Evaluation error: ${err instanceof Error ? err.message : String(err)}`;
|
|
1701
|
+
}
|
|
1702
|
+
},
|
|
1703
|
+
});
|
|
1704
|
+
// ── train_export ───────────────────────────────────────────────────
|
|
1705
|
+
registerTool({
|
|
1706
|
+
name: 'train_export',
|
|
1707
|
+
description: 'Convert and export models between formats. Merge LoRA adapters back into base models, ' +
|
|
1708
|
+
'convert HuggingFace models to GGUF, or quantize GGUF files for efficient deployment.',
|
|
1709
|
+
parameters: {
|
|
1710
|
+
model_path: {
|
|
1711
|
+
type: 'string',
|
|
1712
|
+
description: 'Path to model or adapter directory',
|
|
1713
|
+
required: true,
|
|
1714
|
+
},
|
|
1715
|
+
operation: {
|
|
1716
|
+
type: 'string',
|
|
1717
|
+
description: 'Operation: merge_lora, to_gguf, quantize',
|
|
1718
|
+
required: true,
|
|
1719
|
+
},
|
|
1720
|
+
output: {
|
|
1721
|
+
type: 'string',
|
|
1722
|
+
description: 'Output path (default: auto-generated)',
|
|
1723
|
+
},
|
|
1724
|
+
base_model: {
|
|
1725
|
+
type: 'string',
|
|
1726
|
+
description: 'Base model name/path (required for merge_lora)',
|
|
1727
|
+
},
|
|
1728
|
+
quantization: {
|
|
1729
|
+
type: 'string',
|
|
1730
|
+
description: 'Quantization type for to_gguf and quantize: q4_K_M, q5_K_M, q8_0, f16 (default: q4_K_M)',
|
|
1731
|
+
},
|
|
1732
|
+
},
|
|
1733
|
+
tier: 'pro',
|
|
1734
|
+
timeout: 600_000,
|
|
1735
|
+
async execute(args) {
|
|
1736
|
+
try {
|
|
1737
|
+
const modelPath = resolve(String(args.model_path));
|
|
1738
|
+
const operation = String(args.operation).toLowerCase();
|
|
1739
|
+
const quantType = String(args.quantization || 'q4_K_M');
|
|
1740
|
+
if (!existsSync(modelPath)) {
|
|
1741
|
+
return `Error: Model path not found: ${modelPath}`;
|
|
1742
|
+
}
|
|
1743
|
+
if (!['merge_lora', 'to_gguf', 'quantize'].includes(operation)) {
|
|
1744
|
+
return `Error: Invalid operation "${operation}". Use: merge_lora, to_gguf, quantize`;
|
|
1745
|
+
}
|
|
1746
|
+
if (operation === 'merge_lora') {
|
|
1747
|
+
const baseModel = args.base_model ? String(args.base_model) : '';
|
|
1748
|
+
if (!baseModel) {
|
|
1749
|
+
return 'Error: base_model is required for merge_lora operation.';
|
|
1750
|
+
}
|
|
1751
|
+
const outputPath = args.output
|
|
1752
|
+
? resolve(String(args.output))
|
|
1753
|
+
: resolve(dirname(modelPath), `${basename(modelPath)}-merged`);
|
|
1754
|
+
// Try MLX merge first (Apple Silicon)
|
|
1755
|
+
const hasMlx = shellSafe('python3 -c "import mlx_lm"');
|
|
1756
|
+
if (hasMlx.ok) {
|
|
1757
|
+
const cmd = `python3 -m mlx_lm.fuse --model ${baseModel} --adapter-path ${modelPath} --save-path ${outputPath}`;
|
|
1758
|
+
const result = shellSafe(cmd, { timeout: 300_000 });
|
|
1759
|
+
if (result.ok) {
|
|
1760
|
+
return [
|
|
1761
|
+
`LoRA merge completed (MLX).`,
|
|
1762
|
+
'',
|
|
1763
|
+
` Base model: ${baseModel}`,
|
|
1764
|
+
` Adapter: ${modelPath}`,
|
|
1765
|
+
` Merged output: ${outputPath}`,
|
|
1766
|
+
'',
|
|
1767
|
+
result.output,
|
|
1768
|
+
].join('\n');
|
|
1769
|
+
}
|
|
1770
|
+
// Fall through to try other methods
|
|
1771
|
+
}
|
|
1772
|
+
// Try with PEFT (PyTorch)
|
|
1773
|
+
const mergeScript = `
|
|
1774
|
+
import torch
|
|
1775
|
+
from peft import PeftModel
|
|
1776
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
1777
|
+
|
|
1778
|
+
print("Loading base model: ${baseModel}")
|
|
1779
|
+
model = AutoModelForCausalLM.from_pretrained("${baseModel}", torch_dtype=torch.float16, device_map="auto")
|
|
1780
|
+
tokenizer = AutoTokenizer.from_pretrained("${baseModel}")
|
|
1781
|
+
|
|
1782
|
+
print("Loading LoRA adapter: ${modelPath}")
|
|
1783
|
+
model = PeftModel.from_pretrained(model, "${modelPath}")
|
|
1784
|
+
|
|
1785
|
+
print("Merging LoRA weights into base model...")
|
|
1786
|
+
model = model.merge_and_unload()
|
|
1787
|
+
|
|
1788
|
+
print("Saving merged model to: ${outputPath}")
|
|
1789
|
+
model.save_pretrained("${outputPath}")
|
|
1790
|
+
tokenizer.save_pretrained("${outputPath}")
|
|
1791
|
+
print("Done!")
|
|
1792
|
+
`;
|
|
1793
|
+
const scriptPath = join(dirname(modelPath), '_merge_lora.py');
|
|
1794
|
+
writeFileSync(scriptPath, mergeScript, 'utf-8');
|
|
1795
|
+
const result = shellSafe(`python3 ${scriptPath}`, { timeout: 600_000 });
|
|
1796
|
+
// Clean up script
|
|
1797
|
+
try {
|
|
1798
|
+
execSync(`rm -f ${scriptPath}`, { stdio: 'pipe' });
|
|
1799
|
+
}
|
|
1800
|
+
catch { /* ignore */ }
|
|
1801
|
+
if (!result.ok) {
|
|
1802
|
+
return `Error merging LoRA:\n${result.output}\n\nEnsure transformers and peft are installed: pip install transformers peft torch`;
|
|
1803
|
+
}
|
|
1804
|
+
return [
|
|
1805
|
+
`LoRA merge completed (PEFT).`,
|
|
1806
|
+
'',
|
|
1807
|
+
` Base model: ${baseModel}`,
|
|
1808
|
+
` Adapter: ${modelPath}`,
|
|
1809
|
+
` Merged output: ${outputPath}`,
|
|
1810
|
+
'',
|
|
1811
|
+
result.output,
|
|
1812
|
+
].join('\n');
|
|
1813
|
+
}
|
|
1814
|
+
else if (operation === 'to_gguf') {
|
|
1815
|
+
const outputPath = args.output
|
|
1816
|
+
? resolve(String(args.output))
|
|
1817
|
+
: resolve(dirname(modelPath), `${basename(modelPath)}.${quantType}.gguf`);
|
|
1818
|
+
// Try llama.cpp's convert script
|
|
1819
|
+
const convertScript = shellSafe('which convert_hf_to_gguf.py || which convert-hf-to-gguf.py');
|
|
1820
|
+
let convertCmd;
|
|
1821
|
+
if (convertScript.ok && convertScript.output) {
|
|
1822
|
+
convertCmd = `python3 ${convertScript.output} ${modelPath} --outfile ${outputPath} --outtype ${quantType}`;
|
|
1823
|
+
}
|
|
1824
|
+
else {
|
|
1825
|
+
// Try finding it in common locations
|
|
1826
|
+
const commonPaths = [
|
|
1827
|
+
join(homedir(), 'llama.cpp/convert_hf_to_gguf.py'),
|
|
1828
|
+
join(homedir(), 'llama.cpp/convert-hf-to-gguf.py'),
|
|
1829
|
+
'/opt/llama.cpp/convert_hf_to_gguf.py',
|
|
1830
|
+
'/usr/local/share/llama.cpp/convert_hf_to_gguf.py',
|
|
1831
|
+
];
|
|
1832
|
+
const found = commonPaths.find(p => existsSync(p));
|
|
1833
|
+
if (!found) {
|
|
1834
|
+
return [
|
|
1835
|
+
'GGUF conversion script not found. Ensure llama.cpp is built:',
|
|
1836
|
+
'',
|
|
1837
|
+
' git clone https://github.com/ggerganov/llama.cpp',
|
|
1838
|
+
' cd llama.cpp',
|
|
1839
|
+
' pip install -r requirements.txt',
|
|
1840
|
+
'',
|
|
1841
|
+
'Then run:',
|
|
1842
|
+
` python3 convert_hf_to_gguf.py ${modelPath} --outfile ${outputPath} --outtype ${quantType}`,
|
|
1843
|
+
].join('\n');
|
|
1844
|
+
}
|
|
1845
|
+
convertCmd = `python3 ${found} ${modelPath} --outfile ${outputPath} --outtype ${quantType}`;
|
|
1846
|
+
}
|
|
1847
|
+
const result = shellSafe(convertCmd, { timeout: 600_000 });
|
|
1848
|
+
if (!result.ok) {
|
|
1849
|
+
return `Error converting to GGUF:\n${result.output}`;
|
|
1850
|
+
}
|
|
1851
|
+
const fileSize = existsSync(outputPath) ? (statSync(outputPath).size / (1024 * 1024 * 1024)).toFixed(2) : '?';
|
|
1852
|
+
return [
|
|
1853
|
+
`GGUF conversion completed.`,
|
|
1854
|
+
'',
|
|
1855
|
+
` Input: ${modelPath}`,
|
|
1856
|
+
` Output: ${outputPath}`,
|
|
1857
|
+
` Quantization: ${quantType}`,
|
|
1858
|
+
` File size: ${fileSize} GB`,
|
|
1859
|
+
'',
|
|
1860
|
+
result.output,
|
|
1861
|
+
].join('\n');
|
|
1862
|
+
}
|
|
1863
|
+
else if (operation === 'quantize') {
|
|
1864
|
+
if (!isCommandAvailable('llama-quantize')) {
|
|
1865
|
+
return [
|
|
1866
|
+
'llama-quantize is not installed. Build from llama.cpp:',
|
|
1867
|
+
'',
|
|
1868
|
+
' git clone https://github.com/ggerganov/llama.cpp',
|
|
1869
|
+
' cd llama.cpp',
|
|
1870
|
+
' make llama-quantize',
|
|
1871
|
+
].join('\n');
|
|
1872
|
+
}
|
|
1873
|
+
const outputPath = args.output
|
|
1874
|
+
? resolve(String(args.output))
|
|
1875
|
+
: modelPath.replace(/\.gguf$/, '') + `.${quantType}.gguf`;
|
|
1876
|
+
const result = shellSafe(`llama-quantize ${modelPath} ${outputPath} ${quantType}`, { timeout: 600_000 });
|
|
1877
|
+
if (!result.ok) {
|
|
1878
|
+
return `Error quantizing model:\n${result.output}`;
|
|
1879
|
+
}
|
|
1880
|
+
const inputSize = (statSync(modelPath).size / (1024 * 1024 * 1024)).toFixed(2);
|
|
1881
|
+
const outputSize = existsSync(outputPath) ? (statSync(outputPath).size / (1024 * 1024 * 1024)).toFixed(2) : '?';
|
|
1882
|
+
return [
|
|
1883
|
+
`Quantization completed.`,
|
|
1884
|
+
'',
|
|
1885
|
+
` Input: ${modelPath} (${inputSize} GB)`,
|
|
1886
|
+
` Output: ${outputPath} (${outputSize} GB)`,
|
|
1887
|
+
` Type: ${quantType}`,
|
|
1888
|
+
` Compression: ${inputSize !== '?' && outputSize !== '?'
|
|
1889
|
+
? ((1 - parseFloat(outputSize) / parseFloat(inputSize)) * 100).toFixed(1) + '%'
|
|
1890
|
+
: 'N/A'}`,
|
|
1891
|
+
'',
|
|
1892
|
+
result.output,
|
|
1893
|
+
].join('\n');
|
|
1894
|
+
}
|
|
1895
|
+
return `Error: Operation "${operation}" not handled.`;
|
|
1896
|
+
}
|
|
1897
|
+
catch (err) {
|
|
1898
|
+
return `Export error: ${err instanceof Error ? err.message : String(err)}`;
|
|
1899
|
+
}
|
|
1900
|
+
},
|
|
1901
|
+
});
|
|
1902
|
+
// ── train_deploy ───────────────────────────────────────────────────
|
|
1903
|
+
registerTool({
|
|
1904
|
+
name: 'train_deploy',
|
|
1905
|
+
description: 'Deploy a fine-tuned model. Targets: Ollama (local serving), HuggingFace Hub (public/private repo), ' +
|
|
1906
|
+
'or K:BOT local models directory (~/.kbot/models/).',
|
|
1907
|
+
parameters: {
|
|
1908
|
+
model_path: {
|
|
1909
|
+
type: 'string',
|
|
1910
|
+
description: 'Path to the model file (GGUF) or directory',
|
|
1911
|
+
required: true,
|
|
1912
|
+
},
|
|
1913
|
+
target: {
|
|
1914
|
+
type: 'string',
|
|
1915
|
+
description: 'Deployment target: ollama, huggingface, kbot-local',
|
|
1916
|
+
required: true,
|
|
1917
|
+
},
|
|
1918
|
+
name: {
|
|
1919
|
+
type: 'string',
|
|
1920
|
+
description: 'Model name for the deployment',
|
|
1921
|
+
required: true,
|
|
1922
|
+
},
|
|
1923
|
+
description: {
|
|
1924
|
+
type: 'string',
|
|
1925
|
+
description: 'Model description (optional)',
|
|
1926
|
+
},
|
|
1927
|
+
},
|
|
1928
|
+
tier: 'pro',
|
|
1929
|
+
timeout: 600_000,
|
|
1930
|
+
async execute(args) {
|
|
1931
|
+
try {
|
|
1932
|
+
const modelPath = resolve(String(args.model_path));
|
|
1933
|
+
const target = String(args.target).toLowerCase();
|
|
1934
|
+
const name = String(args.name);
|
|
1935
|
+
const description = args.description ? String(args.description) : `Fine-tuned model: ${name}`;
|
|
1936
|
+
if (!existsSync(modelPath)) {
|
|
1937
|
+
return `Error: Model path not found: ${modelPath}`;
|
|
1938
|
+
}
|
|
1939
|
+
if (!['ollama', 'huggingface', 'kbot-local'].includes(target)) {
|
|
1940
|
+
return `Error: Invalid target "${target}". Use: ollama, huggingface, kbot-local`;
|
|
1941
|
+
}
|
|
1942
|
+
if (target === 'ollama') {
|
|
1943
|
+
if (!isCommandAvailable('ollama')) {
|
|
1944
|
+
return 'Error: Ollama is not installed. Download from https://ollama.ai';
|
|
1945
|
+
}
|
|
1946
|
+
// Determine if model is GGUF file or directory
|
|
1947
|
+
const isGguf = modelPath.endsWith('.gguf');
|
|
1948
|
+
const fromLine = isGguf ? `FROM ${modelPath}` : `FROM ${modelPath}`;
|
|
1949
|
+
// Create a Modelfile
|
|
1950
|
+
const modelfileContent = [
|
|
1951
|
+
fromLine,
|
|
1952
|
+
'',
|
|
1953
|
+
`PARAMETER temperature 0.7`,
|
|
1954
|
+
`PARAMETER top_p 0.9`,
|
|
1955
|
+
`PARAMETER top_k 40`,
|
|
1956
|
+
'',
|
|
1957
|
+
`SYSTEM """${description}"""`,
|
|
1958
|
+
].join('\n');
|
|
1959
|
+
const modelfilePath = join(dirname(modelPath), 'Modelfile');
|
|
1960
|
+
writeFileSync(modelfilePath, modelfileContent, 'utf-8');
|
|
1961
|
+
// Create the Ollama model
|
|
1962
|
+
const result = shellSafe(`ollama create ${name} -f ${modelfilePath}`, { timeout: 300_000 });
|
|
1963
|
+
if (!result.ok) {
|
|
1964
|
+
return `Error creating Ollama model:\n${result.output}\n\nModelfile written to: ${modelfilePath}`;
|
|
1965
|
+
}
|
|
1966
|
+
return [
|
|
1967
|
+
`Model deployed to Ollama.`,
|
|
1968
|
+
'',
|
|
1969
|
+
` Name: ${name}`,
|
|
1970
|
+
` Source: ${modelPath}`,
|
|
1971
|
+
` Modelfile: ${modelfilePath}`,
|
|
1972
|
+
'',
|
|
1973
|
+
`Run it with: ollama run ${name}`,
|
|
1974
|
+
`Use in K:BOT: kbot --model ${name}`,
|
|
1975
|
+
'',
|
|
1976
|
+
result.output,
|
|
1977
|
+
].join('\n');
|
|
1978
|
+
}
|
|
1979
|
+
else if (target === 'huggingface') {
|
|
1980
|
+
if (!isCommandAvailable('huggingface-cli')) {
|
|
1981
|
+
return [
|
|
1982
|
+
'HuggingFace CLI is not installed. Install with:',
|
|
1983
|
+
'',
|
|
1984
|
+
' pip install huggingface_hub[cli]',
|
|
1985
|
+
'',
|
|
1986
|
+
'Then authenticate:',
|
|
1987
|
+
' huggingface-cli login',
|
|
1988
|
+
].join('\n');
|
|
1989
|
+
}
|
|
1990
|
+
// Check if HF_TOKEN is available
|
|
1991
|
+
const hasToken = process.env.HF_TOKEN || process.env.HUGGING_FACE_HUB_TOKEN;
|
|
1992
|
+
if (!hasToken) {
|
|
1993
|
+
const loginCheck = shellSafe('huggingface-cli whoami');
|
|
1994
|
+
if (!loginCheck.ok) {
|
|
1995
|
+
return 'Error: Not authenticated with HuggingFace. Run: huggingface-cli login\nOr set HF_TOKEN environment variable.';
|
|
1996
|
+
}
|
|
1997
|
+
}
|
|
1998
|
+
const result = shellSafe(`huggingface-cli upload ${name} ${modelPath}`, { timeout: 600_000 });
|
|
1999
|
+
if (!result.ok) {
|
|
2000
|
+
return `Error uploading to HuggingFace:\n${result.output}`;
|
|
2001
|
+
}
|
|
2002
|
+
return [
|
|
2003
|
+
`Model uploaded to HuggingFace Hub.`,
|
|
2004
|
+
'',
|
|
2005
|
+
` Repository: ${name}`,
|
|
2006
|
+
` Source: ${modelPath}`,
|
|
2007
|
+
` URL: https://huggingface.co/${name}`,
|
|
2008
|
+
'',
|
|
2009
|
+
result.output,
|
|
2010
|
+
].join('\n');
|
|
2011
|
+
}
|
|
2012
|
+
else if (target === 'kbot-local') {
|
|
2013
|
+
const modelsDir = join(homedir(), '.kbot', 'models');
|
|
2014
|
+
mkdirSync(modelsDir, { recursive: true });
|
|
2015
|
+
const isGguf = modelPath.endsWith('.gguf');
|
|
2016
|
+
const destFilename = isGguf ? `${name}.gguf` : name;
|
|
2017
|
+
const destPath = join(modelsDir, destFilename);
|
|
2018
|
+
// Copy model file(s)
|
|
2019
|
+
const stat = statSync(modelPath);
|
|
2020
|
+
if (stat.isFile()) {
|
|
2021
|
+
const result = shellSafe(`cp ${modelPath} ${destPath}`, { timeout: 120_000 });
|
|
2022
|
+
if (!result.ok) {
|
|
2023
|
+
return `Error copying model: ${result.output}`;
|
|
2024
|
+
}
|
|
2025
|
+
}
|
|
2026
|
+
else if (stat.isDirectory()) {
|
|
2027
|
+
const destDir = join(modelsDir, name);
|
|
2028
|
+
mkdirSync(destDir, { recursive: true });
|
|
2029
|
+
const result = shellSafe(`cp -r ${modelPath}/* ${destDir}/`, { timeout: 300_000 });
|
|
2030
|
+
if (!result.ok) {
|
|
2031
|
+
return `Error copying model directory: ${result.output}`;
|
|
2032
|
+
}
|
|
2033
|
+
}
|
|
2034
|
+
// Register in K:BOT config
|
|
2035
|
+
const configPath = join(homedir(), '.kbot', 'config.json');
|
|
2036
|
+
let config = {};
|
|
2037
|
+
if (existsSync(configPath)) {
|
|
2038
|
+
try {
|
|
2039
|
+
config = JSON.parse(readFileSync(configPath, 'utf-8'));
|
|
2040
|
+
}
|
|
2041
|
+
catch { /* start fresh */ }
|
|
2042
|
+
}
|
|
2043
|
+
if (!config.local_models || !Array.isArray(config.local_models)) {
|
|
2044
|
+
config.local_models = [];
|
|
2045
|
+
}
|
|
2046
|
+
const modelEntry = {
|
|
2047
|
+
name,
|
|
2048
|
+
path: stat.isDirectory() ? join(modelsDir, name) : destPath,
|
|
2049
|
+
description,
|
|
2050
|
+
added: new Date().toISOString(),
|
|
2051
|
+
type: isGguf ? 'gguf' : 'directory',
|
|
2052
|
+
};
|
|
2053
|
+
// Remove existing entry with same name
|
|
2054
|
+
config.local_models = config.local_models.filter(m => m.name !== name);
|
|
2055
|
+
config.local_models.push(modelEntry);
|
|
2056
|
+
writeFileSync(configPath, JSON.stringify(config, null, 2), 'utf-8');
|
|
2057
|
+
const fileSize = stat.isFile()
|
|
2058
|
+
? `${(stat.size / (1024 * 1024 * 1024)).toFixed(2)} GB`
|
|
2059
|
+
: 'directory';
|
|
2060
|
+
return [
|
|
2061
|
+
`Model registered in K:BOT local models.`,
|
|
2062
|
+
'',
|
|
2063
|
+
` Name: ${name}`,
|
|
2064
|
+
` Path: ${modelEntry.path}`,
|
|
2065
|
+
` Size: ${fileSize}`,
|
|
2066
|
+
` Description: ${description}`,
|
|
2067
|
+
'',
|
|
2068
|
+
`Use in K:BOT: kbot --model ${name}`,
|
|
2069
|
+
`List models: kbot models`,
|
|
2070
|
+
].join('\n');
|
|
2071
|
+
}
|
|
2072
|
+
return `Error: Target "${target}" not handled.`;
|
|
2073
|
+
}
|
|
2074
|
+
catch (err) {
|
|
2075
|
+
return `Deploy error: ${err instanceof Error ? err.message : String(err)}`;
|
|
2076
|
+
}
|
|
2077
|
+
},
|
|
2078
|
+
});
|
|
2079
|
+
// ── train_cost ─────────────────────────────────────────────────────
|
|
2080
|
+
registerTool({
|
|
2081
|
+
name: 'train_cost',
|
|
2082
|
+
description: 'Estimate the cost, time, and VRAM requirements for a fine-tuning job before starting. ' +
|
|
2083
|
+
'For cloud backends, calculates token-based pricing. For local, estimates GPU hours and VRAM usage.',
|
|
2084
|
+
parameters: {
|
|
2085
|
+
dataset: {
|
|
2086
|
+
type: 'string',
|
|
2087
|
+
description: 'Path to training dataset file',
|
|
2088
|
+
required: true,
|
|
2089
|
+
},
|
|
2090
|
+
base_model: {
|
|
2091
|
+
type: 'string',
|
|
2092
|
+
description: 'Base model to fine-tune (e.g., gpt-4.1-mini, llama-3-8b)',
|
|
2093
|
+
required: true,
|
|
2094
|
+
},
|
|
2095
|
+
backend: {
|
|
2096
|
+
type: 'string',
|
|
2097
|
+
description: 'Training backend: openai, together, mistral, mlx, unsloth, llama-cpp',
|
|
2098
|
+
required: true,
|
|
2099
|
+
},
|
|
2100
|
+
epochs: {
|
|
2101
|
+
type: 'number',
|
|
2102
|
+
description: 'Number of training epochs (default: 3)',
|
|
2103
|
+
},
|
|
2104
|
+
},
|
|
2105
|
+
tier: 'free',
|
|
2106
|
+
timeout: 30_000,
|
|
2107
|
+
async execute(args) {
|
|
2108
|
+
try {
|
|
2109
|
+
const datasetPath = resolve(String(args.dataset));
|
|
2110
|
+
const baseModel = String(args.base_model);
|
|
2111
|
+
const backend = String(args.backend).toLowerCase();
|
|
2112
|
+
const epochs = typeof args.epochs === 'number' ? args.epochs : 3;
|
|
2113
|
+
if (!existsSync(datasetPath)) {
|
|
2114
|
+
return `Error: Dataset not found: ${datasetPath}`;
|
|
2115
|
+
}
|
|
2116
|
+
const validBackends = ['openai', 'together', 'mistral', 'mlx', 'unsloth', 'llama-cpp'];
|
|
2117
|
+
if (!validBackends.includes(backend)) {
|
|
2118
|
+
return `Error: Invalid backend "${backend}". Supported: ${validBackends.join(', ')}`;
|
|
2119
|
+
}
|
|
2120
|
+
// Calculate dataset size
|
|
2121
|
+
const content = readFileSync(datasetPath, 'utf-8');
|
|
2122
|
+
const fileSize = statSync(datasetPath).size;
|
|
2123
|
+
const totalTokens = estimateTokens(content);
|
|
2124
|
+
const trainingTokens = totalTokens * epochs;
|
|
2125
|
+
// Estimate model parameters
|
|
2126
|
+
const modelParams = estimateModelParams(baseModel);
|
|
2127
|
+
const modelParamsB = (modelParams / 1e9).toFixed(1);
|
|
2128
|
+
// ── Cloud pricing ────────────────────────────────────────────
|
|
2129
|
+
const cloudPricing = {
|
|
2130
|
+
openai: {
|
|
2131
|
+
perKToken: baseModel.includes('gpt-4.1-mini') ? 0.008 : baseModel.includes('gpt-4.1') ? 0.025 : 0.008,
|
|
2132
|
+
label: baseModel.includes('gpt-4.1') && !baseModel.includes('mini') ? 'GPT-4.1' : 'GPT-4.1 Mini',
|
|
2133
|
+
},
|
|
2134
|
+
together: {
|
|
2135
|
+
perKToken: 0.004,
|
|
2136
|
+
label: 'Together AI',
|
|
2137
|
+
},
|
|
2138
|
+
mistral: {
|
|
2139
|
+
perKToken: 0.008,
|
|
2140
|
+
label: 'Mistral',
|
|
2141
|
+
},
|
|
2142
|
+
};
|
|
2143
|
+
// ── VRAM estimation ──────────────────────────────────────────
|
|
2144
|
+
// Full fine-tuning: ~18-20 bytes per param (fp16 model + optimizer states + gradients)
|
|
2145
|
+
// QLoRA: ~4-6 bytes per param (4-bit model + small LoRA overhead)
|
|
2146
|
+
// LoRA (fp16): ~10-12 bytes per param (fp16 model + small LoRA)
|
|
2147
|
+
let vramGB;
|
|
2148
|
+
let vramMethod;
|
|
2149
|
+
if (backend === 'mlx') {
|
|
2150
|
+
// MLX uses unified memory, fp16 model + LoRA
|
|
2151
|
+
vramGB = (modelParams * 4) / (1024 * 1024 * 1024); // ~4 bytes for 4-bit + LoRA overhead
|
|
2152
|
+
vramMethod = '4-bit + LoRA (MLX unified memory)';
|
|
2153
|
+
}
|
|
2154
|
+
else if (backend === 'unsloth') {
|
|
2155
|
+
// Unsloth uses 4-bit QLoRA
|
|
2156
|
+
vramGB = (modelParams * 5) / (1024 * 1024 * 1024);
|
|
2157
|
+
vramMethod = '4-bit QLoRA (Unsloth)';
|
|
2158
|
+
}
|
|
2159
|
+
else if (backend === 'llama-cpp') {
|
|
2160
|
+
// llama.cpp LoRA fine-tuning
|
|
2161
|
+
vramGB = (modelParams * 6) / (1024 * 1024 * 1024);
|
|
2162
|
+
vramMethod = 'LoRA (llama.cpp, CPU/GPU mixed)';
|
|
2163
|
+
}
|
|
2164
|
+
else {
|
|
2165
|
+
// Cloud — user doesn't need to worry about VRAM
|
|
2166
|
+
vramGB = 0;
|
|
2167
|
+
vramMethod = 'Cloud-managed (no local VRAM needed)';
|
|
2168
|
+
}
|
|
2169
|
+
// ── Time estimation ──────────────────────────────────────────
|
|
2170
|
+
// Rough estimates based on model size and training tokens
|
|
2171
|
+
let estimatedMinutes;
|
|
2172
|
+
let timeNote;
|
|
2173
|
+
if (backend === 'openai') {
|
|
2174
|
+
// OpenAI typically processes ~1M tokens/hour for fine-tuning
|
|
2175
|
+
estimatedMinutes = (trainingTokens / 1_000_000) * 60;
|
|
2176
|
+
timeNote = 'OpenAI job queue + training';
|
|
2177
|
+
}
|
|
2178
|
+
else if (backend === 'together') {
|
|
2179
|
+
estimatedMinutes = (trainingTokens / 800_000) * 60;
|
|
2180
|
+
timeNote = 'Together AI queue + training';
|
|
2181
|
+
}
|
|
2182
|
+
else if (backend === 'mistral') {
|
|
2183
|
+
estimatedMinutes = (trainingTokens / 700_000) * 60;
|
|
2184
|
+
timeNote = 'Mistral queue + training';
|
|
2185
|
+
}
|
|
2186
|
+
else if (backend === 'mlx') {
|
|
2187
|
+
// Apple Silicon: ~200-500 tokens/sec for 7B model LoRA
|
|
2188
|
+
const tokPerSec = modelParams <= 8e9 ? 400 : modelParams <= 14e9 ? 150 : 50;
|
|
2189
|
+
estimatedMinutes = (trainingTokens / tokPerSec) / 60;
|
|
2190
|
+
timeNote = `~${tokPerSec} tok/s estimated for ${modelParamsB}B on Apple Silicon`;
|
|
2191
|
+
}
|
|
2192
|
+
else if (backend === 'unsloth') {
|
|
2193
|
+
// Unsloth with consumer GPU: ~500-1000 tokens/sec for 7B
|
|
2194
|
+
const tokPerSec = modelParams <= 8e9 ? 800 : modelParams <= 14e9 ? 300 : 100;
|
|
2195
|
+
estimatedMinutes = (trainingTokens / tokPerSec) / 60;
|
|
2196
|
+
timeNote = `~${tokPerSec} tok/s estimated for ${modelParamsB}B with Unsloth`;
|
|
2197
|
+
}
|
|
2198
|
+
else {
|
|
2199
|
+
// llama.cpp CPU: ~50-200 tokens/sec
|
|
2200
|
+
const tokPerSec = modelParams <= 8e9 ? 150 : modelParams <= 14e9 ? 50 : 15;
|
|
2201
|
+
estimatedMinutes = (trainingTokens / tokPerSec) / 60;
|
|
2202
|
+
timeNote = `~${tokPerSec} tok/s estimated for ${modelParamsB}B on CPU`;
|
|
2203
|
+
}
|
|
2204
|
+
// Format time
|
|
2205
|
+
let timeStr;
|
|
2206
|
+
if (estimatedMinutes < 60) {
|
|
2207
|
+
timeStr = `${Math.ceil(estimatedMinutes)} minutes`;
|
|
2208
|
+
}
|
|
2209
|
+
else if (estimatedMinutes < 1440) {
|
|
2210
|
+
timeStr = `${(estimatedMinutes / 60).toFixed(1)} hours`;
|
|
2211
|
+
}
|
|
2212
|
+
else {
|
|
2213
|
+
timeStr = `${(estimatedMinutes / 1440).toFixed(1)} days`;
|
|
2214
|
+
}
|
|
2215
|
+
// ── Cost for cloud ───────────────────────────────────────────
|
|
2216
|
+
let costStr;
|
|
2217
|
+
if (backend === 'openai' || backend === 'together' || backend === 'mistral') {
|
|
2218
|
+
const pricing = cloudPricing[backend];
|
|
2219
|
+
const cost = (trainingTokens / 1000) * pricing.perKToken;
|
|
2220
|
+
costStr = `$${cost.toFixed(2)} (${pricing.label} @ $${pricing.perKToken}/1K tokens)`;
|
|
2221
|
+
}
|
|
2222
|
+
else {
|
|
2223
|
+
// Local: estimate electricity cost
|
|
2224
|
+
let wattage;
|
|
2225
|
+
if (backend === 'mlx')
|
|
2226
|
+
wattage = 30; // Apple Silicon TDP
|
|
2227
|
+
else if (backend === 'unsloth')
|
|
2228
|
+
wattage = 300; // GPU TDP
|
|
2229
|
+
else
|
|
2230
|
+
wattage = 150; // CPU
|
|
2231
|
+
const kWh = (wattage * estimatedMinutes / 60) / 1000;
|
|
2232
|
+
const electricityCost = kWh * 0.15; // $0.15/kWh average
|
|
2233
|
+
costStr = `~$${electricityCost.toFixed(2)} electricity (${kWh.toFixed(2)} kWh @ $0.15/kWh)`;
|
|
2234
|
+
}
|
|
2235
|
+
// ── GPU recommendations ──────────────────────────────────────
|
|
2236
|
+
let gpuRec;
|
|
2237
|
+
if (backend === 'mlx') {
|
|
2238
|
+
if (vramGB <= 8)
|
|
2239
|
+
gpuRec = 'M1/M2/M3 with 8GB+ unified memory';
|
|
2240
|
+
else if (vramGB <= 16)
|
|
2241
|
+
gpuRec = 'M1 Pro/Max/M2 Pro/Max with 16GB+ unified memory';
|
|
2242
|
+
else if (vramGB <= 36)
|
|
2243
|
+
gpuRec = 'M2 Max/M3 Max with 36GB+ unified memory';
|
|
2244
|
+
else if (vramGB <= 64)
|
|
2245
|
+
gpuRec = 'M2 Ultra/M3 Ultra with 64GB+ unified memory';
|
|
2246
|
+
else
|
|
2247
|
+
gpuRec = 'M2 Ultra/M3 Ultra with 128GB+ unified memory (or use cloud)';
|
|
2248
|
+
}
|
|
2249
|
+
else if (backend === 'unsloth') {
|
|
2250
|
+
if (vramGB <= 8)
|
|
2251
|
+
gpuRec = 'RTX 3060 12GB / RTX 4060 Ti';
|
|
2252
|
+
else if (vramGB <= 16)
|
|
2253
|
+
gpuRec = 'RTX 3090 / RTX 4080 / A5000';
|
|
2254
|
+
else if (vramGB <= 24)
|
|
2255
|
+
gpuRec = 'RTX 3090 Ti / RTX 4090 / A5000';
|
|
2256
|
+
else if (vramGB <= 48)
|
|
2257
|
+
gpuRec = 'A6000 / 2x RTX 4090';
|
|
2258
|
+
else
|
|
2259
|
+
gpuRec = 'A100 80GB / H100 (or use cloud)';
|
|
2260
|
+
}
|
|
2261
|
+
else if (backend === 'llama-cpp') {
|
|
2262
|
+
const ramGB = Math.ceil(vramGB * 1.5);
|
|
2263
|
+
gpuRec = `${ramGB}GB+ system RAM (llama.cpp uses CPU + optional GPU offload)`;
|
|
2264
|
+
}
|
|
2265
|
+
else {
|
|
2266
|
+
gpuRec = 'N/A (cloud-managed)';
|
|
2267
|
+
}
|
|
2268
|
+
// ── Build report ─────────────────────────────────────────────
|
|
2269
|
+
const lines = content.split('\n').filter(l => l.trim()).length;
|
|
2270
|
+
return [
|
|
2271
|
+
`Training Cost Estimate`,
|
|
2272
|
+
'='.repeat(50),
|
|
2273
|
+
'',
|
|
2274
|
+
`Dataset`,
|
|
2275
|
+
'─'.repeat(30),
|
|
2276
|
+
` File: ${datasetPath}`,
|
|
2277
|
+
` File size: ${(fileSize / 1024).toFixed(1)} KB`,
|
|
2278
|
+
` Examples: ~${lines.toLocaleString()}`,
|
|
2279
|
+
` Tokens: ~${totalTokens.toLocaleString()}`,
|
|
2280
|
+
` Training tokens: ~${trainingTokens.toLocaleString()} (${epochs} epochs)`,
|
|
2281
|
+
'',
|
|
2282
|
+
`Model`,
|
|
2283
|
+
'─'.repeat(30),
|
|
2284
|
+
` Base model: ${baseModel}`,
|
|
2285
|
+
` Est. params: ${modelParamsB}B`,
|
|
2286
|
+
` Backend: ${backend}`,
|
|
2287
|
+
'',
|
|
2288
|
+
`Cost`,
|
|
2289
|
+
'─'.repeat(30),
|
|
2290
|
+
` Estimated: ${costStr}`,
|
|
2291
|
+
'',
|
|
2292
|
+
`Time`,
|
|
2293
|
+
'─'.repeat(30),
|
|
2294
|
+
` Estimated: ${timeStr}`,
|
|
2295
|
+
` Note: ${timeNote}`,
|
|
2296
|
+
'',
|
|
2297
|
+
`Resources`,
|
|
2298
|
+
'─'.repeat(30),
|
|
2299
|
+
` VRAM needed: ${vramGB > 0 ? `${vramGB.toFixed(1)} GB` : 'N/A (cloud)'}`,
|
|
2300
|
+
` Method: ${vramMethod}`,
|
|
2301
|
+
` Recommended: ${gpuRec}`,
|
|
2302
|
+
'',
|
|
2303
|
+
`Note: These are rough estimates. Actual cost and time depend on dataset complexity,`,
|
|
2304
|
+
`hardware configuration, queue times (cloud), and hyperparameters.`,
|
|
2305
|
+
].join('\n');
|
|
2306
|
+
}
|
|
2307
|
+
catch (err) {
|
|
2308
|
+
return `Cost estimation error: ${err instanceof Error ? err.message : String(err)}`;
|
|
2309
|
+
}
|
|
2310
|
+
},
|
|
2311
|
+
});
|
|
2312
|
+
}
|
|
2313
|
+
//# sourceMappingURL=training.js.map
|