nanogpt 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile.lock +30 -1
- data/docs/ARCHITECTURE.md +429 -0
- data/exe/nanogpt +210 -233
- data/lib/nano_gpt/bpe_textfile_preparer.rb +105 -0
- data/lib/nano_gpt/data_loader.rb +5 -20
- data/lib/nano_gpt/layers/block.rb +6 -1
- data/lib/nano_gpt/layers/causal_self_attention.rb +11 -1
- data/lib/nano_gpt/model.rb +1 -7
- data/lib/nano_gpt/textfile_preparer.rb +189 -0
- data/lib/nano_gpt/train_config.rb +80 -146
- data/lib/nano_gpt/trainer.rb +21 -48
- data/lib/nano_gpt/version.rb +1 -1
- data/lib/nano_gpt/web/metrics_store.rb +136 -0
- data/lib/nano_gpt/web/server.rb +294 -0
- data/lib/nano_gpt/web/sse_notifier.rb +37 -0
- data/lib/nano_gpt/web/training_state.rb +56 -0
- data/lib/nano_gpt/web/training_worker.rb +153 -0
- data/lib/nano_gpt/web/views/layout.erb +78 -0
- data/lib/nano_gpt/web/views/run_detail.erb +432 -0
- data/lib/nano_gpt/web/views/runs.erb +434 -0
- data/lib/nano_gpt/web/web_trainer.rb +210 -0
- data/lib/nano_gpt/web.rb +9 -0
- data/lib/nano_gpt.rb +1 -0
- data/nanogpt.gemspec +4 -0
- metadata +71 -2
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "thread"
|
|
4
|
+
|
|
5
|
+
module NanoGPT
|
|
6
|
+
module Web
|
|
7
|
+
# Processes all Torch operations on the main thread.
|
|
8
|
+
# The web server runs in a background thread while this worker owns the
|
|
9
|
+
# main thread and processes commands from a Queue.
|
|
10
|
+
class TrainingWorker
|
|
11
|
+
attr_reader :queue
|
|
12
|
+
|
|
13
|
+
def initialize(training_state:, metrics_store:, sse_notifier:)
|
|
14
|
+
@training_state = training_state
|
|
15
|
+
@metrics_store = metrics_store
|
|
16
|
+
@sse_notifier = sse_notifier
|
|
17
|
+
@queue = Queue.new
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Run the command loop on the MAIN thread (blocks forever).
|
|
21
|
+
# Call this AFTER starting the web server in a background thread.
|
|
22
|
+
def run
|
|
23
|
+
loop do
|
|
24
|
+
msg = @queue.pop
|
|
25
|
+
case msg[:command]
|
|
26
|
+
when :start then handle_start(**msg[:args])
|
|
27
|
+
when :resume then handle_resume(**msg[:args])
|
|
28
|
+
when :generate then handle_generate(msg)
|
|
29
|
+
when :prepare_dataset then handle_prepare_dataset(msg)
|
|
30
|
+
when :shutdown then break
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
rescue => e
|
|
34
|
+
puts "Training worker crashed: #{e.message}\n#{e.backtrace.first(5).join("\n")}"
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
# Enqueue a fire-and-forget training command
|
|
38
|
+
def enqueue(command, **args)
|
|
39
|
+
@queue.push({ command: command, args: args })
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
# Enqueue a command and wait for the result (used for generation)
|
|
43
|
+
def enqueue_sync(command, **args)
|
|
44
|
+
result_queue = Queue.new
|
|
45
|
+
@queue.push({ command: command, args: args, result: result_queue })
|
|
46
|
+
result_queue.pop
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
private
|
|
50
|
+
|
|
51
|
+
def handle_start(config:, data_dir:, run_id:)
|
|
52
|
+
run_training(config, data_dir, run_id)
|
|
53
|
+
rescue => e
|
|
54
|
+
handle_training_error(run_id, e)
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def handle_resume(config:, data_dir:, run_id:, checkpoint_path:)
|
|
58
|
+
run_training(config, data_dir, run_id, resume: checkpoint_path)
|
|
59
|
+
rescue => e
|
|
60
|
+
handle_training_error(run_id, e)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def handle_generate(msg)
|
|
64
|
+
args = msg[:args]
|
|
65
|
+
result = generate_text(**args)
|
|
66
|
+
msg[:result]&.push(result)
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def handle_prepare_dataset(msg)
|
|
70
|
+
args = msg[:args]
|
|
71
|
+
result = prepare_dataset(**args)
|
|
72
|
+
msg[:result]&.push(result)
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def prepare_dataset(input_path:, output_name:, tokenizer:, val_ratio: 0.1)
|
|
76
|
+
preparer = if tokenizer == "bpe"
|
|
77
|
+
BPETextfilePreparer.new(input_path: input_path, output_name: output_name, val_ratio: val_ratio)
|
|
78
|
+
else
|
|
79
|
+
TextfilePreparer.new(input_path: input_path, output_name: output_name, val_ratio: val_ratio)
|
|
80
|
+
end
|
|
81
|
+
name = preparer.prepare
|
|
82
|
+
{ ok: true, dataset: name }
|
|
83
|
+
rescue => e
|
|
84
|
+
{ ok: false, error: e.message }
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def generate_text(prompt:, temperature:, max_tokens:, top_k:, dataset:, checkpoint_path: nil, out_dir: nil)
|
|
88
|
+
ckpt_path = checkpoint_path || File.join(out_dir, "ckpt.pt")
|
|
89
|
+
checkpoint = Torch.load(ckpt_path)
|
|
90
|
+
model_args = checkpoint["model_args"].transform_keys(&:to_sym)
|
|
91
|
+
model_config = GPTConfig.new(**model_args)
|
|
92
|
+
gen_model = GPT.new(model_config)
|
|
93
|
+
gen_model.load_state_dict(checkpoint["model"])
|
|
94
|
+
gen_model.eval
|
|
95
|
+
|
|
96
|
+
dataset_dir = File.join("data", dataset)
|
|
97
|
+
tokenizer = Tokenizer.for_dataset(dataset_dir)
|
|
98
|
+
start_ids = tokenizer.encode(prompt)
|
|
99
|
+
x = Torch.tensor([start_ids], dtype: :long)
|
|
100
|
+
|
|
101
|
+
y = gen_model.generate(x, max_tokens, temperature: temperature, top_k: top_k)
|
|
102
|
+
text = tokenizer.decode(y[0].to_a)
|
|
103
|
+
|
|
104
|
+
{ ok: true, text: text }
|
|
105
|
+
rescue => e
|
|
106
|
+
{ ok: false, error: e.message }
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def run_training(config, data_dir, run_id, resume: nil)
|
|
110
|
+
tokenizer = Tokenizer.for_dataset(data_dir)
|
|
111
|
+
model_config = GPTConfig.new(
|
|
112
|
+
block_size: config[:block_size],
|
|
113
|
+
vocab_size: tokenizer.vocab_size,
|
|
114
|
+
n_layer: config[:n_layer],
|
|
115
|
+
n_head: config[:n_head],
|
|
116
|
+
n_embd: config[:n_embd],
|
|
117
|
+
dropout: config[:dropout],
|
|
118
|
+
bias: config[:bias]
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
device = config[:device]
|
|
122
|
+
model = GPT.new(model_config)
|
|
123
|
+
model.to(device) if device != "cpu"
|
|
124
|
+
|
|
125
|
+
data_loader = DataLoader.new(
|
|
126
|
+
data_dir: data_dir,
|
|
127
|
+
block_size: config[:block_size],
|
|
128
|
+
batch_size: config[:batch_size],
|
|
129
|
+
device: device
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
trainer = WebTrainer.new(
|
|
133
|
+
model: model,
|
|
134
|
+
data_loader: data_loader,
|
|
135
|
+
config: config,
|
|
136
|
+
training_state: @training_state,
|
|
137
|
+
metrics_store: @metrics_store,
|
|
138
|
+
sse_notifier: @sse_notifier
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
trainer.load_checkpoint(resume) if resume
|
|
142
|
+
trainer.train(run_id)
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
def handle_training_error(run_id, error)
|
|
146
|
+
@training_state.update(status: "error")
|
|
147
|
+
@metrics_store.update_run(run_id, status: "error", stopped_at: Time.now.iso8601)
|
|
148
|
+
@sse_notifier.broadcast(type: "error", data: { message: error.message })
|
|
149
|
+
puts "Training error: #{error.message}\n#{error.backtrace.first(5).join("\n")}"
|
|
150
|
+
end
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
end
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
<!DOCTYPE html>
|
|
2
|
+
<html lang="en">
|
|
3
|
+
<head>
|
|
4
|
+
<meta charset="UTF-8">
|
|
5
|
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
6
|
+
<title>nanoGPT</title>
|
|
7
|
+
<script src="https://cdn.tailwindcss.com"></script>
|
|
8
|
+
<script src="https://cdn.jsdelivr.net/npm/chart.js@4"></script>
|
|
9
|
+
<style>
|
|
10
|
+
.status-idle { color: #6b7280; }
|
|
11
|
+
.status-running { color: #10b981; }
|
|
12
|
+
.status-stopped { color: #f59e0b; }
|
|
13
|
+
.status-completed { color: #3b82f6; }
|
|
14
|
+
.status-error { color: #ef4444; }
|
|
15
|
+
.pulse { animation: pulse 2s infinite; }
|
|
16
|
+
@keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } }
|
|
17
|
+
.badge { @apply px-2 py-0.5 rounded text-xs font-medium; }
|
|
18
|
+
.badge-running { @apply bg-green-900 text-green-300; }
|
|
19
|
+
.badge-stopped { @apply bg-yellow-900 text-yellow-300; }
|
|
20
|
+
.badge-completed { @apply bg-blue-900 text-blue-300; }
|
|
21
|
+
.badge-error { @apply bg-red-900 text-red-300; }
|
|
22
|
+
</style>
|
|
23
|
+
</head>
|
|
24
|
+
<body class="bg-gray-900 text-gray-100 min-h-screen">
|
|
25
|
+
<header class="bg-gray-800 border-b border-gray-700 px-6 py-3">
|
|
26
|
+
<div class="flex items-center justify-between max-w-7xl mx-auto">
|
|
27
|
+
<nav class="flex items-center gap-6">
|
|
28
|
+
<a href="/" class="text-lg font-bold tracking-tight hover:text-white">nanoGPT</a>
|
|
29
|
+
</nav>
|
|
30
|
+
<div class="flex items-center gap-2">
|
|
31
|
+
<span id="nav-status-dot" class="w-2.5 h-2.5 rounded-full bg-gray-500 inline-block"></span>
|
|
32
|
+
<span id="nav-status-text" class="text-xs font-medium status-idle">Idle</span>
|
|
33
|
+
</div>
|
|
34
|
+
</div>
|
|
35
|
+
</header>
|
|
36
|
+
|
|
37
|
+
<main class="max-w-7xl mx-auto px-6 py-6 space-y-6">
|
|
38
|
+
<%= yield %>
|
|
39
|
+
</main>
|
|
40
|
+
|
|
41
|
+
<script>
|
|
42
|
+
async function pollNavStatus() {
|
|
43
|
+
try {
|
|
44
|
+
const res = await fetch('/train/status');
|
|
45
|
+
if (!res.ok) return;
|
|
46
|
+
const state = JSON.parse(await res.text());
|
|
47
|
+
const status = state.status || 'idle';
|
|
48
|
+
const dot = document.getElementById('nav-status-dot');
|
|
49
|
+
const text = document.getElementById('nav-status-text');
|
|
50
|
+
if (!dot || !text) return;
|
|
51
|
+
|
|
52
|
+
const labels = { idle: 'Idle', running: 'Training...', stopped: 'Stopped', completed: 'Completed', error: 'Error' };
|
|
53
|
+
text.textContent = labels[status] || status;
|
|
54
|
+
text.className = 'text-xs font-medium status-' + status;
|
|
55
|
+
|
|
56
|
+
const colors = { idle: 'bg-gray-500', running: 'bg-green-500', stopped: 'bg-yellow-500', completed: 'bg-blue-500', error: 'bg-red-500' };
|
|
57
|
+
dot.className = 'w-2.5 h-2.5 rounded-full inline-block ' + (colors[status] || 'bg-gray-500');
|
|
58
|
+
if (status === 'running') dot.classList.add('pulse');
|
|
59
|
+
|
|
60
|
+
// Dispatch for page-specific listeners
|
|
61
|
+
window.dispatchEvent(new CustomEvent('training-status', { detail: state }));
|
|
62
|
+
} catch (e) {}
|
|
63
|
+
}
|
|
64
|
+
pollNavStatus();
|
|
65
|
+
setInterval(pollNavStatus, 2000);
|
|
66
|
+
|
|
67
|
+
function statusBadge(s) {
|
|
68
|
+
const colors = { running: 'bg-green-900 text-green-300', stopped: 'bg-yellow-900 text-yellow-300', completed: 'bg-blue-900 text-blue-300', error: 'bg-red-900 text-red-300' };
|
|
69
|
+
return '<span class="px-2 py-0.5 rounded text-xs font-medium ' + (colors[s] || 'bg-gray-700 text-gray-300') + '">' + s + '</span>';
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
function formatTime(iso) {
|
|
73
|
+
if (!iso) return '--';
|
|
74
|
+
return new Date(iso).toLocaleString(undefined, { month: 'short', day: 'numeric', hour: '2-digit', minute: '2-digit' });
|
|
75
|
+
}
|
|
76
|
+
</script>
|
|
77
|
+
</body>
|
|
78
|
+
</html>
|
|
@@ -0,0 +1,432 @@
|
|
|
1
|
+
<%
|
|
2
|
+
config = JSON.parse(@run["config_json"] || "{}") rescue {}
|
|
3
|
+
max_iters = config["max_iters"] || 0
|
|
4
|
+
is_running = @run["status"] == "running"
|
|
5
|
+
%>
|
|
6
|
+
|
|
7
|
+
<!-- Breadcrumb + Status -->
|
|
8
|
+
<div class="flex items-center justify-between">
|
|
9
|
+
<div class="flex items-center gap-3">
|
|
10
|
+
<a href="/" class="text-gray-400 hover:text-white text-sm">← Runs</a>
|
|
11
|
+
<span class="text-gray-600">/</span>
|
|
12
|
+
<h1 class="text-lg font-semibold">Run #<%= @run["id"] %></h1>
|
|
13
|
+
<span id="run-status-badge"><%= @run["status"] %></span>
|
|
14
|
+
</div>
|
|
15
|
+
<div class="flex gap-2" id="run-actions">
|
|
16
|
+
<% if is_running %>
|
|
17
|
+
<button onclick="stopRun()" id="btn-stop" class="px-4 py-2 bg-red-600 hover:bg-red-700 rounded-lg text-sm font-medium transition pulse">
|
|
18
|
+
Stop
|
|
19
|
+
</button>
|
|
20
|
+
<% elsif @run["status"] == "stopped" || @run["status"] == "completed" %>
|
|
21
|
+
<% if @run["checkpoint_path"] %>
|
|
22
|
+
<button onclick="resumeRun()" class="px-4 py-2 bg-yellow-600 hover:bg-yellow-700 rounded-lg text-sm font-medium transition">
|
|
23
|
+
Resume
|
|
24
|
+
</button>
|
|
25
|
+
<% end %>
|
|
26
|
+
<% end %>
|
|
27
|
+
</div>
|
|
28
|
+
</div>
|
|
29
|
+
|
|
30
|
+
<!-- Stats -->
|
|
31
|
+
<div class="bg-gray-800 rounded-lg p-4 grid grid-cols-2 md:grid-cols-5 gap-4">
|
|
32
|
+
<div>
|
|
33
|
+
<div class="text-xs text-gray-400 uppercase">Dataset</div>
|
|
34
|
+
<div class="text-sm font-mono"><%= @run["dataset"] %></div>
|
|
35
|
+
</div>
|
|
36
|
+
<div>
|
|
37
|
+
<div class="text-xs text-gray-400 uppercase">Iteration</div>
|
|
38
|
+
<div class="text-sm font-mono"><span id="stat-iter"><%= @run["current_iter"] || 0 %></span> / <%= max_iters %></div>
|
|
39
|
+
</div>
|
|
40
|
+
<div>
|
|
41
|
+
<div class="text-xs text-gray-400 uppercase">Best Val Loss</div>
|
|
42
|
+
<div class="text-sm font-mono" id="stat-val-loss"><%= @run["best_val_loss"] ? format("%.4f", @run["best_val_loss"]) : "--" %></div>
|
|
43
|
+
</div>
|
|
44
|
+
<div>
|
|
45
|
+
<div class="text-xs text-gray-400 uppercase">Started</div>
|
|
46
|
+
<div class="text-sm font-mono"><%= @run["started_at"] ? Time.parse(@run["started_at"]).strftime("%b %d %H:%M") : "--" %></div>
|
|
47
|
+
</div>
|
|
48
|
+
<div>
|
|
49
|
+
<div class="text-xs text-gray-400 uppercase">Status</div>
|
|
50
|
+
<div class="text-sm font-mono" id="stat-status"><%= @run["status"] %></div>
|
|
51
|
+
</div>
|
|
52
|
+
</div>
|
|
53
|
+
|
|
54
|
+
<!-- Progress Bar (if running) -->
|
|
55
|
+
<div id="progress-section" class="<%= 'hidden' unless is_running %>">
|
|
56
|
+
<div class="bg-gray-800 rounded-lg p-4">
|
|
57
|
+
<div class="flex items-center justify-between text-xs text-gray-400 mb-1.5">
|
|
58
|
+
<span id="progress-label"><%= @run["current_iter"] || 0 %> / <%= max_iters %></span>
|
|
59
|
+
<span id="progress-eta"></span>
|
|
60
|
+
</div>
|
|
61
|
+
<div class="w-full bg-gray-700 rounded-full h-2.5">
|
|
62
|
+
<% pct = max_iters > 0 ? [100, ((@run["current_iter"] || 0).to_f / max_iters * 100)].min : 0 %>
|
|
63
|
+
<div id="progress-bar" class="bg-green-500 h-2.5 rounded-full transition-all duration-300" style="width: <%= pct %>%"></div>
|
|
64
|
+
</div>
|
|
65
|
+
</div>
|
|
66
|
+
</div>
|
|
67
|
+
|
|
68
|
+
<!-- Loss Chart -->
|
|
69
|
+
<div class="bg-gray-800 rounded-lg p-4">
|
|
70
|
+
<h2 class="text-sm font-medium text-gray-300 mb-3">Loss</h2>
|
|
71
|
+
<div style="height: 350px;">
|
|
72
|
+
<canvas id="loss-chart"></canvas>
|
|
73
|
+
</div>
|
|
74
|
+
</div>
|
|
75
|
+
|
|
76
|
+
<!-- Config -->
|
|
77
|
+
<details class="bg-gray-800 rounded-lg">
|
|
78
|
+
<summary class="px-4 py-3 cursor-pointer text-sm font-medium text-gray-400 hover:text-white">Configuration</summary>
|
|
79
|
+
<div class="px-4 pb-4">
|
|
80
|
+
<pre class="text-xs font-mono text-gray-400 whitespace-pre-wrap"><%= JSON.pretty_generate(config) %></pre>
|
|
81
|
+
</div>
|
|
82
|
+
</details>
|
|
83
|
+
|
|
84
|
+
<!-- Checkpoints -->
|
|
85
|
+
<div class="bg-gray-800 rounded-lg p-4">
|
|
86
|
+
<h2 class="text-sm font-medium text-gray-300 mb-3">Checkpoints</h2>
|
|
87
|
+
<div id="checkpoints-list">
|
|
88
|
+
<div class="text-xs text-gray-500">Loading...</div>
|
|
89
|
+
</div>
|
|
90
|
+
</div>
|
|
91
|
+
|
|
92
|
+
<!-- Generate from Checkpoint -->
|
|
93
|
+
<div class="bg-gray-800 rounded-lg p-4" id="generate-section">
|
|
94
|
+
<h2 class="text-sm font-medium text-gray-300 mb-3">Generate Text</h2>
|
|
95
|
+
<div class="space-y-3">
|
|
96
|
+
<div>
|
|
97
|
+
<label class="block text-xs text-gray-400 mb-1">Checkpoint</label>
|
|
98
|
+
<select id="gen-checkpoint" class="w-full bg-gray-700 rounded px-3 py-2 text-sm font-mono">
|
|
99
|
+
<option value="">Select a checkpoint...</option>
|
|
100
|
+
</select>
|
|
101
|
+
</div>
|
|
102
|
+
<div class="grid grid-cols-1 md:grid-cols-4 gap-3">
|
|
103
|
+
<div class="md:col-span-2">
|
|
104
|
+
<label class="block text-xs text-gray-400 mb-1">Prompt</label>
|
|
105
|
+
<textarea id="gen-prompt" rows="3" class="w-full bg-gray-700 rounded px-3 py-2 text-sm font-mono resize-y" placeholder="Enter prompt text..."></textarea>
|
|
106
|
+
</div>
|
|
107
|
+
<div class="space-y-2">
|
|
108
|
+
<div>
|
|
109
|
+
<label class="block text-xs text-gray-400 mb-1">Temperature</label>
|
|
110
|
+
<input id="gen-temperature" type="number" value="0.8" step="0.1" min="0.1" max="2.0" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
111
|
+
</div>
|
|
112
|
+
<div>
|
|
113
|
+
<label class="block text-xs text-gray-400 mb-1">Max Tokens</label>
|
|
114
|
+
<input id="gen-max_tokens" type="number" value="200" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
115
|
+
</div>
|
|
116
|
+
</div>
|
|
117
|
+
<div>
|
|
118
|
+
<label class="block text-xs text-gray-400 mb-1">Top-K</label>
|
|
119
|
+
<input id="gen-top_k" type="number" value="200" class="w-full bg-gray-700 rounded px-2 py-1.5 text-sm">
|
|
120
|
+
</div>
|
|
121
|
+
</div>
|
|
122
|
+
<div class="flex items-center gap-3">
|
|
123
|
+
<button id="btn-generate" onclick="generateText()" class="px-4 py-2 bg-purple-600 hover:bg-purple-700 rounded-lg font-medium text-sm transition disabled:opacity-40 disabled:cursor-not-allowed">
|
|
124
|
+
Generate
|
|
125
|
+
</button>
|
|
126
|
+
<span id="gen-status" class="text-xs text-gray-400"></span>
|
|
127
|
+
</div>
|
|
128
|
+
<pre id="gen-output" class="hidden bg-gray-950 rounded-lg p-4 text-sm font-mono text-green-400 whitespace-pre-wrap min-h-[100px] max-h-[400px] overflow-auto"></pre>
|
|
129
|
+
</div>
|
|
130
|
+
</div>
|
|
131
|
+
|
|
132
|
+
<script>
|
|
133
|
+
const RUN_ID = <%= @run["id"] %>;
|
|
134
|
+
const RUN_DATASET = '<%= @run["dataset"] %>';
|
|
135
|
+
const RUN_STATUS = '<%= @run["status"] %>';
|
|
136
|
+
const MAX_ITERS = <%= max_iters %>;
|
|
137
|
+
let checkpointsData = [];
|
|
138
|
+
let iterTimes = [];
|
|
139
|
+
let lastSeenIter = 0;
|
|
140
|
+
let pollTimer = null;
|
|
141
|
+
|
|
142
|
+
// ---- Chart ----
|
|
143
|
+
const ctx = document.getElementById('loss-chart').getContext('2d');
|
|
144
|
+
const chart = new Chart(ctx, {
|
|
145
|
+
type: 'line',
|
|
146
|
+
data: {
|
|
147
|
+
datasets: [
|
|
148
|
+
{
|
|
149
|
+
label: 'Train Loss',
|
|
150
|
+
borderColor: '#3b82f6',
|
|
151
|
+
backgroundColor: 'rgba(59,130,246,0.1)',
|
|
152
|
+
data: [],
|
|
153
|
+
pointRadius: 0,
|
|
154
|
+
borderWidth: 1.5,
|
|
155
|
+
tension: 0.1
|
|
156
|
+
},
|
|
157
|
+
{
|
|
158
|
+
label: 'Val Loss',
|
|
159
|
+
borderColor: '#ef4444',
|
|
160
|
+
backgroundColor: 'rgba(239,68,68,0.1)',
|
|
161
|
+
data: [],
|
|
162
|
+
pointRadius: 2,
|
|
163
|
+
borderWidth: 2,
|
|
164
|
+
tension: 0.1
|
|
165
|
+
}
|
|
166
|
+
]
|
|
167
|
+
},
|
|
168
|
+
options: {
|
|
169
|
+
responsive: true,
|
|
170
|
+
maintainAspectRatio: false,
|
|
171
|
+
animation: false,
|
|
172
|
+
interaction: { mode: 'nearest', intersect: false },
|
|
173
|
+
scales: {
|
|
174
|
+
x: {
|
|
175
|
+
type: 'linear',
|
|
176
|
+
title: { display: true, text: 'Iteration', color: '#9ca3af' },
|
|
177
|
+
ticks: { color: '#6b7280' },
|
|
178
|
+
grid: { color: '#374151' }
|
|
179
|
+
},
|
|
180
|
+
y: {
|
|
181
|
+
title: { display: true, text: 'Loss', color: '#9ca3af' },
|
|
182
|
+
ticks: { color: '#6b7280' },
|
|
183
|
+
grid: { color: '#374151' }
|
|
184
|
+
}
|
|
185
|
+
},
|
|
186
|
+
plugins: {
|
|
187
|
+
legend: { labels: { color: '#d1d5db' } }
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
});
|
|
191
|
+
|
|
192
|
+
// ---- Load Stored Metrics (from SQLite) ----
|
|
193
|
+
async function loadMetrics() {
|
|
194
|
+
try {
|
|
195
|
+
const res = await fetch('/api/runs/' + RUN_ID + '/metrics');
|
|
196
|
+
if (!res.ok) return;
|
|
197
|
+
const data = await res.json();
|
|
198
|
+
|
|
199
|
+
if (data.metrics.train_loss) {
|
|
200
|
+
chart.data.datasets[0].data = data.metrics.train_loss.map(p => ({ x: p.iteration, y: p.value }));
|
|
201
|
+
const maxIter = data.metrics.train_loss[data.metrics.train_loss.length - 1];
|
|
202
|
+
if (maxIter) lastSeenIter = maxIter.iteration;
|
|
203
|
+
}
|
|
204
|
+
if (data.metrics.val_loss) {
|
|
205
|
+
chart.data.datasets[1].data = data.metrics.val_loss.map(p => ({ x: p.iteration, y: p.value }));
|
|
206
|
+
}
|
|
207
|
+
chart.update('none');
|
|
208
|
+
} catch (e) {}
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
// ---- Live Polling (only for the active running run) ----
|
|
212
|
+
function startLivePolling() {
|
|
213
|
+
if (pollTimer) return;
|
|
214
|
+
pollTimer = setInterval(async () => {
|
|
215
|
+
try {
|
|
216
|
+
const res = await fetch('/metrics/poll?since_iter=' + lastSeenIter);
|
|
217
|
+
if (!res.ok) return;
|
|
218
|
+
const data = await res.json();
|
|
219
|
+
|
|
220
|
+
// Only update if this is still the active run
|
|
221
|
+
if (data.state && data.state.run_id !== RUN_ID) return;
|
|
222
|
+
|
|
223
|
+
let updated = false;
|
|
224
|
+
if (data.metrics && data.metrics.train_loss) {
|
|
225
|
+
data.metrics.train_loss.forEach(p => {
|
|
226
|
+
chart.data.datasets[0].data.push({ x: p.iteration, y: p.value });
|
|
227
|
+
if (p.iteration > lastSeenIter) lastSeenIter = p.iteration;
|
|
228
|
+
});
|
|
229
|
+
updated = true;
|
|
230
|
+
}
|
|
231
|
+
if (data.metrics && data.metrics.val_loss) {
|
|
232
|
+
data.metrics.val_loss.forEach(p => {
|
|
233
|
+
chart.data.datasets[1].data.push({ x: p.iteration, y: p.value });
|
|
234
|
+
});
|
|
235
|
+
updated = true;
|
|
236
|
+
}
|
|
237
|
+
if (data.metrics && data.metrics.iter_time_ms) {
|
|
238
|
+
data.metrics.iter_time_ms.forEach(p => {
|
|
239
|
+
iterTimes.push(p.value);
|
|
240
|
+
if (iterTimes.length > 10) iterTimes.shift();
|
|
241
|
+
});
|
|
242
|
+
}
|
|
243
|
+
if (updated) chart.update('none');
|
|
244
|
+
|
|
245
|
+
// Update progress
|
|
246
|
+
if (data.state) {
|
|
247
|
+
updateLiveStats(data.state);
|
|
248
|
+
}
|
|
249
|
+
} catch (e) {}
|
|
250
|
+
}, 1000);
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
function updateLiveStats(state) {
|
|
254
|
+
if (state.current_iter != null) {
|
|
255
|
+
document.getElementById('stat-iter').textContent = state.current_iter;
|
|
256
|
+
}
|
|
257
|
+
if (state.best_val_loss != null) {
|
|
258
|
+
document.getElementById('stat-val-loss').textContent = state.best_val_loss.toFixed(4);
|
|
259
|
+
}
|
|
260
|
+
if (state.status) {
|
|
261
|
+
document.getElementById('stat-status').textContent = state.status;
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
const section = document.getElementById('progress-section');
|
|
265
|
+
const bar = document.getElementById('progress-bar');
|
|
266
|
+
const label = document.getElementById('progress-label');
|
|
267
|
+
const eta = document.getElementById('progress-eta');
|
|
268
|
+
|
|
269
|
+
if (state.status === 'running' && MAX_ITERS > 0) {
|
|
270
|
+
section.classList.remove('hidden');
|
|
271
|
+
const pct = Math.min(100, ((state.current_iter || 0) / MAX_ITERS) * 100);
|
|
272
|
+
bar.style.width = pct.toFixed(1) + '%';
|
|
273
|
+
label.textContent = (state.current_iter || 0) + ' / ' + MAX_ITERS + ' (' + pct.toFixed(0) + '%)';
|
|
274
|
+
|
|
275
|
+
if (iterTimes.length > 0) {
|
|
276
|
+
const avgMs = iterTimes.reduce((a, b) => a + b, 0) / iterTimes.length;
|
|
277
|
+
const remaining = MAX_ITERS - (state.current_iter || 0);
|
|
278
|
+
const etaSec = (remaining * avgMs) / 1000;
|
|
279
|
+
eta.textContent = etaSec > 60
|
|
280
|
+
? '~' + Math.floor(etaSec / 60) + 'm' + Math.floor(etaSec % 60) + 's remaining'
|
|
281
|
+
: '~' + Math.floor(etaSec) + 's remaining';
|
|
282
|
+
}
|
|
283
|
+
} else {
|
|
284
|
+
section.classList.add('hidden');
|
|
285
|
+
if (state.status !== 'running' && pollTimer) {
|
|
286
|
+
clearInterval(pollTimer);
|
|
287
|
+
pollTimer = null;
|
|
288
|
+
loadCheckpoints(); // Refresh checkpoints when run completes
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
// ---- Checkpoints ----
|
|
294
|
+
async function loadCheckpoints() {
|
|
295
|
+
try {
|
|
296
|
+
const res = await fetch('/api/runs/' + RUN_ID + '/checkpoints');
|
|
297
|
+
const data = await res.json();
|
|
298
|
+
checkpointsData = (data.checkpoints || []).filter(c => {
|
|
299
|
+
// Verify file exists info isn't available server-side for legacy,
|
|
300
|
+
// but we still show them
|
|
301
|
+
return true;
|
|
302
|
+
});
|
|
303
|
+
|
|
304
|
+
const container = document.getElementById('checkpoints-list');
|
|
305
|
+
const select = document.getElementById('gen-checkpoint');
|
|
306
|
+
|
|
307
|
+
if (checkpointsData.length === 0) {
|
|
308
|
+
container.innerHTML = '<div class="text-xs text-gray-500">No checkpoints saved yet.</div>';
|
|
309
|
+
select.innerHTML = '<option value="">No checkpoints available</option>';
|
|
310
|
+
return;
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
container.innerHTML = '<div class="space-y-1">' + checkpointsData.map((cp, i) => {
|
|
314
|
+
const valLoss = cp.val_loss ? parseFloat(cp.val_loss).toFixed(4) : '--';
|
|
315
|
+
return `
|
|
316
|
+
<div class="flex items-center justify-between text-sm py-1.5 px-2 rounded hover:bg-gray-700/50">
|
|
317
|
+
<div class="flex items-center gap-3 font-mono text-xs">
|
|
318
|
+
<span class="px-1.5 py-0.5 rounded bg-gray-700 text-gray-300">${cp.suffix}</span>
|
|
319
|
+
<span class="text-gray-400">iter ${cp.iteration}</span>
|
|
320
|
+
<span class="text-gray-500">val_loss: ${valLoss}</span>
|
|
321
|
+
</div>
|
|
322
|
+
<button onclick="selectCheckpoint(${i})" class="text-xs text-purple-400 hover:text-purple-300">Generate</button>
|
|
323
|
+
</div>
|
|
324
|
+
`;
|
|
325
|
+
}).join('') + '</div>';
|
|
326
|
+
|
|
327
|
+
select.innerHTML = checkpointsData.map((cp, i) => {
|
|
328
|
+
const valLoss = cp.val_loss ? parseFloat(cp.val_loss).toFixed(4) : '--';
|
|
329
|
+
return `<option value="${i}">${cp.suffix} -- iter ${cp.iteration} (val_loss: ${valLoss})</option>`;
|
|
330
|
+
}).join('');
|
|
331
|
+
} catch (e) {
|
|
332
|
+
document.getElementById('checkpoints-list').innerHTML =
|
|
333
|
+
'<div class="text-xs text-red-400">Failed to load checkpoints</div>';
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
function selectCheckpoint(idx) {
|
|
338
|
+
document.getElementById('gen-checkpoint').value = idx;
|
|
339
|
+
document.getElementById('gen-checkpoint').scrollIntoView({ behavior: 'smooth', block: 'center' });
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
// ---- Generate ----
|
|
343
|
+
async function generateText() {
|
|
344
|
+
const btn = document.getElementById('btn-generate');
|
|
345
|
+
const output = document.getElementById('gen-output');
|
|
346
|
+
const status = document.getElementById('gen-status');
|
|
347
|
+
const selectIdx = document.getElementById('gen-checkpoint').value;
|
|
348
|
+
|
|
349
|
+
if (selectIdx === '' || !checkpointsData[selectIdx]) {
|
|
350
|
+
alert('Please select a checkpoint first.');
|
|
351
|
+
return;
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
const cp = checkpointsData[selectIdx];
|
|
355
|
+
btn.disabled = true;
|
|
356
|
+
status.textContent = 'Loading model and generating...';
|
|
357
|
+
status.className = 'text-xs text-yellow-400';
|
|
358
|
+
output.classList.remove('hidden');
|
|
359
|
+
output.textContent = 'Loading...';
|
|
360
|
+
|
|
361
|
+
try {
|
|
362
|
+
const res = await fetch('/generate/run', {
|
|
363
|
+
method: 'POST',
|
|
364
|
+
headers: { 'Content-Type': 'application/json' },
|
|
365
|
+
body: JSON.stringify({
|
|
366
|
+
prompt: document.getElementById('gen-prompt').value || '\n',
|
|
367
|
+
temperature: parseFloat(document.getElementById('gen-temperature').value),
|
|
368
|
+
max_tokens: parseInt(document.getElementById('gen-max_tokens').value),
|
|
369
|
+
top_k: parseInt(document.getElementById('gen-top_k').value),
|
|
370
|
+
checkpoint_path: cp.path,
|
|
371
|
+
dataset: RUN_DATASET
|
|
372
|
+
})
|
|
373
|
+
});
|
|
374
|
+
const data = await res.json();
|
|
375
|
+
if (res.ok) {
|
|
376
|
+
output.textContent = data.text;
|
|
377
|
+
status.textContent = 'Done!';
|
|
378
|
+
status.className = 'text-xs text-green-400';
|
|
379
|
+
} else {
|
|
380
|
+
output.textContent = 'Error: ' + data.error;
|
|
381
|
+
status.textContent = 'Failed';
|
|
382
|
+
status.className = 'text-xs text-red-400';
|
|
383
|
+
}
|
|
384
|
+
} catch (e) {
|
|
385
|
+
output.textContent = 'Error: ' + e.message;
|
|
386
|
+
status.textContent = 'Failed';
|
|
387
|
+
status.className = 'text-xs text-red-400';
|
|
388
|
+
} finally {
|
|
389
|
+
btn.disabled = false;
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
// ---- Run Actions ----
|
|
394
|
+
async function stopRun() {
|
|
395
|
+
const btn = document.getElementById('btn-stop');
|
|
396
|
+
if (btn) btn.disabled = true;
|
|
397
|
+
await fetch('/train/stop', {
|
|
398
|
+
method: 'POST',
|
|
399
|
+
headers: { 'Content-Type': 'application/json' },
|
|
400
|
+
body: '{}'
|
|
401
|
+
});
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
async function resumeRun() {
|
|
405
|
+
const config = <%= JSON.generate(config) %>;
|
|
406
|
+
config.checkpoint_path = '<%= @run["checkpoint_path"]&.gsub("'", "\\\\'") %>';
|
|
407
|
+
|
|
408
|
+
const res = await fetch('/train/resume', {
|
|
409
|
+
method: 'POST',
|
|
410
|
+
headers: { 'Content-Type': 'application/json' },
|
|
411
|
+
body: JSON.stringify(config)
|
|
412
|
+
});
|
|
413
|
+
const data = await res.json();
|
|
414
|
+
if (res.ok) {
|
|
415
|
+
window.location = '/runs/' + data.run_id;
|
|
416
|
+
} else {
|
|
417
|
+
alert(data.error);
|
|
418
|
+
}
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
// ---- Init ----
|
|
422
|
+
async function init() {
|
|
423
|
+
await loadMetrics();
|
|
424
|
+
await loadCheckpoints();
|
|
425
|
+
|
|
426
|
+
if (RUN_STATUS === 'running') {
|
|
427
|
+
startLivePolling();
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
init();
|
|
432
|
+
</script>
|