nanogpt 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile.lock +30 -1
- data/docs/ARCHITECTURE.md +429 -0
- data/exe/nanogpt +210 -233
- data/lib/nano_gpt/bpe_textfile_preparer.rb +105 -0
- data/lib/nano_gpt/data_loader.rb +5 -20
- data/lib/nano_gpt/layers/block.rb +6 -1
- data/lib/nano_gpt/layers/causal_self_attention.rb +11 -1
- data/lib/nano_gpt/model.rb +1 -7
- data/lib/nano_gpt/textfile_preparer.rb +189 -0
- data/lib/nano_gpt/train_config.rb +80 -146
- data/lib/nano_gpt/trainer.rb +21 -48
- data/lib/nano_gpt/version.rb +1 -1
- data/lib/nano_gpt/web/metrics_store.rb +136 -0
- data/lib/nano_gpt/web/server.rb +294 -0
- data/lib/nano_gpt/web/sse_notifier.rb +37 -0
- data/lib/nano_gpt/web/training_state.rb +56 -0
- data/lib/nano_gpt/web/training_worker.rb +153 -0
- data/lib/nano_gpt/web/views/layout.erb +78 -0
- data/lib/nano_gpt/web/views/run_detail.erb +432 -0
- data/lib/nano_gpt/web/views/runs.erb +434 -0
- data/lib/nano_gpt/web/web_trainer.rb +210 -0
- data/lib/nano_gpt/web.rb +9 -0
- data/lib/nano_gpt.rb +1 -0
- data/nanogpt.gemspec +4 -0
- metadata +71 -2
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
<!-- Header -->
|
|
2
|
+
<div class="flex items-center justify-between">
|
|
3
|
+
<h1 class="text-lg font-semibold">Runs</h1>
|
|
4
|
+
<div class="flex gap-2">
|
|
5
|
+
<button onclick="toggleImportModal()" class="px-3 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg text-sm transition">
|
|
6
|
+
Import Dataset
|
|
7
|
+
</button>
|
|
8
|
+
<button onclick="openNewRunModal()" class="px-4 py-2 bg-green-600 hover:bg-green-700 rounded-lg font-medium text-sm transition">
|
|
9
|
+
+ New Run
|
|
10
|
+
</button>
|
|
11
|
+
</div>
|
|
12
|
+
</div>
|
|
13
|
+
|
|
14
|
+
<!-- Runs List -->
|
|
15
|
+
<div id="runs-list" class="space-y-2">
|
|
16
|
+
<div class="text-center text-gray-500 py-8">Loading runs...</div>
|
|
17
|
+
</div>
|
|
18
|
+
|
|
19
|
+
<!-- New Run Modal -->
|
|
20
|
+
<div id="new-run-modal" class="fixed inset-0 bg-black/60 z-50 hidden flex items-start justify-center pt-16 px-4">
|
|
21
|
+
<div class="bg-gray-800 rounded-xl shadow-2xl w-full max-w-2xl max-h-[80vh] overflow-y-auto">
|
|
22
|
+
<div class="flex items-center justify-between px-6 py-4 border-b border-gray-700">
|
|
23
|
+
<h2 class="text-base font-semibold">New Training Run</h2>
|
|
24
|
+
<button onclick="closeNewRunModal()" class="text-gray-400 hover:text-white text-xl leading-none">×</button>
|
|
25
|
+
</div>
|
|
26
|
+
|
|
27
|
+
<div class="px-6 py-5 space-y-5">
|
|
28
|
+
<!-- Dataset -->
|
|
29
|
+
<div>
|
|
30
|
+
<label class="block text-xs text-gray-400 mb-1.5">Dataset</label>
|
|
31
|
+
<select id="cfg-dataset" class="w-full bg-gray-700 rounded px-3 py-2 text-sm"></select>
|
|
32
|
+
</div>
|
|
33
|
+
|
|
34
|
+
<!-- Basic -->
|
|
35
|
+
<div>
|
|
36
|
+
<div class="text-xs text-gray-400 uppercase tracking-wide mb-2">Training</div>
|
|
37
|
+
<div class="grid grid-cols-3 gap-3">
|
|
38
|
+
<div>
|
|
39
|
+
<label class="block text-xs text-gray-400 mb-1">Max Iters</label>
|
|
40
|
+
<input id="cfg-max_iters" type="number" value="5000" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
41
|
+
</div>
|
|
42
|
+
<div>
|
|
43
|
+
<label class="block text-xs text-gray-400 mb-1">Learning Rate</label>
|
|
44
|
+
<input id="cfg-learning_rate" type="number" value="0.001" step="0.0001" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
45
|
+
</div>
|
|
46
|
+
<div>
|
|
47
|
+
<label class="block text-xs text-gray-400 mb-1">Batch Size</label>
|
|
48
|
+
<input id="cfg-batch_size" type="number" value="64" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
49
|
+
</div>
|
|
50
|
+
</div>
|
|
51
|
+
</div>
|
|
52
|
+
|
|
53
|
+
<!-- Model -->
|
|
54
|
+
<div>
|
|
55
|
+
<div class="text-xs text-gray-400 uppercase tracking-wide mb-2">Model</div>
|
|
56
|
+
<div class="grid grid-cols-3 gap-3">
|
|
57
|
+
<div>
|
|
58
|
+
<label class="block text-xs text-gray-400 mb-1">Block Size</label>
|
|
59
|
+
<input id="cfg-block_size" type="number" value="256" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
60
|
+
</div>
|
|
61
|
+
<div>
|
|
62
|
+
<label class="block text-xs text-gray-400 mb-1">Layers</label>
|
|
63
|
+
<input id="cfg-n_layer" type="number" value="6" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
64
|
+
</div>
|
|
65
|
+
<div>
|
|
66
|
+
<label class="block text-xs text-gray-400 mb-1">Heads</label>
|
|
67
|
+
<input id="cfg-n_head" type="number" value="6" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
68
|
+
</div>
|
|
69
|
+
<div>
|
|
70
|
+
<label class="block text-xs text-gray-400 mb-1">Embedding Dim</label>
|
|
71
|
+
<input id="cfg-n_embd" type="number" value="384" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
72
|
+
</div>
|
|
73
|
+
<div>
|
|
74
|
+
<label class="block text-xs text-gray-400 mb-1">Dropout</label>
|
|
75
|
+
<input id="cfg-dropout" type="number" value="0.2" step="0.05" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
76
|
+
</div>
|
|
77
|
+
</div>
|
|
78
|
+
</div>
|
|
79
|
+
|
|
80
|
+
<!-- Advanced (collapsed) -->
|
|
81
|
+
<details>
|
|
82
|
+
<summary class="text-xs text-gray-400 uppercase tracking-wide cursor-pointer hover:text-gray-300">Advanced</summary>
|
|
83
|
+
<div class="grid grid-cols-3 gap-3 mt-3">
|
|
84
|
+
<div>
|
|
85
|
+
<label class="block text-xs text-gray-400 mb-1">Eval Interval</label>
|
|
86
|
+
<input id="cfg-eval_interval" type="number" value="250" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
87
|
+
</div>
|
|
88
|
+
<div>
|
|
89
|
+
<label class="block text-xs text-gray-400 mb-1">Log Interval</label>
|
|
90
|
+
<input id="cfg-log_interval" type="number" value="10" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
91
|
+
</div>
|
|
92
|
+
<div>
|
|
93
|
+
<label class="block text-xs text-gray-400 mb-1">Eval Iters</label>
|
|
94
|
+
<input id="cfg-eval_iters" type="number" value="20" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
95
|
+
</div>
|
|
96
|
+
<div>
|
|
97
|
+
<label class="block text-xs text-gray-400 mb-1">Grad Accumulation</label>
|
|
98
|
+
<input id="cfg-gradient_accumulation_steps" type="number" value="1" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
99
|
+
</div>
|
|
100
|
+
<div>
|
|
101
|
+
<label class="block text-xs text-gray-400 mb-1">Weight Decay</label>
|
|
102
|
+
<input id="cfg-weight_decay" type="number" value="0.1" step="0.01" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
103
|
+
</div>
|
|
104
|
+
<div>
|
|
105
|
+
<label class="block text-xs text-gray-400 mb-1">Grad Clip</label>
|
|
106
|
+
<input id="cfg-grad_clip" type="number" value="1.0" step="0.1" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
107
|
+
</div>
|
|
108
|
+
<div>
|
|
109
|
+
<label class="block text-xs text-gray-400 mb-1">Warmup Iters</label>
|
|
110
|
+
<input id="cfg-warmup_iters" type="number" value="100" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
111
|
+
</div>
|
|
112
|
+
<div>
|
|
113
|
+
<label class="block text-xs text-gray-400 mb-1">LR Decay Iters</label>
|
|
114
|
+
<input id="cfg-lr_decay_iters" type="number" value="5000" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
115
|
+
</div>
|
|
116
|
+
<div>
|
|
117
|
+
<label class="block text-xs text-gray-400 mb-1">Min LR</label>
|
|
118
|
+
<input id="cfg-min_lr" type="number" value="0.0001" step="0.0001" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
119
|
+
</div>
|
|
120
|
+
<div>
|
|
121
|
+
<label class="block text-xs text-gray-400 mb-1">Bias</label>
|
|
122
|
+
<select id="cfg-bias" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
123
|
+
<option value="false" selected>false</option>
|
|
124
|
+
<option value="true">true</option>
|
|
125
|
+
</select>
|
|
126
|
+
</div>
|
|
127
|
+
<div>
|
|
128
|
+
<label class="block text-xs text-gray-400 mb-1">Always Save Ckpt</label>
|
|
129
|
+
<select id="cfg-always_save_checkpoint" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
130
|
+
<option value="false" selected>false</option>
|
|
131
|
+
<option value="true">true</option>
|
|
132
|
+
</select>
|
|
133
|
+
</div>
|
|
134
|
+
</div>
|
|
135
|
+
</details>
|
|
136
|
+
</div>
|
|
137
|
+
|
|
138
|
+
<div class="px-6 py-4 border-t border-gray-700 flex justify-end gap-3">
|
|
139
|
+
<button onclick="closeNewRunModal()" class="px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg text-sm transition">Cancel</button>
|
|
140
|
+
<button id="btn-start-run" onclick="startRun()" class="px-5 py-2 bg-green-600 hover:bg-green-700 rounded-lg font-medium text-sm transition">
|
|
141
|
+
Start Training
|
|
142
|
+
</button>
|
|
143
|
+
</div>
|
|
144
|
+
</div>
|
|
145
|
+
</div>
|
|
146
|
+
|
|
147
|
+
<!-- Import Dataset Modal -->
|
|
148
|
+
<div id="import-modal" class="fixed inset-0 bg-black/60 z-50 hidden flex items-start justify-center pt-16 px-4">
|
|
149
|
+
<div class="bg-gray-800 rounded-xl shadow-2xl w-full max-w-lg">
|
|
150
|
+
<div class="flex items-center justify-between px-6 py-4 border-b border-gray-700">
|
|
151
|
+
<h2 class="text-base font-semibold">Import Dataset</h2>
|
|
152
|
+
<button onclick="toggleImportModal()" class="text-gray-400 hover:text-white text-xl leading-none">×</button>
|
|
153
|
+
</div>
|
|
154
|
+
<form id="upload-form" enctype="multipart/form-data" class="px-6 py-5 space-y-4">
|
|
155
|
+
<div>
|
|
156
|
+
<label class="block text-xs text-gray-400 mb-1">Text File</label>
|
|
157
|
+
<input id="import-file" name="file" type="file" accept=".txt,.text" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm file:mr-3 file:py-1 file:px-3 file:rounded file:border-0 file:text-sm file:bg-gray-600 file:text-gray-200">
|
|
158
|
+
</div>
|
|
159
|
+
<div>
|
|
160
|
+
<label class="block text-xs text-gray-400 mb-1">Dataset Name</label>
|
|
161
|
+
<input id="import-name" name="name" type="text" placeholder="auto from filename" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
162
|
+
</div>
|
|
163
|
+
<div>
|
|
164
|
+
<label class="block text-xs text-gray-400 mb-1">Tokenizer</label>
|
|
165
|
+
<div class="flex gap-4 mt-1">
|
|
166
|
+
<label class="flex items-center gap-1.5 text-sm"><input type="radio" name="tokenizer" value="char" checked class="accent-green-500"> Char</label>
|
|
167
|
+
<label class="flex items-center gap-1.5 text-sm"><input type="radio" name="tokenizer" value="bpe" class="accent-green-500"> BPE (GPT-2)</label>
|
|
168
|
+
</div>
|
|
169
|
+
</div>
|
|
170
|
+
<div>
|
|
171
|
+
<label class="block text-xs text-gray-400 mb-1">Val Split</label>
|
|
172
|
+
<input id="import-val-ratio" name="val_ratio" type="number" value="0.1" step="0.05" min="0" max="0.5" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
173
|
+
</div>
|
|
174
|
+
</form>
|
|
175
|
+
<div class="px-6 py-4 border-t border-gray-700 flex items-center justify-between">
|
|
176
|
+
<span id="import-status" class="text-xs text-gray-400"></span>
|
|
177
|
+
<button id="btn-upload" onclick="uploadDataset()" class="px-4 py-2 bg-blue-600 hover:bg-blue-700 rounded-lg font-medium text-sm transition">
|
|
178
|
+
Upload & Prepare
|
|
179
|
+
</button>
|
|
180
|
+
</div>
|
|
181
|
+
</div>
|
|
182
|
+
</div>
|
|
183
|
+
|
|
184
|
+
<script>
|
|
185
|
+
let runsData = [];
|
|
186
|
+
let currentTrainingRunId = null;
|
|
187
|
+
|
|
188
|
+
// ---- Load Runs ----
|
|
189
|
+
async function loadRuns() {
|
|
190
|
+
try {
|
|
191
|
+
const res = await fetch('/api/runs');
|
|
192
|
+
const data = await res.json();
|
|
193
|
+
runsData = data.runs || [];
|
|
194
|
+
renderRuns();
|
|
195
|
+
} catch (e) {
|
|
196
|
+
document.getElementById('runs-list').innerHTML =
|
|
197
|
+
'<div class="text-center text-red-400 py-8">Failed to load runs</div>';
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
function renderRuns() {
|
|
202
|
+
const container = document.getElementById('runs-list');
|
|
203
|
+
if (runsData.length === 0) {
|
|
204
|
+
container.innerHTML = '<div class="text-center text-gray-500 py-12">No runs yet. Click "+ New Run" to start training.</div>';
|
|
205
|
+
return;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
container.innerHTML = runsData.map(r => {
|
|
209
|
+
const valLoss = r.best_val_loss ? parseFloat(r.best_val_loss).toFixed(4) : '--';
|
|
210
|
+
const iters = r.current_iter || 0;
|
|
211
|
+
const config = JSON.parse(r.config_json || '{}');
|
|
212
|
+
const maxIters = config.max_iters || '?';
|
|
213
|
+
const isActive = r.status === 'running' && currentTrainingRunId === r.id;
|
|
214
|
+
|
|
215
|
+
return `
|
|
216
|
+
<div class="bg-gray-800 rounded-lg p-4 hover:bg-gray-750 transition cursor-pointer flex items-center justify-between group"
|
|
217
|
+
onclick="window.location='/runs/${r.id}'">
|
|
218
|
+
<div class="flex items-center gap-4 min-w-0">
|
|
219
|
+
<div class="text-sm font-mono text-gray-500 w-8">#${r.id}</div>
|
|
220
|
+
<div class="min-w-0">
|
|
221
|
+
<div class="flex items-center gap-2">
|
|
222
|
+
<span class="font-medium text-sm">${r.dataset}</span>
|
|
223
|
+
${statusBadge(r.status)}
|
|
224
|
+
</div>
|
|
225
|
+
<div class="text-xs text-gray-500 mt-0.5">
|
|
226
|
+
${formatTime(r.started_at)}
|
|
227
|
+
· ${iters}/${maxIters} iters
|
|
228
|
+
· val_loss: ${valLoss}
|
|
229
|
+
</div>
|
|
230
|
+
</div>
|
|
231
|
+
</div>
|
|
232
|
+
<div class="flex items-center gap-2" onclick="event.stopPropagation()">
|
|
233
|
+
${r.status === 'running' ? `
|
|
234
|
+
<button onclick="stopRun()" class="px-3 py-1.5 bg-red-600/80 hover:bg-red-600 rounded text-xs font-medium transition">Stop</button>
|
|
235
|
+
` : ''}
|
|
236
|
+
${(r.status === 'stopped' || r.status === 'completed') && r.checkpoint_path ? `
|
|
237
|
+
<button onclick="resumeRun(${r.id}, '${r.dataset}', '${(r.checkpoint_path || '').replace(/'/g, "\\'")}')" class="px-3 py-1.5 bg-yellow-600/80 hover:bg-yellow-600 rounded text-xs font-medium transition">Resume</button>
|
|
238
|
+
` : ''}
|
|
239
|
+
<a href="/runs/${r.id}" class="px-3 py-1.5 bg-gray-700 hover:bg-gray-600 rounded text-xs font-medium transition">View</a>
|
|
240
|
+
</div>
|
|
241
|
+
</div>
|
|
242
|
+
`;
|
|
243
|
+
}).join('');
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
// ---- Training Status Listener ----
|
|
247
|
+
window.addEventListener('training-status', (e) => {
|
|
248
|
+
const state = e.detail;
|
|
249
|
+
const oldId = currentTrainingRunId;
|
|
250
|
+
currentTrainingRunId = state.run_id;
|
|
251
|
+
|
|
252
|
+
// Re-render if status changed
|
|
253
|
+
if (runsData.length > 0) {
|
|
254
|
+
const activeRun = runsData.find(r => r.id === state.run_id);
|
|
255
|
+
if (activeRun && activeRun.status !== state.status) {
|
|
256
|
+
activeRun.status = state.status;
|
|
257
|
+
activeRun.current_iter = state.current_iter;
|
|
258
|
+
activeRun.best_val_loss = state.best_val_loss;
|
|
259
|
+
renderRuns();
|
|
260
|
+
}
|
|
261
|
+
// If status changed from running to something else, reload to get final state
|
|
262
|
+
if (oldId && state.status !== 'running') {
|
|
263
|
+
loadRuns();
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
});
|
|
267
|
+
|
|
268
|
+
// ---- Modal Controls ----
|
|
269
|
+
function openNewRunModal() {
|
|
270
|
+
loadDatasets();
|
|
271
|
+
document.getElementById('new-run-modal').classList.remove('hidden');
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
function closeNewRunModal() {
|
|
275
|
+
document.getElementById('new-run-modal').classList.add('hidden');
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
function toggleImportModal() {
|
|
279
|
+
document.getElementById('import-modal').classList.toggle('hidden');
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// Close modals on backdrop click
|
|
283
|
+
document.getElementById('new-run-modal')?.addEventListener('click', (e) => {
|
|
284
|
+
if (e.target === e.currentTarget) closeNewRunModal();
|
|
285
|
+
});
|
|
286
|
+
document.getElementById('import-modal')?.addEventListener('click', (e) => {
|
|
287
|
+
if (e.target === e.currentTarget) toggleImportModal();
|
|
288
|
+
});
|
|
289
|
+
|
|
290
|
+
// Close modals on Escape
|
|
291
|
+
document.addEventListener('keydown', (e) => {
|
|
292
|
+
if (e.key === 'Escape') {
|
|
293
|
+
closeNewRunModal();
|
|
294
|
+
document.getElementById('import-modal').classList.add('hidden');
|
|
295
|
+
}
|
|
296
|
+
});
|
|
297
|
+
|
|
298
|
+
// ---- Config Helpers ----
|
|
299
|
+
function getConfig() {
|
|
300
|
+
return {
|
|
301
|
+
dataset: document.getElementById('cfg-dataset').value,
|
|
302
|
+
batch_size: parseInt(document.getElementById('cfg-batch_size').value),
|
|
303
|
+
block_size: parseInt(document.getElementById('cfg-block_size').value),
|
|
304
|
+
n_layer: parseInt(document.getElementById('cfg-n_layer').value),
|
|
305
|
+
n_head: parseInt(document.getElementById('cfg-n_head').value),
|
|
306
|
+
n_embd: parseInt(document.getElementById('cfg-n_embd').value),
|
|
307
|
+
learning_rate: parseFloat(document.getElementById('cfg-learning_rate').value),
|
|
308
|
+
max_iters: parseInt(document.getElementById('cfg-max_iters').value),
|
|
309
|
+
eval_interval: parseInt(document.getElementById('cfg-eval_interval').value),
|
|
310
|
+
log_interval: parseInt(document.getElementById('cfg-log_interval').value),
|
|
311
|
+
dropout: parseFloat(document.getElementById('cfg-dropout').value),
|
|
312
|
+
gradient_accumulation_steps: parseInt(document.getElementById('cfg-gradient_accumulation_steps').value),
|
|
313
|
+
eval_iters: parseInt(document.getElementById('cfg-eval_iters').value),
|
|
314
|
+
weight_decay: parseFloat(document.getElementById('cfg-weight_decay').value),
|
|
315
|
+
grad_clip: parseFloat(document.getElementById('cfg-grad_clip').value),
|
|
316
|
+
warmup_iters: parseInt(document.getElementById('cfg-warmup_iters').value),
|
|
317
|
+
lr_decay_iters: parseInt(document.getElementById('cfg-lr_decay_iters').value),
|
|
318
|
+
min_lr: parseFloat(document.getElementById('cfg-min_lr').value),
|
|
319
|
+
bias: document.getElementById('cfg-bias').value === 'true',
|
|
320
|
+
always_save_checkpoint: document.getElementById('cfg-always_save_checkpoint').value === 'true'
|
|
321
|
+
};
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
// ---- Actions ----
|
|
325
|
+
async function startRun() {
|
|
326
|
+
const btn = document.getElementById('btn-start-run');
|
|
327
|
+
btn.disabled = true;
|
|
328
|
+
btn.textContent = 'Starting...';
|
|
329
|
+
|
|
330
|
+
try {
|
|
331
|
+
const res = await fetch('/train/start', {
|
|
332
|
+
method: 'POST',
|
|
333
|
+
headers: { 'Content-Type': 'application/json' },
|
|
334
|
+
body: JSON.stringify(getConfig())
|
|
335
|
+
});
|
|
336
|
+
const data = await res.json();
|
|
337
|
+
if (res.ok) {
|
|
338
|
+
closeNewRunModal();
|
|
339
|
+
// Navigate to the new run
|
|
340
|
+
window.location = '/runs/' + data.run_id;
|
|
341
|
+
} else {
|
|
342
|
+
alert(data.error);
|
|
343
|
+
}
|
|
344
|
+
} finally {
|
|
345
|
+
btn.disabled = false;
|
|
346
|
+
btn.textContent = 'Start Training';
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
async function stopRun() {
|
|
351
|
+
await fetch('/train/stop', {
|
|
352
|
+
method: 'POST',
|
|
353
|
+
headers: { 'Content-Type': 'application/json' },
|
|
354
|
+
body: '{}'
|
|
355
|
+
});
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
async function resumeRun(runId, dataset, checkpointPath) {
|
|
359
|
+
// Open modal pre-filled, but just start immediately with the same config
|
|
360
|
+
const run = runsData.find(r => r.id === runId);
|
|
361
|
+
const config = run ? JSON.parse(run.config_json || '{}') : {};
|
|
362
|
+
config.checkpoint_path = checkpointPath;
|
|
363
|
+
|
|
364
|
+
const res = await fetch('/train/resume', {
|
|
365
|
+
method: 'POST',
|
|
366
|
+
headers: { 'Content-Type': 'application/json' },
|
|
367
|
+
body: JSON.stringify(config)
|
|
368
|
+
});
|
|
369
|
+
const data = await res.json();
|
|
370
|
+
if (res.ok) {
|
|
371
|
+
window.location = '/runs/' + data.run_id;
|
|
372
|
+
} else {
|
|
373
|
+
alert(data.error);
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
// ---- Datasets ----
|
|
378
|
+
async function loadDatasets() {
|
|
379
|
+
try {
|
|
380
|
+
const res = await fetch('/datasets');
|
|
381
|
+
const data = await res.json();
|
|
382
|
+
const select = document.getElementById('cfg-dataset');
|
|
383
|
+
select.innerHTML = '';
|
|
384
|
+
data.datasets.forEach(d => {
|
|
385
|
+
const opt = document.createElement('option');
|
|
386
|
+
opt.value = d;
|
|
387
|
+
opt.textContent = d;
|
|
388
|
+
select.appendChild(opt);
|
|
389
|
+
});
|
|
390
|
+
} catch (e) {}
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
// ---- Import ----
|
|
394
|
+
document.getElementById('import-file')?.addEventListener('change', function() {
|
|
395
|
+
const nameInput = document.getElementById('import-name');
|
|
396
|
+
if (!nameInput.value && this.files[0]) {
|
|
397
|
+
nameInput.value = this.files[0].name.replace(/\.[^/.]+$/, '').replace(/[^a-zA-Z0-9_-]/g, '_');
|
|
398
|
+
}
|
|
399
|
+
});
|
|
400
|
+
|
|
401
|
+
async function uploadDataset() {
|
|
402
|
+
const btn = document.getElementById('btn-upload');
|
|
403
|
+
const status = document.getElementById('import-status');
|
|
404
|
+
btn.disabled = true;
|
|
405
|
+
status.textContent = 'Uploading and preparing...';
|
|
406
|
+
status.className = 'text-xs text-yellow-400';
|
|
407
|
+
|
|
408
|
+
const form = document.getElementById('upload-form');
|
|
409
|
+
const formData = new FormData(form);
|
|
410
|
+
try {
|
|
411
|
+
const res = await fetch('/datasets/upload', { method: 'POST', body: formData });
|
|
412
|
+
const data = await res.json();
|
|
413
|
+
if (data.ok) {
|
|
414
|
+
status.textContent = 'Dataset "' + data.dataset + '" ready!';
|
|
415
|
+
status.className = 'text-xs text-green-400';
|
|
416
|
+
setTimeout(() => {
|
|
417
|
+
document.getElementById('import-modal').classList.add('hidden');
|
|
418
|
+
status.textContent = '';
|
|
419
|
+
}, 1500);
|
|
420
|
+
} else {
|
|
421
|
+
status.textContent = 'Error: ' + data.error;
|
|
422
|
+
status.className = 'text-xs text-red-400';
|
|
423
|
+
}
|
|
424
|
+
} catch (err) {
|
|
425
|
+
status.textContent = 'Upload failed: ' + err.message;
|
|
426
|
+
status.className = 'text-xs text-red-400';
|
|
427
|
+
} finally {
|
|
428
|
+
btn.disabled = false;
|
|
429
|
+
}
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
// ---- Init ----
|
|
433
|
+
loadRuns();
|
|
434
|
+
</script>
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "fileutils"
|
|
4
|
+
|
|
5
|
+
module NanoGPT
|
|
6
|
+
module Web
|
|
7
|
+
# Training loop with web dashboard hooks
|
|
8
|
+
# Composes with the same model/data_loader/config as Trainer
|
|
9
|
+
# but adds stop flag checking, metric recording, and SSE broadcasting
|
|
10
|
+
class WebTrainer
|
|
11
|
+
attr_reader :model, :config, :iter_num, :best_val_loss
|
|
12
|
+
|
|
13
|
+
def initialize(model:, data_loader:, config:, training_state:, metrics_store:, sse_notifier:)
|
|
14
|
+
@model = model
|
|
15
|
+
@data_loader = data_loader
|
|
16
|
+
@config = config.is_a?(Hash) ? config.transform_keys(&:to_sym) : config.to_h
|
|
17
|
+
@training_state = training_state
|
|
18
|
+
@metrics_store = metrics_store
|
|
19
|
+
@sse_notifier = sse_notifier
|
|
20
|
+
|
|
21
|
+
@iter_num = 0
|
|
22
|
+
@best_val_loss = Float::INFINITY
|
|
23
|
+
|
|
24
|
+
setup_optimizer
|
|
25
|
+
setup_lr_scheduler
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def train(run_id)
|
|
29
|
+
@run_id = run_id
|
|
30
|
+
@training_state.update(status: "running", run_id: run_id, max_iters: @config[:max_iters])
|
|
31
|
+
|
|
32
|
+
puts "Starting training... max_iters=#{@config[:max_iters]} eval_interval=#{@config[:eval_interval]} eval_iters=#{@config[:eval_iters]} device=#{@config[:device]}"
|
|
33
|
+
|
|
34
|
+
@model.train
|
|
35
|
+
x, y = @data_loader.get_batch(:train)
|
|
36
|
+
t0 = Time.now
|
|
37
|
+
|
|
38
|
+
while @iter_num <= @config[:max_iters]
|
|
39
|
+
# Check stop flag
|
|
40
|
+
if @training_state.stop_requested?
|
|
41
|
+
save_checkpoint("final")
|
|
42
|
+
@training_state.update(status: "stopped")
|
|
43
|
+
@metrics_store.update_run(run_id, status: "stopped", stopped_at: Time.now.iso8601,
|
|
44
|
+
checkpoint_path: checkpoint_path("final"))
|
|
45
|
+
@sse_notifier.broadcast(type: "status", data: @training_state.to_h)
|
|
46
|
+
return
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
lr = @config[:decay_lr] ? @lr_scheduler.step(@optimizer, @iter_num) : @config[:learning_rate]
|
|
50
|
+
|
|
51
|
+
if @iter_num % @config[:eval_interval] == 0
|
|
52
|
+
puts "iter #{@iter_num}: running eval..."
|
|
53
|
+
losses = estimate_loss
|
|
54
|
+
val_loss = losses[:val]
|
|
55
|
+
train_loss = losses[:train]
|
|
56
|
+
puts "iter #{@iter_num}: eval done - train_loss=#{format('%.4f', train_loss)} val_loss=#{format('%.4f', val_loss)}"
|
|
57
|
+
|
|
58
|
+
@training_state.update(best_val_loss: [@best_val_loss, val_loss].min)
|
|
59
|
+
@metrics_store.record_metrics(run_id, @iter_num, { val_loss: val_loss, eval_train_loss: train_loss })
|
|
60
|
+
@sse_notifier.broadcast(type: "eval", data: {
|
|
61
|
+
iteration: @iter_num, val_loss: val_loss, train_loss: train_loss
|
|
62
|
+
})
|
|
63
|
+
|
|
64
|
+
if val_loss < @best_val_loss || @config[:always_save_checkpoint]
|
|
65
|
+
@best_val_loss = [val_loss, @best_val_loss].min
|
|
66
|
+
save_checkpoint("best") if @iter_num > 0
|
|
67
|
+
@metrics_store.update_run(run_id,
|
|
68
|
+
best_val_loss: @best_val_loss,
|
|
69
|
+
checkpoint_path: checkpoint_path("best")
|
|
70
|
+
)
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
break if @iter_num == 0 && @config[:eval_only]
|
|
75
|
+
|
|
76
|
+
@optimizer.zero_grad
|
|
77
|
+
|
|
78
|
+
accumulated_loss = 0.0
|
|
79
|
+
@config[:gradient_accumulation_steps].times do
|
|
80
|
+
_logits, loss = @model.call(x, targets: y)
|
|
81
|
+
loss = loss / @config[:gradient_accumulation_steps]
|
|
82
|
+
accumulated_loss += loss.item
|
|
83
|
+
loss.backward
|
|
84
|
+
x, y = @data_loader.get_batch(:train)
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
clip_grad_norm(@model.parameters, @config[:grad_clip]) if @config[:grad_clip] > 0.0
|
|
88
|
+
@optimizer.step
|
|
89
|
+
|
|
90
|
+
t1 = Time.now
|
|
91
|
+
dt = t1 - t0
|
|
92
|
+
t0 = t1
|
|
93
|
+
|
|
94
|
+
if @iter_num % @config[:log_interval] == 0
|
|
95
|
+
@training_state.update(
|
|
96
|
+
current_iter: @iter_num,
|
|
97
|
+
current_loss: accumulated_loss
|
|
98
|
+
)
|
|
99
|
+
@metrics_store.record_metrics(run_id, @iter_num, { train_loss: accumulated_loss, lr: lr, iter_time_ms: dt * 1000 })
|
|
100
|
+
@metrics_store.update_run(run_id, current_iter: @iter_num)
|
|
101
|
+
@sse_notifier.broadcast(type: "train", data: {
|
|
102
|
+
iteration: @iter_num, loss: accumulated_loss, lr: lr, time_ms: (dt * 1000).round(2)
|
|
103
|
+
})
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
@iter_num += 1
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
save_checkpoint("final")
|
|
110
|
+
@training_state.update(status: "completed")
|
|
111
|
+
@metrics_store.update_run(run_id, status: "completed", stopped_at: Time.now.iso8601,
|
|
112
|
+
checkpoint_path: checkpoint_path("best"))
|
|
113
|
+
@sse_notifier.broadcast(type: "status", data: @training_state.to_h)
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
def load_checkpoint(path)
|
|
117
|
+
checkpoint = Torch.load(path)
|
|
118
|
+
@model.load_state_dict(checkpoint["model"])
|
|
119
|
+
@iter_num = checkpoint["iter_num"]
|
|
120
|
+
@best_val_loss = checkpoint["best_val_loss"]
|
|
121
|
+
setup_optimizer
|
|
122
|
+
checkpoint
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
private
|
|
126
|
+
|
|
127
|
+
def estimate_loss
|
|
128
|
+
@model.eval
|
|
129
|
+
out = {}
|
|
130
|
+
|
|
131
|
+
[:train, :val].each do |split|
|
|
132
|
+
losses = []
|
|
133
|
+
@config[:eval_iters].times do |i|
|
|
134
|
+
break if @training_state.stop_requested?
|
|
135
|
+
x, y = @data_loader.get_batch(split)
|
|
136
|
+
Torch.no_grad do
|
|
137
|
+
_logits, loss = @model.call(x, targets: y)
|
|
138
|
+
losses << loss.item
|
|
139
|
+
end
|
|
140
|
+
end
|
|
141
|
+
out[split] = losses.empty? ? Float::INFINITY : losses.sum / losses.size
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
@model.train
|
|
145
|
+
out
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def save_checkpoint(suffix = "best")
|
|
149
|
+
FileUtils.mkdir_p(@config[:out_dir])
|
|
150
|
+
path = checkpoint_path(suffix)
|
|
151
|
+
checkpoint = {
|
|
152
|
+
"model" => @model.state_dict,
|
|
153
|
+
"model_args" => stringify_keys(@model.config.to_h),
|
|
154
|
+
"iter_num" => @iter_num,
|
|
155
|
+
"best_val_loss" => @best_val_loss,
|
|
156
|
+
"config" => stringify_keys(@config)
|
|
157
|
+
}
|
|
158
|
+
Torch.save(checkpoint, path)
|
|
159
|
+
@metrics_store.record_checkpoint(@run_id,
|
|
160
|
+
path: path, suffix: suffix, iteration: @iter_num, val_loss: @best_val_loss)
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
def checkpoint_path(suffix = "best")
|
|
164
|
+
File.join(@config[:out_dir], "ckpt_run#{@run_id}_#{suffix}.pt")
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def stringify_keys(hash)
|
|
168
|
+
hash.transform_keys(&:to_s).transform_values do |v|
|
|
169
|
+
v.is_a?(Hash) ? stringify_keys(v) : v
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
def clip_grad_norm(parameters, max_norm)
|
|
174
|
+
total_norm = 0.0
|
|
175
|
+
parameters.each do |p|
|
|
176
|
+
next unless p.grad
|
|
177
|
+
param_norm = p.grad.data.norm(2).item
|
|
178
|
+
total_norm += param_norm**2
|
|
179
|
+
end
|
|
180
|
+
total_norm = Math.sqrt(total_norm)
|
|
181
|
+
|
|
182
|
+
clip_coef = max_norm / (total_norm + 1e-6)
|
|
183
|
+
if clip_coef < 1
|
|
184
|
+
parameters.each do |p|
|
|
185
|
+
next unless p.grad
|
|
186
|
+
p.grad.data.mul!(clip_coef)
|
|
187
|
+
end
|
|
188
|
+
end
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
def setup_optimizer
|
|
192
|
+
@optimizer = @model.configure_optimizers(
|
|
193
|
+
weight_decay: @config[:weight_decay] || 1e-1,
|
|
194
|
+
learning_rate: @config[:learning_rate],
|
|
195
|
+
betas: [@config[:beta1] || 0.9, @config[:beta2] || 0.99],
|
|
196
|
+
device_type: @config[:device]
|
|
197
|
+
)
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
def setup_lr_scheduler
|
|
201
|
+
@lr_scheduler = LRScheduler.new(
|
|
202
|
+
learning_rate: @config[:learning_rate],
|
|
203
|
+
min_lr: @config[:min_lr] || 1e-4,
|
|
204
|
+
warmup_iters: @config[:warmup_iters] || 100,
|
|
205
|
+
lr_decay_iters: @config[:lr_decay_iters] || 5000
|
|
206
|
+
)
|
|
207
|
+
end
|
|
208
|
+
end
|
|
209
|
+
end
|
|
210
|
+
end
|
data/lib/nano_gpt/web.rb
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "bpe_textfile_preparer"
|
|
4
|
+
require_relative "web/metrics_store"
|
|
5
|
+
require_relative "web/training_state"
|
|
6
|
+
require_relative "web/sse_notifier"
|
|
7
|
+
require_relative "web/web_trainer"
|
|
8
|
+
require_relative "web/training_worker"
|
|
9
|
+
require_relative "web/server"
|
data/lib/nano_gpt.rb
CHANGED
data/nanogpt.gemspec
CHANGED
|
@@ -32,6 +32,10 @@ Gem::Specification.new do |spec|
|
|
|
32
32
|
spec.add_dependency "torch-rb", "~> 0.14"
|
|
33
33
|
spec.add_dependency "numo-narray", "~> 0.9"
|
|
34
34
|
spec.add_dependency "tiktoken_ruby", "~> 0.0"
|
|
35
|
+
spec.add_dependency "sinatra", "~> 4.0"
|
|
36
|
+
spec.add_dependency "rackup", "~> 2.0"
|
|
37
|
+
spec.add_dependency "webrick", "~> 1.8"
|
|
38
|
+
spec.add_dependency "sqlite3", "~> 2.0"
|
|
35
39
|
|
|
36
40
|
spec.add_development_dependency "rspec", "~> 3.12"
|
|
37
41
|
end
|