nanogpt 0.2.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,434 @@
1
+ <!-- Header -->
2
+ <div class="flex items-center justify-between">
3
+ <h1 class="text-lg font-semibold">Runs</h1>
4
+ <div class="flex gap-2">
5
+ <button onclick="toggleImportModal()" class="px-3 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg text-sm transition">
6
+ Import Dataset
7
+ </button>
8
+ <button onclick="openNewRunModal()" class="px-4 py-2 bg-green-600 hover:bg-green-700 rounded-lg font-medium text-sm transition">
9
+ + New Run
10
+ </button>
11
+ </div>
12
+ </div>
13
+
14
+ <!-- Runs List -->
15
+ <div id="runs-list" class="space-y-2">
16
+ <div class="text-center text-gray-500 py-8">Loading runs...</div>
17
+ </div>
18
+
19
+ <!-- New Run Modal -->
20
+ <div id="new-run-modal" class="fixed inset-0 bg-black/60 z-50 hidden flex items-start justify-center pt-16 px-4">
21
+ <div class="bg-gray-800 rounded-xl shadow-2xl w-full max-w-2xl max-h-[80vh] overflow-y-auto">
22
+ <div class="flex items-center justify-between px-6 py-4 border-b border-gray-700">
23
+ <h2 class="text-base font-semibold">New Training Run</h2>
24
+ <button onclick="closeNewRunModal()" class="text-gray-400 hover:text-white text-xl leading-none">&times;</button>
25
+ </div>
26
+
27
+ <div class="px-6 py-5 space-y-5">
28
+ <!-- Dataset -->
29
+ <div>
30
+ <label class="block text-xs text-gray-400 mb-1.5">Dataset</label>
31
+ <select id="cfg-dataset" class="w-full bg-gray-700 rounded px-3 py-2 text-sm"></select>
32
+ </div>
33
+
34
+ <!-- Basic -->
35
+ <div>
36
+ <div class="text-xs text-gray-400 uppercase tracking-wide mb-2">Training</div>
37
+ <div class="grid grid-cols-3 gap-3">
38
+ <div>
39
+ <label class="block text-xs text-gray-400 mb-1">Max Iters</label>
40
+ <input id="cfg-max_iters" type="number" value="5000" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
41
+ </div>
42
+ <div>
43
+ <label class="block text-xs text-gray-400 mb-1">Learning Rate</label>
44
+ <input id="cfg-learning_rate" type="number" value="0.001" step="0.0001" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
45
+ </div>
46
+ <div>
47
+ <label class="block text-xs text-gray-400 mb-1">Batch Size</label>
48
+ <input id="cfg-batch_size" type="number" value="64" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
49
+ </div>
50
+ </div>
51
+ </div>
52
+
53
+ <!-- Model -->
54
+ <div>
55
+ <div class="text-xs text-gray-400 uppercase tracking-wide mb-2">Model</div>
56
+ <div class="grid grid-cols-3 gap-3">
57
+ <div>
58
+ <label class="block text-xs text-gray-400 mb-1">Block Size</label>
59
+ <input id="cfg-block_size" type="number" value="256" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
60
+ </div>
61
+ <div>
62
+ <label class="block text-xs text-gray-400 mb-1">Layers</label>
63
+ <input id="cfg-n_layer" type="number" value="6" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
64
+ </div>
65
+ <div>
66
+ <label class="block text-xs text-gray-400 mb-1">Heads</label>
67
+ <input id="cfg-n_head" type="number" value="6" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
68
+ </div>
69
+ <div>
70
+ <label class="block text-xs text-gray-400 mb-1">Embedding Dim</label>
71
+ <input id="cfg-n_embd" type="number" value="384" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
72
+ </div>
73
+ <div>
74
+ <label class="block text-xs text-gray-400 mb-1">Dropout</label>
75
+ <input id="cfg-dropout" type="number" value="0.2" step="0.05" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
76
+ </div>
77
+ </div>
78
+ </div>
79
+
80
+ <!-- Advanced (collapsed) -->
81
+ <details>
82
+ <summary class="text-xs text-gray-400 uppercase tracking-wide cursor-pointer hover:text-gray-300">Advanced</summary>
83
+ <div class="grid grid-cols-3 gap-3 mt-3">
84
+ <div>
85
+ <label class="block text-xs text-gray-400 mb-1">Eval Interval</label>
86
+ <input id="cfg-eval_interval" type="number" value="250" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
87
+ </div>
88
+ <div>
89
+ <label class="block text-xs text-gray-400 mb-1">Log Interval</label>
90
+ <input id="cfg-log_interval" type="number" value="10" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
91
+ </div>
92
+ <div>
93
+ <label class="block text-xs text-gray-400 mb-1">Eval Iters</label>
94
+ <input id="cfg-eval_iters" type="number" value="20" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
95
+ </div>
96
+ <div>
97
+ <label class="block text-xs text-gray-400 mb-1">Grad Accumulation</label>
98
+ <input id="cfg-gradient_accumulation_steps" type="number" value="1" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
99
+ </div>
100
+ <div>
101
+ <label class="block text-xs text-gray-400 mb-1">Weight Decay</label>
102
+ <input id="cfg-weight_decay" type="number" value="0.1" step="0.01" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
103
+ </div>
104
+ <div>
105
+ <label class="block text-xs text-gray-400 mb-1">Grad Clip</label>
106
+ <input id="cfg-grad_clip" type="number" value="1.0" step="0.1" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
107
+ </div>
108
+ <div>
109
+ <label class="block text-xs text-gray-400 mb-1">Warmup Iters</label>
110
+ <input id="cfg-warmup_iters" type="number" value="100" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
111
+ </div>
112
+ <div>
113
+ <label class="block text-xs text-gray-400 mb-1">LR Decay Iters</label>
114
+ <input id="cfg-lr_decay_iters" type="number" value="5000" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
115
+ </div>
116
+ <div>
117
+ <label class="block text-xs text-gray-400 mb-1">Min LR</label>
118
+ <input id="cfg-min_lr" type="number" value="0.0001" step="0.0001" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
119
+ </div>
120
+ <div>
121
+ <label class="block text-xs text-gray-400 mb-1">Bias</label>
122
+ <select id="cfg-bias" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
123
+ <option value="false" selected>false</option>
124
+ <option value="true">true</option>
125
+ </select>
126
+ </div>
127
+ <div>
128
+ <label class="block text-xs text-gray-400 mb-1">Always Save Ckpt</label>
129
+ <select id="cfg-always_save_checkpoint" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
130
+ <option value="false" selected>false</option>
131
+ <option value="true">true</option>
132
+ </select>
133
+ </div>
134
+ </div>
135
+ </details>
136
+ </div>
137
+
138
+ <div class="px-6 py-4 border-t border-gray-700 flex justify-end gap-3">
139
+ <button onclick="closeNewRunModal()" class="px-4 py-2 bg-gray-700 hover:bg-gray-600 rounded-lg text-sm transition">Cancel</button>
140
+ <button id="btn-start-run" onclick="startRun()" class="px-5 py-2 bg-green-600 hover:bg-green-700 rounded-lg font-medium text-sm transition">
141
+ Start Training
142
+ </button>
143
+ </div>
144
+ </div>
145
+ </div>
146
+
147
+ <!-- Import Dataset Modal -->
148
+ <div id="import-modal" class="fixed inset-0 bg-black/60 z-50 hidden flex items-start justify-center pt-16 px-4">
149
+ <div class="bg-gray-800 rounded-xl shadow-2xl w-full max-w-lg">
150
+ <div class="flex items-center justify-between px-6 py-4 border-b border-gray-700">
151
+ <h2 class="text-base font-semibold">Import Dataset</h2>
152
+ <button onclick="toggleImportModal()" class="text-gray-400 hover:text-white text-xl leading-none">&times;</button>
153
+ </div>
154
+ <form id="upload-form" enctype="multipart/form-data" class="px-6 py-5 space-y-4">
155
+ <div>
156
+ <label class="block text-xs text-gray-400 mb-1">Text File</label>
157
+ <input id="import-file" name="file" type="file" accept=".txt,.text" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm file:mr-3 file:py-1 file:px-3 file:rounded file:border-0 file:text-sm file:bg-gray-600 file:text-gray-200">
158
+ </div>
159
+ <div>
160
+ <label class="block text-xs text-gray-400 mb-1">Dataset Name</label>
161
+ <input id="import-name" name="name" type="text" placeholder="auto from filename" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
162
+ </div>
163
+ <div>
164
+ <label class="block text-xs text-gray-400 mb-1">Tokenizer</label>
165
+ <div class="flex gap-4 mt-1">
166
+ <label class="flex items-center gap-1.5 text-sm"><input type="radio" name="tokenizer" value="char" checked class="accent-green-500"> Char</label>
167
+ <label class="flex items-center gap-1.5 text-sm"><input type="radio" name="tokenizer" value="bpe" class="accent-green-500"> BPE (GPT-2)</label>
168
+ </div>
169
+ </div>
170
+ <div>
171
+ <label class="block text-xs text-gray-400 mb-1">Val Split</label>
172
+ <input id="import-val-ratio" name="val_ratio" type="number" value="0.1" step="0.05" min="0" max="0.5" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
173
+ </div>
174
+ </form>
175
+ <div class="px-6 py-4 border-t border-gray-700 flex items-center justify-between">
176
+ <span id="import-status" class="text-xs text-gray-400"></span>
177
+ <button id="btn-upload" onclick="uploadDataset()" class="px-4 py-2 bg-blue-600 hover:bg-blue-700 rounded-lg font-medium text-sm transition">
178
+ Upload & Prepare
179
+ </button>
180
+ </div>
181
+ </div>
182
+ </div>
183
+
184
+ <script>
185
+ let runsData = [];
186
+ let currentTrainingRunId = null;
187
+
188
+ // ---- Load Runs ----
189
+ async function loadRuns() {
190
+ try {
191
+ const res = await fetch('/api/runs');
192
+ const data = await res.json();
193
+ runsData = data.runs || [];
194
+ renderRuns();
195
+ } catch (e) {
196
+ document.getElementById('runs-list').innerHTML =
197
+ '<div class="text-center text-red-400 py-8">Failed to load runs</div>';
198
+ }
199
+ }
200
+
201
+ function renderRuns() {
202
+ const container = document.getElementById('runs-list');
203
+ if (runsData.length === 0) {
204
+ container.innerHTML = '<div class="text-center text-gray-500 py-12">No runs yet. Click "+ New Run" to start training.</div>';
205
+ return;
206
+ }
207
+
208
+ container.innerHTML = runsData.map(r => {
209
+ const valLoss = r.best_val_loss ? parseFloat(r.best_val_loss).toFixed(4) : '--';
210
+ const iters = r.current_iter || 0;
211
+ const config = JSON.parse(r.config_json || '{}');
212
+ const maxIters = config.max_iters || '?';
213
+ const isActive = r.status === 'running' && currentTrainingRunId === r.id;
214
+
215
+ return `
216
+ <div class="bg-gray-800 rounded-lg p-4 hover:bg-gray-750 transition cursor-pointer flex items-center justify-between group"
217
+ onclick="window.location='/runs/${r.id}'">
218
+ <div class="flex items-center gap-4 min-w-0">
219
+ <div class="text-sm font-mono text-gray-500 w-8">#${r.id}</div>
220
+ <div class="min-w-0">
221
+ <div class="flex items-center gap-2">
222
+ <span class="font-medium text-sm">${r.dataset}</span>
223
+ ${statusBadge(r.status)}
224
+ </div>
225
+ <div class="text-xs text-gray-500 mt-0.5">
226
+ ${formatTime(r.started_at)}
227
+ &middot; ${iters}/${maxIters} iters
228
+ &middot; val_loss: ${valLoss}
229
+ </div>
230
+ </div>
231
+ </div>
232
+ <div class="flex items-center gap-2" onclick="event.stopPropagation()">
233
+ ${r.status === 'running' ? `
234
+ <button onclick="stopRun()" class="px-3 py-1.5 bg-red-600/80 hover:bg-red-600 rounded text-xs font-medium transition">Stop</button>
235
+ ` : ''}
236
+ ${(r.status === 'stopped' || r.status === 'completed') && r.checkpoint_path ? `
237
+ <button onclick="resumeRun(${r.id}, '${r.dataset}', '${(r.checkpoint_path || '').replace(/'/g, "\\'")}')" class="px-3 py-1.5 bg-yellow-600/80 hover:bg-yellow-600 rounded text-xs font-medium transition">Resume</button>
238
+ ` : ''}
239
+ <a href="/runs/${r.id}" class="px-3 py-1.5 bg-gray-700 hover:bg-gray-600 rounded text-xs font-medium transition">View</a>
240
+ </div>
241
+ </div>
242
+ `;
243
+ }).join('');
244
+ }
245
+
246
+ // ---- Training Status Listener ----
247
+ window.addEventListener('training-status', (e) => {
248
+ const state = e.detail;
249
+ const oldId = currentTrainingRunId;
250
+ currentTrainingRunId = state.run_id;
251
+
252
+ // Re-render if status changed
253
+ if (runsData.length > 0) {
254
+ const activeRun = runsData.find(r => r.id === state.run_id);
255
+ if (activeRun && activeRun.status !== state.status) {
256
+ activeRun.status = state.status;
257
+ activeRun.current_iter = state.current_iter;
258
+ activeRun.best_val_loss = state.best_val_loss;
259
+ renderRuns();
260
+ }
261
+ // If status changed from running to something else, reload to get final state
262
+ if (oldId && state.status !== 'running') {
263
+ loadRuns();
264
+ }
265
+ }
266
+ });
267
+
268
+ // ---- Modal Controls ----
269
+ function openNewRunModal() {
270
+ loadDatasets();
271
+ document.getElementById('new-run-modal').classList.remove('hidden');
272
+ }
273
+
274
+ function closeNewRunModal() {
275
+ document.getElementById('new-run-modal').classList.add('hidden');
276
+ }
277
+
278
+ function toggleImportModal() {
279
+ document.getElementById('import-modal').classList.toggle('hidden');
280
+ }
281
+
282
+ // Close modals on backdrop click
283
+ document.getElementById('new-run-modal')?.addEventListener('click', (e) => {
284
+ if (e.target === e.currentTarget) closeNewRunModal();
285
+ });
286
+ document.getElementById('import-modal')?.addEventListener('click', (e) => {
287
+ if (e.target === e.currentTarget) toggleImportModal();
288
+ });
289
+
290
+ // Close modals on Escape
291
+ document.addEventListener('keydown', (e) => {
292
+ if (e.key === 'Escape') {
293
+ closeNewRunModal();
294
+ document.getElementById('import-modal').classList.add('hidden');
295
+ }
296
+ });
297
+
298
+ // ---- Config Helpers ----
299
+ function getConfig() {
300
+ return {
301
+ dataset: document.getElementById('cfg-dataset').value,
302
+ batch_size: parseInt(document.getElementById('cfg-batch_size').value),
303
+ block_size: parseInt(document.getElementById('cfg-block_size').value),
304
+ n_layer: parseInt(document.getElementById('cfg-n_layer').value),
305
+ n_head: parseInt(document.getElementById('cfg-n_head').value),
306
+ n_embd: parseInt(document.getElementById('cfg-n_embd').value),
307
+ learning_rate: parseFloat(document.getElementById('cfg-learning_rate').value),
308
+ max_iters: parseInt(document.getElementById('cfg-max_iters').value),
309
+ eval_interval: parseInt(document.getElementById('cfg-eval_interval').value),
310
+ log_interval: parseInt(document.getElementById('cfg-log_interval').value),
311
+ dropout: parseFloat(document.getElementById('cfg-dropout').value),
312
+ gradient_accumulation_steps: parseInt(document.getElementById('cfg-gradient_accumulation_steps').value),
313
+ eval_iters: parseInt(document.getElementById('cfg-eval_iters').value),
314
+ weight_decay: parseFloat(document.getElementById('cfg-weight_decay').value),
315
+ grad_clip: parseFloat(document.getElementById('cfg-grad_clip').value),
316
+ warmup_iters: parseInt(document.getElementById('cfg-warmup_iters').value),
317
+ lr_decay_iters: parseInt(document.getElementById('cfg-lr_decay_iters').value),
318
+ min_lr: parseFloat(document.getElementById('cfg-min_lr').value),
319
+ bias: document.getElementById('cfg-bias').value === 'true',
320
+ always_save_checkpoint: document.getElementById('cfg-always_save_checkpoint').value === 'true'
321
+ };
322
+ }
323
+
324
+ // ---- Actions ----
325
+ async function startRun() {
326
+ const btn = document.getElementById('btn-start-run');
327
+ btn.disabled = true;
328
+ btn.textContent = 'Starting...';
329
+
330
+ try {
331
+ const res = await fetch('/train/start', {
332
+ method: 'POST',
333
+ headers: { 'Content-Type': 'application/json' },
334
+ body: JSON.stringify(getConfig())
335
+ });
336
+ const data = await res.json();
337
+ if (res.ok) {
338
+ closeNewRunModal();
339
+ // Navigate to the new run
340
+ window.location = '/runs/' + data.run_id;
341
+ } else {
342
+ alert(data.error);
343
+ }
344
+ } finally {
345
+ btn.disabled = false;
346
+ btn.textContent = 'Start Training';
347
+ }
348
+ }
349
+
350
+ async function stopRun() {
351
+ await fetch('/train/stop', {
352
+ method: 'POST',
353
+ headers: { 'Content-Type': 'application/json' },
354
+ body: '{}'
355
+ });
356
+ }
357
+
358
+ async function resumeRun(runId, dataset, checkpointPath) {
359
+ // Open modal pre-filled, but just start immediately with the same config
360
+ const run = runsData.find(r => r.id === runId);
361
+ const config = run ? JSON.parse(run.config_json || '{}') : {};
362
+ config.checkpoint_path = checkpointPath;
363
+
364
+ const res = await fetch('/train/resume', {
365
+ method: 'POST',
366
+ headers: { 'Content-Type': 'application/json' },
367
+ body: JSON.stringify(config)
368
+ });
369
+ const data = await res.json();
370
+ if (res.ok) {
371
+ window.location = '/runs/' + data.run_id;
372
+ } else {
373
+ alert(data.error);
374
+ }
375
+ }
376
+
377
+ // ---- Datasets ----
378
+ async function loadDatasets() {
379
+ try {
380
+ const res = await fetch('/datasets');
381
+ const data = await res.json();
382
+ const select = document.getElementById('cfg-dataset');
383
+ select.innerHTML = '';
384
+ data.datasets.forEach(d => {
385
+ const opt = document.createElement('option');
386
+ opt.value = d;
387
+ opt.textContent = d;
388
+ select.appendChild(opt);
389
+ });
390
+ } catch (e) {}
391
+ }
392
+
393
+ // ---- Import ----
394
+ document.getElementById('import-file')?.addEventListener('change', function() {
395
+ const nameInput = document.getElementById('import-name');
396
+ if (!nameInput.value && this.files[0]) {
397
+ nameInput.value = this.files[0].name.replace(/\.[^/.]+$/, '').replace(/[^a-zA-Z0-9_-]/g, '_');
398
+ }
399
+ });
400
+
401
+ async function uploadDataset() {
402
+ const btn = document.getElementById('btn-upload');
403
+ const status = document.getElementById('import-status');
404
+ btn.disabled = true;
405
+ status.textContent = 'Uploading and preparing...';
406
+ status.className = 'text-xs text-yellow-400';
407
+
408
+ const form = document.getElementById('upload-form');
409
+ const formData = new FormData(form);
410
+ try {
411
+ const res = await fetch('/datasets/upload', { method: 'POST', body: formData });
412
+ const data = await res.json();
413
+ if (data.ok) {
414
+ status.textContent = 'Dataset "' + data.dataset + '" ready!';
415
+ status.className = 'text-xs text-green-400';
416
+ setTimeout(() => {
417
+ document.getElementById('import-modal').classList.add('hidden');
418
+ status.textContent = '';
419
+ }, 1500);
420
+ } else {
421
+ status.textContent = 'Error: ' + data.error;
422
+ status.className = 'text-xs text-red-400';
423
+ }
424
+ } catch (err) {
425
+ status.textContent = 'Upload failed: ' + err.message;
426
+ status.className = 'text-xs text-red-400';
427
+ } finally {
428
+ btn.disabled = false;
429
+ }
430
+ }
431
+
432
+ // ---- Init ----
433
+ loadRuns();
434
+ </script>
@@ -0,0 +1,210 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "fileutils"
4
+
5
+ module NanoGPT
6
+ module Web
7
+ # Training loop with web dashboard hooks
8
+ # Composes with the same model/data_loader/config as Trainer
9
+ # but adds stop flag checking, metric recording, and SSE broadcasting
10
+ class WebTrainer
11
+ attr_reader :model, :config, :iter_num, :best_val_loss
12
+
13
+ def initialize(model:, data_loader:, config:, training_state:, metrics_store:, sse_notifier:)
14
+ @model = model
15
+ @data_loader = data_loader
16
+ @config = config.is_a?(Hash) ? config.transform_keys(&:to_sym) : config.to_h
17
+ @training_state = training_state
18
+ @metrics_store = metrics_store
19
+ @sse_notifier = sse_notifier
20
+
21
+ @iter_num = 0
22
+ @best_val_loss = Float::INFINITY
23
+
24
+ setup_optimizer
25
+ setup_lr_scheduler
26
+ end
27
+
28
+ def train(run_id)
29
+ @run_id = run_id
30
+ @training_state.update(status: "running", run_id: run_id, max_iters: @config[:max_iters])
31
+
32
+ puts "Starting training... max_iters=#{@config[:max_iters]} eval_interval=#{@config[:eval_interval]} eval_iters=#{@config[:eval_iters]} device=#{@config[:device]}"
33
+
34
+ @model.train
35
+ x, y = @data_loader.get_batch(:train)
36
+ t0 = Time.now
37
+
38
+ while @iter_num <= @config[:max_iters]
39
+ # Check stop flag
40
+ if @training_state.stop_requested?
41
+ save_checkpoint("final")
42
+ @training_state.update(status: "stopped")
43
+ @metrics_store.update_run(run_id, status: "stopped", stopped_at: Time.now.iso8601,
44
+ checkpoint_path: checkpoint_path("final"))
45
+ @sse_notifier.broadcast(type: "status", data: @training_state.to_h)
46
+ return
47
+ end
48
+
49
+ lr = @config[:decay_lr] ? @lr_scheduler.step(@optimizer, @iter_num) : @config[:learning_rate]
50
+
51
+ if @iter_num % @config[:eval_interval] == 0
52
+ puts "iter #{@iter_num}: running eval..."
53
+ losses = estimate_loss
54
+ val_loss = losses[:val]
55
+ train_loss = losses[:train]
56
+ puts "iter #{@iter_num}: eval done - train_loss=#{format('%.4f', train_loss)} val_loss=#{format('%.4f', val_loss)}"
57
+
58
+ @training_state.update(best_val_loss: [@best_val_loss, val_loss].min)
59
+ @metrics_store.record_metrics(run_id, @iter_num, { val_loss: val_loss, eval_train_loss: train_loss })
60
+ @sse_notifier.broadcast(type: "eval", data: {
61
+ iteration: @iter_num, val_loss: val_loss, train_loss: train_loss
62
+ })
63
+
64
+ if val_loss < @best_val_loss || @config[:always_save_checkpoint]
65
+ @best_val_loss = [val_loss, @best_val_loss].min
66
+ save_checkpoint("best") if @iter_num > 0
67
+ @metrics_store.update_run(run_id,
68
+ best_val_loss: @best_val_loss,
69
+ checkpoint_path: checkpoint_path("best")
70
+ )
71
+ end
72
+ end
73
+
74
+ break if @iter_num == 0 && @config[:eval_only]
75
+
76
+ @optimizer.zero_grad
77
+
78
+ accumulated_loss = 0.0
79
+ @config[:gradient_accumulation_steps].times do
80
+ _logits, loss = @model.call(x, targets: y)
81
+ loss = loss / @config[:gradient_accumulation_steps]
82
+ accumulated_loss += loss.item
83
+ loss.backward
84
+ x, y = @data_loader.get_batch(:train)
85
+ end
86
+
87
+ clip_grad_norm(@model.parameters, @config[:grad_clip]) if @config[:grad_clip] > 0.0
88
+ @optimizer.step
89
+
90
+ t1 = Time.now
91
+ dt = t1 - t0
92
+ t0 = t1
93
+
94
+ if @iter_num % @config[:log_interval] == 0
95
+ @training_state.update(
96
+ current_iter: @iter_num,
97
+ current_loss: accumulated_loss
98
+ )
99
+ @metrics_store.record_metrics(run_id, @iter_num, { train_loss: accumulated_loss, lr: lr, iter_time_ms: dt * 1000 })
100
+ @metrics_store.update_run(run_id, current_iter: @iter_num)
101
+ @sse_notifier.broadcast(type: "train", data: {
102
+ iteration: @iter_num, loss: accumulated_loss, lr: lr, time_ms: (dt * 1000).round(2)
103
+ })
104
+ end
105
+
106
+ @iter_num += 1
107
+ end
108
+
109
+ save_checkpoint("final")
110
+ @training_state.update(status: "completed")
111
+ @metrics_store.update_run(run_id, status: "completed", stopped_at: Time.now.iso8601,
112
+ checkpoint_path: checkpoint_path("best"))
113
+ @sse_notifier.broadcast(type: "status", data: @training_state.to_h)
114
+ end
115
+
116
+ def load_checkpoint(path)
117
+ checkpoint = Torch.load(path)
118
+ @model.load_state_dict(checkpoint["model"])
119
+ @iter_num = checkpoint["iter_num"]
120
+ @best_val_loss = checkpoint["best_val_loss"]
121
+ setup_optimizer
122
+ checkpoint
123
+ end
124
+
125
+ private
126
+
127
+ def estimate_loss
128
+ @model.eval
129
+ out = {}
130
+
131
+ [:train, :val].each do |split|
132
+ losses = []
133
+ @config[:eval_iters].times do |i|
134
+ break if @training_state.stop_requested?
135
+ x, y = @data_loader.get_batch(split)
136
+ Torch.no_grad do
137
+ _logits, loss = @model.call(x, targets: y)
138
+ losses << loss.item
139
+ end
140
+ end
141
+ out[split] = losses.empty? ? Float::INFINITY : losses.sum / losses.size
142
+ end
143
+
144
+ @model.train
145
+ out
146
+ end
147
+
148
+ def save_checkpoint(suffix = "best")
149
+ FileUtils.mkdir_p(@config[:out_dir])
150
+ path = checkpoint_path(suffix)
151
+ checkpoint = {
152
+ "model" => @model.state_dict,
153
+ "model_args" => stringify_keys(@model.config.to_h),
154
+ "iter_num" => @iter_num,
155
+ "best_val_loss" => @best_val_loss,
156
+ "config" => stringify_keys(@config)
157
+ }
158
+ Torch.save(checkpoint, path)
159
+ @metrics_store.record_checkpoint(@run_id,
160
+ path: path, suffix: suffix, iteration: @iter_num, val_loss: @best_val_loss)
161
+ end
162
+
163
+ def checkpoint_path(suffix = "best")
164
+ File.join(@config[:out_dir], "ckpt_run#{@run_id}_#{suffix}.pt")
165
+ end
166
+
167
+ def stringify_keys(hash)
168
+ hash.transform_keys(&:to_s).transform_values do |v|
169
+ v.is_a?(Hash) ? stringify_keys(v) : v
170
+ end
171
+ end
172
+
173
+ def clip_grad_norm(parameters, max_norm)
174
+ total_norm = 0.0
175
+ parameters.each do |p|
176
+ next unless p.grad
177
+ param_norm = p.grad.data.norm(2).item
178
+ total_norm += param_norm**2
179
+ end
180
+ total_norm = Math.sqrt(total_norm)
181
+
182
+ clip_coef = max_norm / (total_norm + 1e-6)
183
+ if clip_coef < 1
184
+ parameters.each do |p|
185
+ next unless p.grad
186
+ p.grad.data.mul!(clip_coef)
187
+ end
188
+ end
189
+ end
190
+
191
+ def setup_optimizer
192
+ @optimizer = @model.configure_optimizers(
193
+ weight_decay: @config[:weight_decay] || 1e-1,
194
+ learning_rate: @config[:learning_rate],
195
+ betas: [@config[:beta1] || 0.9, @config[:beta2] || 0.99],
196
+ device_type: @config[:device]
197
+ )
198
+ end
199
+
200
+ def setup_lr_scheduler
201
+ @lr_scheduler = LRScheduler.new(
202
+ learning_rate: @config[:learning_rate],
203
+ min_lr: @config[:min_lr] || 1e-4,
204
+ warmup_iters: @config[:warmup_iters] || 100,
205
+ lr_decay_iters: @config[:lr_decay_iters] || 5000
206
+ )
207
+ end
208
+ end
209
+ end
210
+ end
@@ -0,0 +1,9 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "bpe_textfile_preparer"
4
+ require_relative "web/metrics_store"
5
+ require_relative "web/training_state"
6
+ require_relative "web/sse_notifier"
7
+ require_relative "web/web_trainer"
8
+ require_relative "web/training_worker"
9
+ require_relative "web/server"
data/lib/nano_gpt.rb CHANGED
@@ -16,3 +16,4 @@ require_relative "nano_gpt/data_loader"
16
16
  require_relative "nano_gpt/lr_scheduler"
17
17
  require_relative "nano_gpt/trainer"
18
18
  require_relative "nano_gpt/train_config"
19
+ require_relative "nano_gpt/textfile_preparer"
data/nanogpt.gemspec CHANGED
@@ -32,6 +32,10 @@ Gem::Specification.new do |spec|
32
32
  spec.add_dependency "torch-rb", "~> 0.14"
33
33
  spec.add_dependency "numo-narray", "~> 0.9"
34
34
  spec.add_dependency "tiktoken_ruby", "~> 0.0"
35
+ spec.add_dependency "sinatra", "~> 4.0"
36
+ spec.add_dependency "rackup", "~> 2.0"
37
+ spec.add_dependency "webrick", "~> 1.8"
38
+ spec.add_dependency "sqlite3", "~> 2.0"
35
39
 
36
40
  spec.add_development_dependency "rspec", "~> 3.12"
37
41
  end