nanogpt 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile.lock +30 -1
- data/docs/ARCHITECTURE.md +429 -0
- data/exe/nanogpt +210 -233
- data/lib/nano_gpt/bpe_textfile_preparer.rb +105 -0
- data/lib/nano_gpt/data_loader.rb +5 -20
- data/lib/nano_gpt/layers/block.rb +6 -1
- data/lib/nano_gpt/layers/causal_self_attention.rb +11 -1
- data/lib/nano_gpt/model.rb +1 -7
- data/lib/nano_gpt/textfile_preparer.rb +189 -0
- data/lib/nano_gpt/train_config.rb +80 -146
- data/lib/nano_gpt/trainer.rb +21 -48
- data/lib/nano_gpt/version.rb +1 -1
- data/lib/nano_gpt/web/metrics_store.rb +136 -0
- data/lib/nano_gpt/web/server.rb +294 -0
- data/lib/nano_gpt/web/sse_notifier.rb +37 -0
- data/lib/nano_gpt/web/training_state.rb +56 -0
- data/lib/nano_gpt/web/training_worker.rb +153 -0
- data/lib/nano_gpt/web/views/layout.erb +78 -0
- data/lib/nano_gpt/web/views/run_detail.erb +432 -0
- data/lib/nano_gpt/web/views/runs.erb +434 -0
- data/lib/nano_gpt/web/web_trainer.rb +210 -0
- data/lib/nano_gpt/web.rb +9 -0
- data/lib/nano_gpt.rb +1 -0
- data/nanogpt.gemspec +4 -0
- metadata +71 -2
data/lib/nano_gpt/trainer.rb
CHANGED
|
@@ -4,13 +4,24 @@ require "fileutils"
|
|
|
4
4
|
|
|
5
5
|
module NanoGPT
|
|
6
6
|
# Training loop for GPT models
|
|
7
|
+
# Accepts a TrainConfig (or hash with same keys) for all configuration
|
|
7
8
|
class Trainer
|
|
9
|
+
# Default optimizer parameters (can be overridden via config)
|
|
10
|
+
OPTIMIZER_DEFAULTS = {
|
|
11
|
+
weight_decay: 1e-1,
|
|
12
|
+
beta1: 0.9,
|
|
13
|
+
beta2: 0.99,
|
|
14
|
+
grad_clip: 1.0,
|
|
15
|
+
always_save_checkpoint: false,
|
|
16
|
+
eval_only: false
|
|
17
|
+
}.freeze
|
|
18
|
+
|
|
8
19
|
attr_reader :model, :optimizer, :config, :iter_num, :best_val_loss
|
|
9
20
|
|
|
10
|
-
def initialize(model:, data_loader:, config:
|
|
21
|
+
def initialize(model:, data_loader:, config:)
|
|
11
22
|
@model = model
|
|
12
23
|
@data_loader = data_loader
|
|
13
|
-
@config =
|
|
24
|
+
@config = OPTIMIZER_DEFAULTS.merge(symbolize_keys(config.is_a?(Hash) ? config : config.to_h))
|
|
14
25
|
|
|
15
26
|
@iter_num = 0
|
|
16
27
|
@best_val_loss = Float::INFINITY
|
|
@@ -19,36 +30,6 @@ module NanoGPT
|
|
|
19
30
|
setup_lr_scheduler
|
|
20
31
|
end
|
|
21
32
|
|
|
22
|
-
def default_config
|
|
23
|
-
{
|
|
24
|
-
out_dir: "out",
|
|
25
|
-
eval_interval: 250,
|
|
26
|
-
log_interval: 10,
|
|
27
|
-
eval_iters: 200,
|
|
28
|
-
eval_only: false,
|
|
29
|
-
always_save_checkpoint: false,
|
|
30
|
-
|
|
31
|
-
# Optimizer
|
|
32
|
-
learning_rate: 1e-3,
|
|
33
|
-
weight_decay: 1e-1,
|
|
34
|
-
beta1: 0.9,
|
|
35
|
-
beta2: 0.99,
|
|
36
|
-
grad_clip: 1.0,
|
|
37
|
-
|
|
38
|
-
# LR scheduler
|
|
39
|
-
decay_lr: true,
|
|
40
|
-
warmup_iters: 100,
|
|
41
|
-
lr_decay_iters: 5000,
|
|
42
|
-
min_lr: 1e-4,
|
|
43
|
-
|
|
44
|
-
# Training
|
|
45
|
-
max_iters: 5000,
|
|
46
|
-
gradient_accumulation_steps: 1,
|
|
47
|
-
|
|
48
|
-
device: "cpu"
|
|
49
|
-
}
|
|
50
|
-
end
|
|
51
|
-
|
|
52
33
|
def train
|
|
53
34
|
puts "Starting training..."
|
|
54
35
|
puts "Tokens per iteration: #{tokens_per_iter}"
|
|
@@ -58,10 +39,8 @@ module NanoGPT
|
|
|
58
39
|
t0 = Time.now
|
|
59
40
|
|
|
60
41
|
while @iter_num <= @config[:max_iters]
|
|
61
|
-
# Set learning rate for this iteration
|
|
62
42
|
lr = @config[:decay_lr] ? @lr_scheduler.step(@optimizer, @iter_num) : @config[:learning_rate]
|
|
63
43
|
|
|
64
|
-
# Evaluate and checkpoint
|
|
65
44
|
if @iter_num % @config[:eval_interval] == 0
|
|
66
45
|
losses = estimate_loss
|
|
67
46
|
puts "step #{@iter_num}: train loss #{losses[:train].round(4)}, val loss #{losses[:val].round(4)}"
|
|
@@ -74,29 +53,24 @@ module NanoGPT
|
|
|
74
53
|
|
|
75
54
|
break if @iter_num == 0 && @config[:eval_only]
|
|
76
55
|
|
|
77
|
-
# Forward/backward with gradient accumulation
|
|
78
56
|
@optimizer.zero_grad
|
|
79
57
|
|
|
80
58
|
accumulated_loss = 0.0
|
|
81
|
-
@config[:gradient_accumulation_steps].times do |
|
|
82
|
-
|
|
59
|
+
@config[:gradient_accumulation_steps].times do |_micro_step|
|
|
60
|
+
_logits, loss = @model.call(x, targets: y)
|
|
83
61
|
loss = loss / @config[:gradient_accumulation_steps]
|
|
84
62
|
accumulated_loss += loss.item
|
|
85
63
|
loss.backward
|
|
86
64
|
|
|
87
|
-
# Prefetch next batch
|
|
88
65
|
x, y = @data_loader.get_batch(:train)
|
|
89
66
|
end
|
|
90
67
|
|
|
91
|
-
# Gradient clipping (manual implementation since torch.rb lacks clip_grad_norm_)
|
|
92
68
|
if @config[:grad_clip] > 0.0
|
|
93
69
|
clip_grad_norm(@model.parameters, @config[:grad_clip])
|
|
94
70
|
end
|
|
95
71
|
|
|
96
|
-
# Optimizer step
|
|
97
72
|
@optimizer.step
|
|
98
73
|
|
|
99
|
-
# Logging
|
|
100
74
|
t1 = Time.now
|
|
101
75
|
dt = t1 - t0
|
|
102
76
|
t0 = t1
|
|
@@ -135,9 +109,7 @@ module NanoGPT
|
|
|
135
109
|
FileUtils.mkdir_p(@config[:out_dir])
|
|
136
110
|
path = File.join(@config[:out_dir], "ckpt.pt")
|
|
137
111
|
|
|
138
|
-
#
|
|
139
|
-
# We save model state and training metadata
|
|
140
|
-
# Convert symbol keys to strings for Torch.save compatibility
|
|
112
|
+
# Torch.save requires string keys
|
|
141
113
|
checkpoint = {
|
|
142
114
|
"model" => @model.state_dict,
|
|
143
115
|
"model_args" => stringify_keys(@model.config.to_h),
|
|
@@ -157,7 +129,6 @@ module NanoGPT
|
|
|
157
129
|
@iter_num = checkpoint["iter_num"]
|
|
158
130
|
@best_val_loss = checkpoint["best_val_loss"]
|
|
159
131
|
|
|
160
|
-
# Reinitialize optimizer (since we can't restore optimizer state in torch.rb)
|
|
161
132
|
setup_optimizer
|
|
162
133
|
|
|
163
134
|
puts "Loaded checkpoint from #{path} (iter #{@iter_num})"
|
|
@@ -166,21 +137,23 @@ module NanoGPT
|
|
|
166
137
|
|
|
167
138
|
private
|
|
168
139
|
|
|
169
|
-
|
|
140
|
+
def symbolize_keys(hash)
|
|
141
|
+
hash.transform_keys(&:to_sym)
|
|
142
|
+
end
|
|
143
|
+
|
|
170
144
|
def stringify_keys(hash)
|
|
171
145
|
hash.transform_keys(&:to_s).transform_values do |v|
|
|
172
146
|
v.is_a?(Hash) ? stringify_keys(v) : v
|
|
173
147
|
end
|
|
174
148
|
end
|
|
175
149
|
|
|
176
|
-
# Manual gradient clipping (torch.rb doesn't have clip_grad_norm_)
|
|
177
150
|
def clip_grad_norm(parameters, max_norm)
|
|
178
151
|
total_norm = 0.0
|
|
179
152
|
parameters.each do |p|
|
|
180
153
|
next unless p.grad
|
|
181
154
|
|
|
182
155
|
param_norm = p.grad.data.norm(2).item
|
|
183
|
-
total_norm += param_norm
|
|
156
|
+
total_norm += param_norm**2
|
|
184
157
|
end
|
|
185
158
|
total_norm = Math.sqrt(total_norm)
|
|
186
159
|
|
data/lib/nano_gpt/version.rb
CHANGED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "sqlite3"
|
|
4
|
+
require "json"
|
|
5
|
+
|
|
6
|
+
module NanoGPT
|
|
7
|
+
module Web
|
|
8
|
+
# SQLite-backed storage for training metrics
|
|
9
|
+
# Uses WAL mode for concurrent read/write access
|
|
10
|
+
class MetricsStore
|
|
11
|
+
attr_reader :db_path
|
|
12
|
+
|
|
13
|
+
def initialize(db_path = "nanogpt_metrics.db")
|
|
14
|
+
@db_path = db_path
|
|
15
|
+
@db = SQLite3::Database.new(db_path)
|
|
16
|
+
@db.results_as_hash = true
|
|
17
|
+
@db.execute("PRAGMA journal_mode=WAL")
|
|
18
|
+
@db.execute("PRAGMA synchronous=NORMAL")
|
|
19
|
+
create_tables
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def create_run(dataset:, config:, status: "running")
|
|
23
|
+
@db.execute(
|
|
24
|
+
"INSERT INTO training_runs (dataset, config_json, status, started_at) VALUES (?, ?, ?, ?)",
|
|
25
|
+
[dataset, JSON.generate(config), status, Time.now.iso8601]
|
|
26
|
+
)
|
|
27
|
+
@db.last_insert_row_id
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def update_run(run_id, **attrs)
|
|
31
|
+
sets = []
|
|
32
|
+
values = []
|
|
33
|
+
attrs.each do |key, value|
|
|
34
|
+
sets << "#{key} = ?"
|
|
35
|
+
values << value
|
|
36
|
+
end
|
|
37
|
+
values << run_id
|
|
38
|
+
@db.execute("UPDATE training_runs SET #{sets.join(', ')} WHERE id = ?", values)
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def record_metrics(run_id, iteration, metrics_hash)
|
|
42
|
+
recorded_at = Time.now.iso8601
|
|
43
|
+
metrics_hash.each do |metric_type, value|
|
|
44
|
+
@db.execute(
|
|
45
|
+
"INSERT INTO metrics (run_id, iteration, metric_type, value, recorded_at) VALUES (?, ?, ?, ?, ?)",
|
|
46
|
+
[run_id, iteration, metric_type.to_s, value, recorded_at]
|
|
47
|
+
)
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
def metrics_for_run(run_id)
|
|
52
|
+
@db.execute(
|
|
53
|
+
"SELECT iteration, metric_type, value FROM metrics WHERE run_id = ? ORDER BY iteration",
|
|
54
|
+
[run_id]
|
|
55
|
+
)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def latest_run
|
|
59
|
+
@db.get_first_row("SELECT * FROM training_runs ORDER BY id DESC LIMIT 1")
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def get_run(run_id)
|
|
63
|
+
@db.get_first_row("SELECT * FROM training_runs WHERE id = ?", [run_id])
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def list_runs(limit: 50)
|
|
67
|
+
@db.execute("SELECT * FROM training_runs ORDER BY id DESC LIMIT ?", [limit])
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
def record_checkpoint(run_id, path:, suffix:, iteration:, val_loss:)
|
|
71
|
+
@db.execute(
|
|
72
|
+
"INSERT INTO checkpoints (run_id, path, suffix, iteration, val_loss, saved_at) VALUES (?, ?, ?, ?, ?, ?)",
|
|
73
|
+
[run_id, path, suffix, iteration, val_loss == Float::INFINITY ? nil : val_loss, Time.now.iso8601]
|
|
74
|
+
)
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def checkpoints_for_run(run_id)
|
|
78
|
+
@db.execute(
|
|
79
|
+
"SELECT * FROM checkpoints WHERE run_id = ? ORDER BY iteration DESC",
|
|
80
|
+
[run_id]
|
|
81
|
+
)
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def close
|
|
85
|
+
@db.close
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
private
|
|
89
|
+
|
|
90
|
+
def create_tables
|
|
91
|
+
@db.execute(<<~SQL)
|
|
92
|
+
CREATE TABLE IF NOT EXISTS training_runs (
|
|
93
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
94
|
+
dataset TEXT NOT NULL,
|
|
95
|
+
config_json TEXT NOT NULL,
|
|
96
|
+
status TEXT NOT NULL DEFAULT 'running',
|
|
97
|
+
started_at TEXT NOT NULL,
|
|
98
|
+
stopped_at TEXT,
|
|
99
|
+
current_iter INTEGER DEFAULT 0,
|
|
100
|
+
best_val_loss REAL,
|
|
101
|
+
checkpoint_path TEXT
|
|
102
|
+
)
|
|
103
|
+
SQL
|
|
104
|
+
|
|
105
|
+
@db.execute(<<~SQL)
|
|
106
|
+
CREATE TABLE IF NOT EXISTS metrics (
|
|
107
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
108
|
+
run_id INTEGER NOT NULL,
|
|
109
|
+
iteration INTEGER NOT NULL,
|
|
110
|
+
metric_type TEXT NOT NULL,
|
|
111
|
+
value REAL NOT NULL,
|
|
112
|
+
recorded_at TEXT NOT NULL,
|
|
113
|
+
FOREIGN KEY (run_id) REFERENCES training_runs(id)
|
|
114
|
+
)
|
|
115
|
+
SQL
|
|
116
|
+
|
|
117
|
+
@db.execute(<<~SQL)
|
|
118
|
+
CREATE INDEX IF NOT EXISTS idx_metrics_run_id ON metrics(run_id, iteration)
|
|
119
|
+
SQL
|
|
120
|
+
|
|
121
|
+
@db.execute(<<~SQL)
|
|
122
|
+
CREATE TABLE IF NOT EXISTS checkpoints (
|
|
123
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
124
|
+
run_id INTEGER NOT NULL,
|
|
125
|
+
path TEXT NOT NULL,
|
|
126
|
+
suffix TEXT NOT NULL,
|
|
127
|
+
iteration INTEGER NOT NULL,
|
|
128
|
+
val_loss REAL,
|
|
129
|
+
saved_at TEXT NOT NULL,
|
|
130
|
+
FOREIGN KEY (run_id) REFERENCES training_runs(id)
|
|
131
|
+
)
|
|
132
|
+
SQL
|
|
133
|
+
end
|
|
134
|
+
end
|
|
135
|
+
end
|
|
136
|
+
end
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "sinatra/base"
|
|
4
|
+
require "json"
|
|
5
|
+
require "tempfile"
|
|
6
|
+
require "fileutils"
|
|
7
|
+
|
|
8
|
+
module NanoGPT
|
|
9
|
+
module Web
|
|
10
|
+
class Server < Sinatra::Base
|
|
11
|
+
set :views, File.join(__dir__, "views")
|
|
12
|
+
set :server, :webrick
|
|
13
|
+
set :logging, true
|
|
14
|
+
|
|
15
|
+
# Shared state - initialized before server starts
|
|
16
|
+
class << self
|
|
17
|
+
attr_accessor :training_state, :metrics_store, :sse_notifier, :training_worker
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# ---- Pages ----
|
|
21
|
+
|
|
22
|
+
get "/" do
|
|
23
|
+
erb :runs, layout: :layout
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
get "/runs/:id" do
|
|
27
|
+
run = self.class.metrics_store.get_run(params[:id].to_i)
|
|
28
|
+
halt 404, "Run not found" unless run
|
|
29
|
+
@run = run
|
|
30
|
+
erb :run_detail, layout: :layout
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# ---- Training Status ----
|
|
34
|
+
|
|
35
|
+
get "/train/status" do
|
|
36
|
+
content_type :json
|
|
37
|
+
self.class.training_state.to_json
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# ---- Start Training ----
|
|
41
|
+
|
|
42
|
+
post "/train/start" do
|
|
43
|
+
content_type :json
|
|
44
|
+
state = self.class.training_state
|
|
45
|
+
|
|
46
|
+
if state[:status] == "running"
|
|
47
|
+
halt 409, JSON.generate(error: "Training already running")
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
body = JSON.parse(request.body.read) rescue {}
|
|
51
|
+
config = build_train_config(body)
|
|
52
|
+
|
|
53
|
+
data_dir = File.join("data", config[:dataset])
|
|
54
|
+
train_bin = File.join(data_dir, "train.bin")
|
|
55
|
+
unless File.exist?(train_bin)
|
|
56
|
+
halt 422, JSON.generate(error: "Dataset not found: #{config[:dataset]}. Run prepare first.")
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
state.reset_stop!
|
|
60
|
+
state.update(status: "running", dataset: config[:dataset], max_iters: config[:max_iters])
|
|
61
|
+
|
|
62
|
+
run_id = self.class.metrics_store.create_run(
|
|
63
|
+
dataset: config[:dataset],
|
|
64
|
+
config: config
|
|
65
|
+
)
|
|
66
|
+
state.update(run_id: run_id)
|
|
67
|
+
|
|
68
|
+
self.class.training_worker.enqueue(:start, config: config, data_dir: data_dir, run_id: run_id)
|
|
69
|
+
|
|
70
|
+
JSON.generate(run_id: run_id, status: "started")
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
# ---- Stop Training ----
|
|
74
|
+
|
|
75
|
+
post "/train/stop" do
|
|
76
|
+
content_type :json
|
|
77
|
+
request.body.read rescue nil
|
|
78
|
+
state = self.class.training_state
|
|
79
|
+
|
|
80
|
+
unless state[:status] == "running"
|
|
81
|
+
halt 409, JSON.generate(error: "No training in progress")
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
state.request_stop!
|
|
85
|
+
JSON.generate(status: "stop_requested")
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# ---- Resume Training ----
|
|
89
|
+
|
|
90
|
+
post "/train/resume" do
|
|
91
|
+
content_type :json
|
|
92
|
+
state = self.class.training_state
|
|
93
|
+
|
|
94
|
+
if state[:status] == "running"
|
|
95
|
+
halt 409, JSON.generate(error: "Training already running")
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
body = JSON.parse(request.body.read) rescue {}
|
|
99
|
+
config = build_train_config(body)
|
|
100
|
+
|
|
101
|
+
ckpt_path = body["checkpoint_path"]
|
|
102
|
+
if ckpt_path.nil? || !File.exist?(ckpt_path)
|
|
103
|
+
ckpt_path = File.join(config[:out_dir], "ckpt.pt")
|
|
104
|
+
end
|
|
105
|
+
unless File.exist?(ckpt_path)
|
|
106
|
+
halt 422, JSON.generate(error: "No checkpoint found at #{ckpt_path}")
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
data_dir = File.join("data", config[:dataset])
|
|
110
|
+
unless File.exist?(File.join(data_dir, "train.bin"))
|
|
111
|
+
halt 422, JSON.generate(error: "Dataset not found: #{config[:dataset]}")
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
state.reset_stop!
|
|
115
|
+
state.update(status: "running", dataset: config[:dataset], max_iters: config[:max_iters])
|
|
116
|
+
|
|
117
|
+
run_id = self.class.metrics_store.create_run(
|
|
118
|
+
dataset: config[:dataset],
|
|
119
|
+
config: config,
|
|
120
|
+
status: "running"
|
|
121
|
+
)
|
|
122
|
+
state.update(run_id: run_id)
|
|
123
|
+
|
|
124
|
+
self.class.training_worker.enqueue(:resume,
|
|
125
|
+
config: config, data_dir: data_dir, run_id: run_id, checkpoint_path: ckpt_path)
|
|
126
|
+
|
|
127
|
+
JSON.generate(run_id: run_id, status: "resumed")
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# ---- Metrics polling ----
|
|
131
|
+
|
|
132
|
+
get "/metrics/poll" do
|
|
133
|
+
content_type :json
|
|
134
|
+
state = self.class.training_state.to_h
|
|
135
|
+
run_id = state[:run_id]
|
|
136
|
+
since_iter = (params[:since_iter] || 0).to_i
|
|
137
|
+
recent = {}
|
|
138
|
+
if run_id
|
|
139
|
+
all_metrics = self.class.metrics_store.metrics_for_run(run_id)
|
|
140
|
+
all_metrics.each do |row|
|
|
141
|
+
next if row["iteration"] <= since_iter
|
|
142
|
+
type = row["metric_type"]
|
|
143
|
+
recent[type] ||= []
|
|
144
|
+
recent[type] << { iteration: row["iteration"], value: row["value"] }
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
JSON.generate(state: state, metrics: recent)
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
# ---- Runs API ----
|
|
151
|
+
|
|
152
|
+
get "/api/runs" do
|
|
153
|
+
content_type :json
|
|
154
|
+
runs = self.class.metrics_store.list_runs
|
|
155
|
+
JSON.generate(runs: runs)
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
get "/api/runs/:id" do
|
|
159
|
+
content_type :json
|
|
160
|
+
run = self.class.metrics_store.get_run(params[:id].to_i)
|
|
161
|
+
halt 404, JSON.generate(error: "Run not found") unless run
|
|
162
|
+
JSON.generate(run: run)
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
get "/api/runs/:id/metrics" do
|
|
166
|
+
content_type :json
|
|
167
|
+
run_id = params[:id].to_i
|
|
168
|
+
run = self.class.metrics_store.get_run(run_id)
|
|
169
|
+
halt 404, JSON.generate(error: "Run not found") unless run
|
|
170
|
+
|
|
171
|
+
metrics = self.class.metrics_store.metrics_for_run(run_id)
|
|
172
|
+
grouped = {}
|
|
173
|
+
metrics.each do |row|
|
|
174
|
+
type = row["metric_type"]
|
|
175
|
+
grouped[type] ||= []
|
|
176
|
+
grouped[type] << { iteration: row["iteration"], value: row["value"] }
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
JSON.generate(run: run, metrics: grouped)
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
get "/api/runs/:id/checkpoints" do
|
|
183
|
+
content_type :json
|
|
184
|
+
run_id = params[:id].to_i
|
|
185
|
+
checkpoints = self.class.metrics_store.checkpoints_for_run(run_id).dup
|
|
186
|
+
|
|
187
|
+
# Also include legacy checkpoint_path from the run record
|
|
188
|
+
run = self.class.metrics_store.get_run(run_id)
|
|
189
|
+
if run && run["checkpoint_path"] && File.exist?(run["checkpoint_path"])
|
|
190
|
+
unless checkpoints.any? { |c| c["path"] == run["checkpoint_path"] }
|
|
191
|
+
checkpoints.unshift({
|
|
192
|
+
"id" => nil, "run_id" => run_id, "path" => run["checkpoint_path"],
|
|
193
|
+
"suffix" => "legacy", "iteration" => run["current_iter"] || 0,
|
|
194
|
+
"val_loss" => run["best_val_loss"], "saved_at" => run["stopped_at"]
|
|
195
|
+
})
|
|
196
|
+
end
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
JSON.generate(checkpoints: checkpoints)
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
# ---- Text Generation API ----
|
|
203
|
+
|
|
204
|
+
post "/generate/run" do
|
|
205
|
+
content_type :json
|
|
206
|
+
body = JSON.parse(request.body.read) rescue {}
|
|
207
|
+
|
|
208
|
+
prompt = body["prompt"] || "\n"
|
|
209
|
+
temperature = (body["temperature"] || 0.8).to_f
|
|
210
|
+
max_tokens = (body["max_tokens"] || 200).to_i
|
|
211
|
+
top_k = body["top_k"]&.to_i
|
|
212
|
+
checkpoint_path = body["checkpoint_path"]
|
|
213
|
+
dataset = body["dataset"] || "shakespeare_char"
|
|
214
|
+
|
|
215
|
+
if checkpoint_path && !File.exist?(checkpoint_path)
|
|
216
|
+
halt 422, JSON.generate(error: "Checkpoint not found: #{checkpoint_path}")
|
|
217
|
+
end
|
|
218
|
+
|
|
219
|
+
unless checkpoint_path
|
|
220
|
+
out_dir = body["out_dir"] || "out-shakespeare-char"
|
|
221
|
+
checkpoint_path = File.join(out_dir, "ckpt.pt")
|
|
222
|
+
unless File.exist?(checkpoint_path)
|
|
223
|
+
halt 422, JSON.generate(error: "No checkpoint found at #{checkpoint_path}")
|
|
224
|
+
end
|
|
225
|
+
end
|
|
226
|
+
|
|
227
|
+
if self.class.training_state[:status] == "running"
|
|
228
|
+
halt 409, JSON.generate(error: "Cannot generate while training is running. Stop training first.")
|
|
229
|
+
end
|
|
230
|
+
|
|
231
|
+
result = self.class.training_worker.enqueue_sync(:generate,
|
|
232
|
+
prompt: prompt, temperature: temperature, max_tokens: max_tokens,
|
|
233
|
+
top_k: top_k, checkpoint_path: checkpoint_path, dataset: dataset)
|
|
234
|
+
if result[:ok]
|
|
235
|
+
JSON.generate(text: result[:text])
|
|
236
|
+
else
|
|
237
|
+
halt 500, JSON.generate(error: result[:error])
|
|
238
|
+
end
|
|
239
|
+
end
|
|
240
|
+
|
|
241
|
+
# ---- Dataset Upload ----
|
|
242
|
+
|
|
243
|
+
post "/datasets/upload" do
|
|
244
|
+
content_type :json
|
|
245
|
+
|
|
246
|
+
unless params[:file] && params[:file][:tempfile]
|
|
247
|
+
halt 422, JSON.generate(error: "No file uploaded")
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
name = params[:name]&.strip
|
|
251
|
+
name = File.basename(params[:file][:filename], ".*").gsub(/[^a-zA-Z0-9_-]/, "_") if name.nil? || name.empty?
|
|
252
|
+
tokenizer_type = params[:tokenizer] || "char"
|
|
253
|
+
val_ratio = (params[:val_ratio] || 0.1).to_f
|
|
254
|
+
|
|
255
|
+
tmp_path = File.join(Dir.tmpdir, "nanogpt_upload_#{name}_#{Time.now.to_i}.txt")
|
|
256
|
+
FileUtils.cp(params[:file][:tempfile].path, tmp_path)
|
|
257
|
+
|
|
258
|
+
result = self.class.training_worker.enqueue_sync(:prepare_dataset,
|
|
259
|
+
input_path: tmp_path, output_name: name, tokenizer: tokenizer_type, val_ratio: val_ratio)
|
|
260
|
+
|
|
261
|
+
File.delete(tmp_path) if File.exist?(tmp_path)
|
|
262
|
+
|
|
263
|
+
if result[:ok]
|
|
264
|
+
JSON.generate(ok: true, dataset: result[:dataset])
|
|
265
|
+
else
|
|
266
|
+
halt 500, JSON.generate(ok: false, error: result[:error])
|
|
267
|
+
end
|
|
268
|
+
end
|
|
269
|
+
|
|
270
|
+
# ---- Available Datasets ----
|
|
271
|
+
|
|
272
|
+
get "/datasets" do
|
|
273
|
+
content_type :json
|
|
274
|
+
datasets = Dir.glob("data/*/train.bin").map do |path|
|
|
275
|
+
File.basename(File.dirname(path))
|
|
276
|
+
end.sort
|
|
277
|
+
JSON.generate(datasets: datasets)
|
|
278
|
+
end
|
|
279
|
+
|
|
280
|
+
private
|
|
281
|
+
|
|
282
|
+
def build_train_config(overrides = {})
|
|
283
|
+
config = TrainConfig.new.to_h
|
|
284
|
+
overrides.each do |key, value|
|
|
285
|
+
sym = key.to_sym
|
|
286
|
+
config[sym] = value if config.key?(sym)
|
|
287
|
+
end
|
|
288
|
+
config[:device] = Device.auto
|
|
289
|
+
config[:out_dir] = "out-#{config[:dataset]}" if overrides["dataset"] && !overrides["out_dir"]
|
|
290
|
+
config
|
|
291
|
+
end
|
|
292
|
+
end
|
|
293
|
+
end
|
|
294
|
+
end
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module NanoGPT
|
|
6
|
+
module Web
|
|
7
|
+
# Manages Server-Sent Events connections and broadcasts
|
|
8
|
+
class SSENotifier
|
|
9
|
+
def initialize
|
|
10
|
+
@connections = []
|
|
11
|
+
@mutex = Mutex.new
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def add(connection)
|
|
15
|
+
@mutex.synchronize { @connections << connection }
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def remove(connection)
|
|
19
|
+
@mutex.synchronize { @connections.delete(connection) }
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def broadcast(type:, data:)
|
|
23
|
+
payload = "event: #{type}\ndata: #{JSON.generate(data)}\n\n"
|
|
24
|
+
@mutex.synchronize do
|
|
25
|
+
@connections.reject! do |conn|
|
|
26
|
+
begin
|
|
27
|
+
conn << payload
|
|
28
|
+
false
|
|
29
|
+
rescue IOError, Errno::EPIPE
|
|
30
|
+
true
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
end
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module NanoGPT
|
|
6
|
+
module Web
|
|
7
|
+
# Thread-safe shared state for training status
|
|
8
|
+
class TrainingState
|
|
9
|
+
STATUSES = %w[idle running stopped completed].freeze
|
|
10
|
+
|
|
11
|
+
def initialize
|
|
12
|
+
@mutex = Mutex.new
|
|
13
|
+
@state = {
|
|
14
|
+
status: "idle",
|
|
15
|
+
run_id: nil,
|
|
16
|
+
current_iter: 0,
|
|
17
|
+
current_loss: nil,
|
|
18
|
+
best_val_loss: nil,
|
|
19
|
+
max_iters: 0,
|
|
20
|
+
dataset: nil
|
|
21
|
+
}
|
|
22
|
+
@stop_requested = false
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def update(**attrs)
|
|
26
|
+
@mutex.synchronize do
|
|
27
|
+
attrs.each { |k, v| @state[k] = v if @state.key?(k) }
|
|
28
|
+
end
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def [](key)
|
|
32
|
+
@mutex.synchronize { @state[key] }
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def request_stop!
|
|
36
|
+
@mutex.synchronize { @stop_requested = true }
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def stop_requested?
|
|
40
|
+
@mutex.synchronize { @stop_requested }
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
def reset_stop!
|
|
44
|
+
@mutex.synchronize { @stop_requested = false }
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def to_json(*)
|
|
48
|
+
@mutex.synchronize { JSON.generate(@state) }
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
def to_h
|
|
52
|
+
@mutex.synchronize { @state.dup }
|
|
53
|
+
end
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|