ignis-dl 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-dl.rb +48 -0
- data/lib/nnw/ai/gpt2_loader.rb +144 -0
- data/lib/nnw/ai/inference.rb +224 -0
- data/lib/nnw/ai/kv_cache.rb +79 -0
- data/lib/nnw/ai/llama_loader.rb +100 -0
- data/lib/nnw/ai/loss.rb +170 -0
- data/lib/nnw/ai/nn/dropout.rb +68 -0
- data/lib/nnw/ai/nn/embedding.rb +86 -0
- data/lib/nnw/ai/nn/layer_norm.rb +54 -0
- data/lib/nnw/ai/nn/linear.rb +80 -0
- data/lib/nnw/ai/nn/module.rb +178 -0
- data/lib/nnw/ai/nn/rms_norm.rb +43 -0
- data/lib/nnw/ai/nn/sequential.rb +52 -0
- data/lib/nnw/ai/optim/adam.rb +63 -0
- data/lib/nnw/ai/optim/adamw.rb +63 -0
- data/lib/nnw/ai/optim/base.rb +90 -0
- data/lib/nnw/ai/optim/lr_scheduler.rb +118 -0
- data/lib/nnw/ai/optim/sgd.rb +49 -0
- data/lib/nnw/ai/safetensors.rb +220 -0
- data/lib/nnw/ai/server.rb +268 -0
- data/lib/nnw/ai/tokenizer.rb +413 -0
- data/lib/nnw/ai/trainer.rb +245 -0
- data/lib/nnw/ai/transformer/attention.rb +89 -0
- data/lib/nnw/ai/transformer/block.rb +90 -0
- data/lib/nnw/ai/transformer/feed_forward.rb +53 -0
- data/lib/nnw/ai/transformer/model.rb +189 -0
- data/lib/nnw/ai/transformer/modern.rb +191 -0
- data/lib/nnw/ai/transformer/swiglu.rb +39 -0
- data/lib/nnw/ai/weight_map.rb +139 -0
- metadata +91 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "module"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module AI
|
|
7
|
+
module NN
|
|
8
|
+
# Sequential container: chains modules in order.
|
|
9
|
+
# forward(x) passes x through each module sequentially.
|
|
10
|
+
class Sequential < Module
|
|
11
|
+
# @param modules [Array<Module>] modules to chain
|
|
12
|
+
def initialize(*modules)
|
|
13
|
+
super()
|
|
14
|
+
modules.each_with_index do |mod, i|
|
|
15
|
+
register_module(i.to_s, mod)
|
|
16
|
+
end
|
|
17
|
+
@layers = modules
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Forward pass: reduce input through all modules.
|
|
21
|
+
# @param x [Tensor]
|
|
22
|
+
# @return [Tensor]
|
|
23
|
+
def forward(x)
|
|
24
|
+
@layers.reduce(x) { |h, m| m.call(h) }
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
# Number of layers.
|
|
28
|
+
# @return [Integer]
|
|
29
|
+
def length
|
|
30
|
+
@layers.length
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Access layer by index.
|
|
34
|
+
# @param idx [Integer]
|
|
35
|
+
# @return [Module]
|
|
36
|
+
def [](idx)
|
|
37
|
+
@layers[idx]
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# @return [String]
|
|
41
|
+
def to_s
|
|
42
|
+
lines = ["Sequential("]
|
|
43
|
+
@layers.each_with_index do |layer, i|
|
|
44
|
+
lines << " (#{i}): #{layer}"
|
|
45
|
+
end
|
|
46
|
+
lines << ")"
|
|
47
|
+
lines.join("\n")
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
end
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "base"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module AI
|
|
7
|
+
module Optim
|
|
8
|
+
# Adam optimizer (Kingma & Ba, 2014).
|
|
9
|
+
# Fused single-kernel per parameter: updates m, v, and param in one launch.
|
|
10
|
+
class Adam < Base
|
|
11
|
+
# @param params [Array<Tensor>]
|
|
12
|
+
# @param lr [Float] learning rate
|
|
13
|
+
# @param beta1 [Float] exponential decay rate for first moment
|
|
14
|
+
# @param beta2 [Float] exponential decay rate for second moment
|
|
15
|
+
# @param eps [Float] numerical stability
|
|
16
|
+
# @param weight_decay [Float] L2 regularization
|
|
17
|
+
def initialize(params, lr: 1e-3, beta1: 0.9, beta2: 0.999,
|
|
18
|
+
eps: 1e-8, weight_decay: 0.0)
|
|
19
|
+
super(params, lr: lr)
|
|
20
|
+
@beta1 = beta1.to_f
|
|
21
|
+
@beta2 = beta2.to_f
|
|
22
|
+
@eps = eps.to_f
|
|
23
|
+
@weight_decay = weight_decay.to_f
|
|
24
|
+
|
|
25
|
+
# Initialize first and second moment estimates to zero
|
|
26
|
+
@m_states = {}
|
|
27
|
+
@v_states = {}
|
|
28
|
+
params.each do |p|
|
|
29
|
+
m = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
|
|
30
|
+
m.from_host(Array.new(p.numel, 0.0))
|
|
31
|
+
@m_states[p.object_id] = m
|
|
32
|
+
|
|
33
|
+
v = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
|
|
34
|
+
v.from_host(Array.new(p.numel, 0.0))
|
|
35
|
+
@v_states[p.object_id] = v
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
# @return [void]
|
|
40
|
+
def step
|
|
41
|
+
@step_count += 1
|
|
42
|
+
kernel = Ignis::JIT::Kernels::Optimizer.adam_step
|
|
43
|
+
|
|
44
|
+
bias_correction1 = 1.0 - @beta1**@step_count
|
|
45
|
+
bias_correction2 = 1.0 - @beta2**@step_count
|
|
46
|
+
|
|
47
|
+
@params.each do |p|
|
|
48
|
+
next unless p.grad
|
|
49
|
+
n = p.numel
|
|
50
|
+
m = @m_states[p.object_id]
|
|
51
|
+
v = @v_states[p.object_id]
|
|
52
|
+
|
|
53
|
+
kernel.launch(grid: [(n + 255) / 256], block: [256],
|
|
54
|
+
args: [p.data, p.grad, m, v,
|
|
55
|
+
@lr, @beta1, @beta2, @eps, @weight_decay,
|
|
56
|
+
bias_correction1.to_f, bias_correction2.to_f, n])
|
|
57
|
+
end
|
|
58
|
+
Ignis.synchronize
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "base"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module AI
|
|
7
|
+
module Optim
|
|
8
|
+
# AdamW optimizer (Loshchilov & Hutter, 2019).
|
|
9
|
+
# Key difference from Adam: weight decay is *decoupled* from gradient update.
|
|
10
|
+
# This means WD is applied directly to parameters, not through the gradient.
|
|
11
|
+
class AdamW < Base
|
|
12
|
+
# @param params [Array<Tensor>]
|
|
13
|
+
# @param lr [Float]
|
|
14
|
+
# @param beta1 [Float]
|
|
15
|
+
# @param beta2 [Float]
|
|
16
|
+
# @param eps [Float]
|
|
17
|
+
# @param weight_decay [Float] decoupled weight decay
|
|
18
|
+
def initialize(params, lr: 1e-3, beta1: 0.9, beta2: 0.999,
|
|
19
|
+
eps: 1e-8, weight_decay: 0.01)
|
|
20
|
+
super(params, lr: lr)
|
|
21
|
+
@beta1 = beta1.to_f
|
|
22
|
+
@beta2 = beta2.to_f
|
|
23
|
+
@eps = eps.to_f
|
|
24
|
+
@weight_decay = weight_decay.to_f
|
|
25
|
+
|
|
26
|
+
@m_states = {}
|
|
27
|
+
@v_states = {}
|
|
28
|
+
params.each do |p|
|
|
29
|
+
m = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
|
|
30
|
+
m.from_host(Array.new(p.numel, 0.0))
|
|
31
|
+
@m_states[p.object_id] = m
|
|
32
|
+
|
|
33
|
+
v = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
|
|
34
|
+
v.from_host(Array.new(p.numel, 0.0))
|
|
35
|
+
@v_states[p.object_id] = v
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
# @return [void]
|
|
40
|
+
def step
|
|
41
|
+
@step_count += 1
|
|
42
|
+
kernel = Ignis::JIT::Kernels::Optimizer.adamw_step
|
|
43
|
+
|
|
44
|
+
bias_correction1 = 1.0 - @beta1**@step_count
|
|
45
|
+
bias_correction2 = 1.0 - @beta2**@step_count
|
|
46
|
+
|
|
47
|
+
@params.each do |p|
|
|
48
|
+
next unless p.grad
|
|
49
|
+
n = p.numel
|
|
50
|
+
m = @m_states[p.object_id]
|
|
51
|
+
v = @v_states[p.object_id]
|
|
52
|
+
|
|
53
|
+
kernel.launch(grid: [(n + 255) / 256], block: [256],
|
|
54
|
+
args: [p.data, p.grad, m, v,
|
|
55
|
+
@lr, @beta1, @beta2, @eps, @weight_decay,
|
|
56
|
+
bias_correction1.to_f, bias_correction2.to_f, n])
|
|
57
|
+
end
|
|
58
|
+
Ignis.synchronize
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
module Optim
|
|
6
|
+
# Base optimizer class.
|
|
7
|
+
# Provides shared infrastructure for all optimizers.
|
|
8
|
+
class Base
|
|
9
|
+
# @return [Array<Tensor>] parameters being optimized
|
|
10
|
+
attr_reader :params
|
|
11
|
+
|
|
12
|
+
# @return [Integer] current step count for LR scheduling
|
|
13
|
+
attr_reader :step_count
|
|
14
|
+
|
|
15
|
+
# @param params [Array<Tensor>] parameters to optimize
|
|
16
|
+
# @param lr [Float] learning rate
|
|
17
|
+
def initialize(params, lr:)
|
|
18
|
+
@params = params
|
|
19
|
+
@lr = lr.to_f
|
|
20
|
+
@step_count = 0
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Perform one optimization step using accumulated gradients.
|
|
24
|
+
# @return [void]
|
|
25
|
+
def step
|
|
26
|
+
raise NotImplementedError, "Subclass must implement #step"
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
# Zero all parameter gradients.
|
|
30
|
+
# @return [void]
|
|
31
|
+
def zero_grad!
|
|
32
|
+
@params.each(&:zero_grad!)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Get current learning rate.
|
|
36
|
+
# @return [Float]
|
|
37
|
+
def lr
|
|
38
|
+
@lr
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
# Set learning rate (for schedulers).
|
|
42
|
+
# @param new_lr [Float]
|
|
43
|
+
# @return [void]
|
|
44
|
+
def lr=(new_lr)
|
|
45
|
+
@lr = new_lr.to_f
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
# Clip gradient global norm across all parameters.
|
|
49
|
+
# @param max_norm [Float]
|
|
50
|
+
# @return [Float] the actual total norm (before clipping)
|
|
51
|
+
def clip_grad_norm!(max_norm)
|
|
52
|
+
# Phase 1: compute total squared sum across all params
|
|
53
|
+
total_norm_nv = Ignis::Shared::NvArray.new(shape: [1], dtype: :float32,
|
|
54
|
+
device_id: @params.first.device_id)
|
|
55
|
+
total_norm_nv.from_host([0.0])
|
|
56
|
+
|
|
57
|
+
@params.each do |p|
|
|
58
|
+
next unless p.grad
|
|
59
|
+
|
|
60
|
+
sq_kernel = Ignis::JIT::Kernels::Optimizer.grad_squared_sum
|
|
61
|
+
n = p.numel
|
|
62
|
+
block_size = 256
|
|
63
|
+
grid_size = (n + block_size - 1) / block_size
|
|
64
|
+
sq_kernel.launch(grid: [grid_size], block: [block_size],
|
|
65
|
+
shared_mem: block_size * 4,
|
|
66
|
+
args: [p.grad, total_norm_nv, n])
|
|
67
|
+
end
|
|
68
|
+
Ignis.synchronize
|
|
69
|
+
|
|
70
|
+
total_norm = Math.sqrt(total_norm_nv.to_host[0])
|
|
71
|
+
|
|
72
|
+
if total_norm > max_norm
|
|
73
|
+
clip_factor = max_norm / (total_norm + 1e-6)
|
|
74
|
+
clip_kernel = Ignis::JIT::Kernels::Optimizer.grad_clip_scale
|
|
75
|
+
|
|
76
|
+
@params.each do |p|
|
|
77
|
+
next unless p.grad
|
|
78
|
+
n = p.numel
|
|
79
|
+
clip_kernel.launch(grid: [(n + 255) / 256], block: [256],
|
|
80
|
+
args: [p.grad, clip_factor.to_f, n])
|
|
81
|
+
end
|
|
82
|
+
Ignis.synchronize
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
total_norm
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
end
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
module Optim
|
|
6
|
+
# Learning rate schedulers.
|
|
7
|
+
# Adjust the optimizer's learning rate based on training progress.
|
|
8
|
+
module LRScheduler
|
|
9
|
+
# Linear warmup: linearly increase LR from 0 to base_lr over warmup_steps.
|
|
10
|
+
# After warmup, keep constant.
|
|
11
|
+
class LinearWarmup
|
|
12
|
+
# @param optimizer [Base]
|
|
13
|
+
# @param warmup_steps [Integer]
|
|
14
|
+
def initialize(optimizer, warmup_steps:)
|
|
15
|
+
@optimizer = optimizer
|
|
16
|
+
@warmup_steps = warmup_steps
|
|
17
|
+
@base_lr = optimizer.lr
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Call after each optimizer step to update LR.
|
|
21
|
+
# @return [Float] new learning rate
|
|
22
|
+
def step
|
|
23
|
+
current = @optimizer.step_count
|
|
24
|
+
if current < @warmup_steps
|
|
25
|
+
factor = current.to_f / @warmup_steps
|
|
26
|
+
new_lr = @base_lr * factor
|
|
27
|
+
@optimizer.lr = new_lr
|
|
28
|
+
new_lr
|
|
29
|
+
else
|
|
30
|
+
@optimizer.lr = @base_lr
|
|
31
|
+
@base_lr
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# @return [Float]
|
|
36
|
+
def current_lr
|
|
37
|
+
@optimizer.lr
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
# Cosine annealing: decay LR following a cosine curve.
|
|
42
|
+
# LR drops from base_lr to min_lr over total_steps.
|
|
43
|
+
class CosineAnnealing
|
|
44
|
+
# @param optimizer [Base]
|
|
45
|
+
# @param total_steps [Integer]
|
|
46
|
+
# @param min_lr [Float]
|
|
47
|
+
def initialize(optimizer, total_steps:, min_lr: 0.0)
|
|
48
|
+
@optimizer = optimizer
|
|
49
|
+
@total_steps = total_steps
|
|
50
|
+
@min_lr = min_lr.to_f
|
|
51
|
+
@base_lr = optimizer.lr
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# @return [Float]
|
|
55
|
+
def step
|
|
56
|
+
current = @optimizer.step_count
|
|
57
|
+
if current >= @total_steps
|
|
58
|
+
@optimizer.lr = @min_lr
|
|
59
|
+
return @min_lr
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
progress = current.to_f / @total_steps
|
|
63
|
+
new_lr = @min_lr + 0.5 * (@base_lr - @min_lr) * (1.0 + Math.cos(Math::PI * progress))
|
|
64
|
+
@optimizer.lr = new_lr
|
|
65
|
+
new_lr
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
# @return [Float]
|
|
69
|
+
def current_lr
|
|
70
|
+
@optimizer.lr
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
# Cosine with linear warmup: common combo for Transformer training.
|
|
75
|
+
class CosineWithWarmup
|
|
76
|
+
# @param optimizer [Base]
|
|
77
|
+
# @param warmup_steps [Integer]
|
|
78
|
+
# @param total_steps [Integer]
|
|
79
|
+
# @param min_lr [Float]
|
|
80
|
+
def initialize(optimizer, warmup_steps:, total_steps:, min_lr: 0.0)
|
|
81
|
+
@optimizer = optimizer
|
|
82
|
+
@warmup_steps = warmup_steps
|
|
83
|
+
@total_steps = total_steps
|
|
84
|
+
@min_lr = min_lr.to_f
|
|
85
|
+
@base_lr = optimizer.lr
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
# @return [Float]
|
|
89
|
+
def step
|
|
90
|
+
current = @optimizer.step_count
|
|
91
|
+
|
|
92
|
+
if current < @warmup_steps
|
|
93
|
+
# Linear warmup phase
|
|
94
|
+
factor = current.to_f / @warmup_steps
|
|
95
|
+
new_lr = @base_lr * factor
|
|
96
|
+
elsif current >= @total_steps
|
|
97
|
+
new_lr = @min_lr
|
|
98
|
+
else
|
|
99
|
+
# Cosine decay phase
|
|
100
|
+
decay_steps = @total_steps - @warmup_steps
|
|
101
|
+
decay_current = current - @warmup_steps
|
|
102
|
+
progress = decay_current.to_f / decay_steps
|
|
103
|
+
new_lr = @min_lr + 0.5 * (@base_lr - @min_lr) * (1.0 + Math.cos(Math::PI * progress))
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
@optimizer.lr = new_lr
|
|
107
|
+
new_lr
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# @return [Float]
|
|
111
|
+
def current_lr
|
|
112
|
+
@optimizer.lr
|
|
113
|
+
end
|
|
114
|
+
end
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
end
|
|
118
|
+
end
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "base"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module AI
|
|
7
|
+
module Optim
|
|
8
|
+
# SGD optimizer with optional momentum and weight decay.
|
|
9
|
+
class SGD < Base
|
|
10
|
+
# @param params [Array<Tensor>]
|
|
11
|
+
# @param lr [Float] learning rate
|
|
12
|
+
# @param momentum [Float] momentum factor (0.0 = no momentum)
|
|
13
|
+
# @param weight_decay [Float] L2 penalty
|
|
14
|
+
def initialize(params, lr:, momentum: 0.0, weight_decay: 0.0)
|
|
15
|
+
super(params, lr: lr)
|
|
16
|
+
@momentum = momentum.to_f
|
|
17
|
+
@weight_decay = weight_decay.to_f
|
|
18
|
+
|
|
19
|
+
@velocities = {}
|
|
20
|
+
if @momentum > 0.0
|
|
21
|
+
params.each do |p|
|
|
22
|
+
v = Ignis::Shared::NvArray.new(shape: p.shape, dtype: p.dtype, device_id: p.device_id)
|
|
23
|
+
v.from_host(Array.new(p.numel, 0.0))
|
|
24
|
+
@velocities[p.object_id] = v
|
|
25
|
+
end
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
# @return [void]
|
|
30
|
+
def step
|
|
31
|
+
kernel = Ignis::JIT::Kernels::Optimizer.sgd_step
|
|
32
|
+
|
|
33
|
+
@params.each do |p|
|
|
34
|
+
next unless p.grad
|
|
35
|
+
n = p.numel
|
|
36
|
+
vel = @velocities[p.object_id]
|
|
37
|
+
vel_ptr = vel || p.grad # If no momentum, velocity unused but still passed
|
|
38
|
+
|
|
39
|
+
kernel.launch(grid: [(n + 255) / 256], block: [256],
|
|
40
|
+
args: [p.data, p.grad, vel_ptr, @lr, @momentum, @weight_decay, n])
|
|
41
|
+
end
|
|
42
|
+
Ignis.synchronize
|
|
43
|
+
|
|
44
|
+
@step_count += 1
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
end
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module AI
|
|
7
|
+
# Safetensors — pure Ruby loader/saver for the HuggingFace safetensors format.
|
|
8
|
+
#
|
|
9
|
+
# Format:
|
|
10
|
+
# 8 bytes: LE uint64 header length
|
|
11
|
+
# N bytes: JSON header {"tensor_name": {"dtype": "F16", "shape": [d1,d2], "data_offsets": [begin, end]}}
|
|
12
|
+
# Remaining: raw tensor data (contiguous, little-endian, row-major)
|
|
13
|
+
#
|
|
14
|
+
# Zero-copy: uses IO.read + direct cudaMemcpy to GPU.
|
|
15
|
+
module Safetensors
|
|
16
|
+
# Dtype mapping: safetensors string → Ignis::Shared::NvArray symbol + Ruby unpack format
|
|
17
|
+
DTYPE_MAP = {
|
|
18
|
+
"F16" => { symbol: :float16, bytes: 2, pack: "v*" },
|
|
19
|
+
"BF16" => { symbol: :bfloat16, bytes: 2, pack: "v*" },
|
|
20
|
+
"F32" => { symbol: :float32, bytes: 4, pack: "e*" },
|
|
21
|
+
"F64" => { symbol: :float64, bytes: 8, pack: "E*" },
|
|
22
|
+
"I8" => { symbol: :uint8, bytes: 1, pack: "c*" },
|
|
23
|
+
"I16" => { symbol: :int32, bytes: 2, pack: "s<*" },
|
|
24
|
+
"I32" => { symbol: :int32, bytes: 4, pack: "l<*" },
|
|
25
|
+
"I64" => { symbol: :int64, bytes: 8, pack: "q<*" },
|
|
26
|
+
"U8" => { symbol: :uint8, bytes: 1, pack: "C*" },
|
|
27
|
+
"BOOL" => { symbol: :uint8, bytes: 1, pack: "C*" }
|
|
28
|
+
}.freeze
|
|
29
|
+
|
|
30
|
+
# Maximum header size (100MB — DOS protection per spec)
|
|
31
|
+
MAX_HEADER_SIZE = 100 * 1024 * 1024
|
|
32
|
+
|
|
33
|
+
class << self
|
|
34
|
+
# Load all tensors from a safetensors file.
|
|
35
|
+
# @param path [String] path to .safetensors file
|
|
36
|
+
# @param device_id [Integer] target GPU device
|
|
37
|
+
# @return [Hash{String => Tensor}] name → Tensor mapping
|
|
38
|
+
def load(path, device_id: 0)
|
|
39
|
+
raise ArgumentError, "File not found: #{path}" unless File.exist?(path)
|
|
40
|
+
|
|
41
|
+
File.open(path, "rb") do |f|
|
|
42
|
+
header_size_bytes = f.read(8)
|
|
43
|
+
header_size = header_size_bytes.unpack1("Q<")
|
|
44
|
+
|
|
45
|
+
raise "Header size #{header_size} exceeds maximum #{MAX_HEADER_SIZE}" if header_size > MAX_HEADER_SIZE
|
|
46
|
+
|
|
47
|
+
header_json = f.read(header_size)
|
|
48
|
+
header = JSON.parse(header_json)
|
|
49
|
+
|
|
50
|
+
# Data section starts after 8 + header_size bytes
|
|
51
|
+
data_offset = 8 + header_size
|
|
52
|
+
|
|
53
|
+
tensors = {}
|
|
54
|
+
|
|
55
|
+
header.each do |name, meta|
|
|
56
|
+
next if name == "__metadata__"
|
|
57
|
+
|
|
58
|
+
dtype_str = meta["dtype"]
|
|
59
|
+
shape = meta["shape"]
|
|
60
|
+
offsets = meta["data_offsets"]
|
|
61
|
+
|
|
62
|
+
dtype_info = DTYPE_MAP[dtype_str]
|
|
63
|
+
raise "Unsupported dtype: #{dtype_str}" unless dtype_info
|
|
64
|
+
|
|
65
|
+
begin_offset = offsets[0]
|
|
66
|
+
end_offset = offsets[1]
|
|
67
|
+
byte_count = end_offset - begin_offset
|
|
68
|
+
|
|
69
|
+
# Read raw bytes from file
|
|
70
|
+
f.seek(data_offset + begin_offset)
|
|
71
|
+
raw_bytes = f.read(byte_count)
|
|
72
|
+
raise "Short read for #{name}: expected #{byte_count}, got #{raw_bytes&.length}" unless raw_bytes&.length == byte_count
|
|
73
|
+
|
|
74
|
+
# Convert to NvArray on GPU
|
|
75
|
+
nv_array = bytes_to_nv_array(raw_bytes, shape, dtype_info, device_id)
|
|
76
|
+
tensors[name] = Tensor.new(data: nv_array, requires_grad: false)
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
tensors
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# Load tensors into a model using a weight map.
|
|
84
|
+
# @param model [NN::Module] the model to load into
|
|
85
|
+
# @param path [String] path to .safetensors file
|
|
86
|
+
# @param weight_map [Hash{String => String}, nil] HF name → Ignis name mapping
|
|
87
|
+
# @param strict [Boolean] fail on missing/extra keys
|
|
88
|
+
# @param device_id [Integer]
|
|
89
|
+
# @return [NN::Module]
|
|
90
|
+
def load_model(model, path, weight_map: nil, strict: false, device_id: 0)
|
|
91
|
+
tensors = load(path, device_id: device_id)
|
|
92
|
+
model_params = model.named_parameters
|
|
93
|
+
|
|
94
|
+
loaded_count = 0
|
|
95
|
+
skipped = []
|
|
96
|
+
|
|
97
|
+
tensors.each do |name, tensor|
|
|
98
|
+
mapped_name = weight_map ? (weight_map[name] || name) : name
|
|
99
|
+
|
|
100
|
+
if model_params.key?(mapped_name)
|
|
101
|
+
param = model_params[mapped_name]
|
|
102
|
+
if param.shape == tensor.shape
|
|
103
|
+
param.data.from_host(tensor.to_host)
|
|
104
|
+
loaded_count += 1
|
|
105
|
+
else
|
|
106
|
+
Ignis.logger.warn("Shape mismatch for #{mapped_name}: " \
|
|
107
|
+
"model=#{param.shape} file=#{tensor.shape}")
|
|
108
|
+
skipped << name
|
|
109
|
+
end
|
|
110
|
+
else
|
|
111
|
+
skipped << name
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
if strict
|
|
116
|
+
missing = model_params.keys - tensors.keys.map { |n| weight_map ? (weight_map[n] || n) : n }
|
|
117
|
+
unless missing.empty?
|
|
118
|
+
raise KeyError, "Missing weights: #{missing.join(', ')}"
|
|
119
|
+
end
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
Ignis.logger.info("Loaded #{loaded_count}/#{tensors.size} tensors, skipped #{skipped.size}")
|
|
123
|
+
model
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
# Save tensors to safetensors format.
|
|
127
|
+
# @param tensors [Hash{String => Tensor}] name → Tensor
|
|
128
|
+
# @param path [String] output file path
|
|
129
|
+
# @param metadata [Hash, nil] optional metadata
|
|
130
|
+
# @return [void]
|
|
131
|
+
def save(tensors, path, metadata: nil)
|
|
132
|
+
header = {}
|
|
133
|
+
header["__metadata__"] = metadata if metadata
|
|
134
|
+
|
|
135
|
+
# Calculate offsets
|
|
136
|
+
current_offset = 0
|
|
137
|
+
tensor_data = []
|
|
138
|
+
|
|
139
|
+
tensors.each do |name, tensor|
|
|
140
|
+
host_data = tensor.to_host
|
|
141
|
+
dtype_str = nv_dtype_to_safetensors(tensor.dtype)
|
|
142
|
+
bytes = host_values_to_bytes(host_data, tensor.dtype)
|
|
143
|
+
byte_count = bytes.length
|
|
144
|
+
|
|
145
|
+
header[name] = {
|
|
146
|
+
"dtype" => dtype_str,
|
|
147
|
+
"shape" => tensor.shape,
|
|
148
|
+
"data_offsets" => [current_offset, current_offset + byte_count]
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
tensor_data << bytes
|
|
152
|
+
current_offset += byte_count
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
header_json = JSON.generate(header)
|
|
156
|
+
# Pad header to 8-byte alignment
|
|
157
|
+
padding = (8 - (header_json.length % 8)) % 8
|
|
158
|
+
header_json += " " * padding
|
|
159
|
+
|
|
160
|
+
File.open(path, "wb") do |f|
|
|
161
|
+
f.write([header_json.length].pack("Q<"))
|
|
162
|
+
f.write(header_json)
|
|
163
|
+
tensor_data.each { |d| f.write(d) }
|
|
164
|
+
end
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
private
|
|
168
|
+
|
|
169
|
+
# Convert raw bytes to NvArray on GPU.
|
|
170
|
+
# @param raw_bytes [String]
|
|
171
|
+
# @param shape [Array<Integer>]
|
|
172
|
+
# @param dtype_info [Hash]
|
|
173
|
+
# @param device_id [Integer]
|
|
174
|
+
# @return [Ignis::Shared::NvArray]
|
|
175
|
+
def bytes_to_nv_array(raw_bytes, shape, dtype_info, device_id)
|
|
176
|
+
nv = Ignis::Shared::NvArray.new(shape: shape, dtype: dtype_info[:symbol], device_id: device_id)
|
|
177
|
+
|
|
178
|
+
# The safetensors payload is already stored in the tensor's native
|
|
179
|
+
# little-endian row-major layout — including IEEE-754 F16/BF16 — so we
|
|
180
|
+
# copy the raw bytes straight to the device. This is both correct
|
|
181
|
+
# (no lossy float<->half round trip) and zero-conversion.
|
|
182
|
+
nv.from_host_raw(raw_bytes)
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
# Convert Ignis dtype symbol to safetensors string.
|
|
186
|
+
# @param dtype [Symbol]
|
|
187
|
+
# @return [String]
|
|
188
|
+
def nv_dtype_to_safetensors(dtype)
|
|
189
|
+
case dtype
|
|
190
|
+
when :float16 then "F16"
|
|
191
|
+
when :bfloat16 then "BF16"
|
|
192
|
+
when :float32 then "F32"
|
|
193
|
+
when :float64 then "F64"
|
|
194
|
+
when :int32 then "I32"
|
|
195
|
+
when :int64 then "I64"
|
|
196
|
+
when :uint8 then "U8"
|
|
197
|
+
else raise "Unknown dtype: #{dtype}"
|
|
198
|
+
end
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
# Convert host values to raw bytes for saving.
|
|
202
|
+
# @param values [Array<Numeric>]
|
|
203
|
+
# @param dtype [Symbol]
|
|
204
|
+
# @return [String]
|
|
205
|
+
def host_values_to_bytes(values, dtype)
|
|
206
|
+
case dtype
|
|
207
|
+
when :float32 then values.pack("e*")
|
|
208
|
+
when :float64 then values.pack("E*")
|
|
209
|
+
when :int32 then values.pack("l<*")
|
|
210
|
+
when :int64 then values.pack("q<*")
|
|
211
|
+
when :uint8 then values.pack("C*")
|
|
212
|
+
when :float16 then values.map { |v| ::Ignis::Half.f32_to_f16(v) }.pack("v*")
|
|
213
|
+
when :bfloat16 then values.map { |v| ::Ignis::Half.f32_to_bf16(v) }.pack("v*")
|
|
214
|
+
else raise "Cannot save dtype: #{dtype}"
|
|
215
|
+
end
|
|
216
|
+
end
|
|
217
|
+
end
|
|
218
|
+
end
|
|
219
|
+
end
|
|
220
|
+
end
|