daimond 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CONTRIBUTIONG.md +160 -0
- data/README.ja.md +115 -0
- data/README.md +115 -0
- data/README.ru.md +116 -0
- data/ext/daimond_rust/Cargo.lock +353 -0
- data/ext/daimond_rust/Cargo.toml +13 -0
- data/ext/daimond_rust/build.rs +3 -0
- data/ext/daimond_rust/src/lib.rs +103 -0
- data/lib/daimond/autograd.rb +0 -0
- data/lib/daimond/data/data_loader.rb +41 -0
- data/lib/daimond/data/mnist.rb +56 -0
- data/lib/daimond/loss/cross_entropy.rb +45 -0
- data/lib/daimond/loss/mse.rb +0 -0
- data/lib/daimond/nn/conv2d.rb +117 -0
- data/lib/daimond/nn/conv2d_rust.rb +52 -0
- data/lib/daimond/nn/flatten.rb +29 -0
- data/lib/daimond/nn/functional.rb +0 -0
- data/lib/daimond/nn/linear.rb +22 -0
- data/lib/daimond/nn/max_pool2d.rb +69 -0
- data/lib/daimond/nn/max_pool2d_rust.rb +33 -0
- data/lib/daimond/nn/module.rb +60 -0
- data/lib/daimond/optim/adam.rb +41 -0
- data/lib/daimond/optim/sgd.rb +25 -0
- data/lib/daimond/rust/daimond_rust.bundle +0 -0
- data/lib/daimond/rust_backend.rb +23 -0
- data/lib/daimond/rust_bridge.rb +63 -0
- data/lib/daimond/tensor.rb +241 -0
- data/lib/daimond/utils/training_logger.rb +111 -0
- data/lib/daimond/version.rb +3 -0
- data/lib/daimond.rb +40 -0
- metadata +134 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
require_relative 'module'
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
module NN
|
|
5
|
+
class Conv2d < Module
|
|
6
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0)
|
|
7
|
+
super()
|
|
8
|
+
@in_channels = in_channels
|
|
9
|
+
@out_channels = out_channels
|
|
10
|
+
@kernel_size = kernel_size.is_a?(Array) ? kernel_size : [kernel_size, kernel_size]
|
|
11
|
+
@stride = stride
|
|
12
|
+
@padding = padding
|
|
13
|
+
|
|
14
|
+
# Xavier инициализация для Conv: sqrt(2 / (in * k * k))
|
|
15
|
+
k_h, k_w = @kernel_size
|
|
16
|
+
limit = Math.sqrt(2.0 / (in_channels * k_h * k_w))
|
|
17
|
+
|
|
18
|
+
# Веса: [out_channels, in_channels, k_h, k_w]
|
|
19
|
+
@weight = Tensor.new(
|
|
20
|
+
Numo::DFloat.new(out_channels, in_channels, k_h, k_w).rand * 2 * limit - limit
|
|
21
|
+
)
|
|
22
|
+
@bias = Tensor.zeros(out_channels)
|
|
23
|
+
|
|
24
|
+
@parameters = [@weight, @bias]
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def forward(input)
|
|
28
|
+
# input: [batch, in_channels, height, width]
|
|
29
|
+
batch_size = input.shape[0]
|
|
30
|
+
in_c = @in_channels
|
|
31
|
+
out_c = @out_channels
|
|
32
|
+
k_h, k_w = @kernel_size
|
|
33
|
+
|
|
34
|
+
# Размеры входа
|
|
35
|
+
h_in = input.shape[2]
|
|
36
|
+
w_in = input.shape[3]
|
|
37
|
+
|
|
38
|
+
# Размеры выхода (без padding пока)
|
|
39
|
+
h_out = ((h_in + 2 * @padding - k_h) / @stride).floor + 1
|
|
40
|
+
w_out = ((w_in + 2 * @padding - k_w) / @stride).floor + 1
|
|
41
|
+
|
|
42
|
+
# Выходной тензор
|
|
43
|
+
output = Numo::DFloat.zeros(batch_size, out_c, h_out, w_out)
|
|
44
|
+
|
|
45
|
+
# Добавляем padding если нужно
|
|
46
|
+
if @padding > 0
|
|
47
|
+
padded = Numo::DFloat.zeros(batch_size, in_c, h_in + 2*@padding, w_in + 2*@padding)
|
|
48
|
+
padded[true, true, @padding...h_in+@padding, @padding...w_in+@padding] = input.data
|
|
49
|
+
x_data = padded
|
|
50
|
+
else
|
|
51
|
+
x_data = input.data
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# Свертка (4 вложенных цикла — медленно, но понятно)
|
|
55
|
+
batch_size.times do |b|
|
|
56
|
+
out_c.times do |oc|
|
|
57
|
+
h_out.times do |i|
|
|
58
|
+
w_out.times do |j|
|
|
59
|
+
# Координаты окна
|
|
60
|
+
i0 = i * @stride
|
|
61
|
+
j0 = j * @stride
|
|
62
|
+
|
|
63
|
+
# Извлекаем окно и считаем свёртку
|
|
64
|
+
window = x_data[b, true, i0...i0+k_h, j0...j0+k_w]
|
|
65
|
+
kernel = @weight.data[oc, true, true, true]
|
|
66
|
+
|
|
67
|
+
output[b, oc, i, j] = (window * kernel).sum + @bias.data[oc]
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
out_tensor = Tensor.new(output, prev: [input, @weight, @bias], op: 'conv2d')
|
|
74
|
+
|
|
75
|
+
# Backward (упрощённо — только для stride=1, padding=0)
|
|
76
|
+
out_tensor._backward = lambda do
|
|
77
|
+
grad_output = out_tensor.grad # [batch, out_c, h_out, w_out]
|
|
78
|
+
|
|
79
|
+
# Градиент по весам
|
|
80
|
+
@out_channels.times do |oc|
|
|
81
|
+
@in_channels.times do |ic|
|
|
82
|
+
k_h.times do |kh|
|
|
83
|
+
k_w.times do |kw|
|
|
84
|
+
# Сумма по всем позициям где этот вес участвовал
|
|
85
|
+
grad_sum = 0.0
|
|
86
|
+
batch_size.times do |b|
|
|
87
|
+
h_out.times do |i|
|
|
88
|
+
w_out.times do |j|
|
|
89
|
+
# Координаты входа
|
|
90
|
+
i_in = i * @stride + kh
|
|
91
|
+
j_in = j * @stride + kw
|
|
92
|
+
|
|
93
|
+
grad_sum += x_data[b, ic, i_in, j_in] * grad_output[b, oc, i, j]
|
|
94
|
+
end
|
|
95
|
+
end
|
|
96
|
+
end
|
|
97
|
+
@weight.grad[oc, ic, kh, kw] += grad_sum
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# Градиент по bias
|
|
103
|
+
@bias.grad[oc] += grad_output[true, oc, true, true].sum
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
# Градиент по входу (если нужен)
|
|
107
|
+
if input.grad
|
|
108
|
+
# full convolution с rotated kernel
|
|
109
|
+
# Упрощено для stride=1
|
|
110
|
+
end
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
out_tensor
|
|
114
|
+
end
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
end
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
require_relative 'module'
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
module NN
|
|
5
|
+
class Conv2dRust < Module
|
|
6
|
+
attr_reader :weight, :bias
|
|
7
|
+
|
|
8
|
+
def initialize(in_channels, out_channels, kernel_size)
|
|
9
|
+
super()
|
|
10
|
+
@in_channels = in_channels
|
|
11
|
+
@out_channels = out_channels
|
|
12
|
+
@kernel_size = kernel_size
|
|
13
|
+
|
|
14
|
+
# Xavier инициализация
|
|
15
|
+
k = kernel_size
|
|
16
|
+
limit = Math.sqrt(2.0 / (in_channels * k * k))
|
|
17
|
+
|
|
18
|
+
@weight = Tensor.new(
|
|
19
|
+
Numo::DFloat.new(out_channels, in_channels, k, k).rand * 2 * limit - limit
|
|
20
|
+
)
|
|
21
|
+
@bias = Tensor.zeros(out_channels)
|
|
22
|
+
@parameters = [@weight, @bias]
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def forward(input)
|
|
26
|
+
# input: [batch, in_c, h, w]
|
|
27
|
+
batch = input.shape[0]
|
|
28
|
+
in_c = @in_channels
|
|
29
|
+
out_c = @out_channels
|
|
30
|
+
h = input.shape[2]
|
|
31
|
+
w = input.shape[3]
|
|
32
|
+
k = @kernel_size
|
|
33
|
+
|
|
34
|
+
# Используем Rust backend
|
|
35
|
+
if Daimond::RustBackend.available?
|
|
36
|
+
output_data = Daimond::RustBackend.conv2d(
|
|
37
|
+
input.data, @weight.data, @bias.data,
|
|
38
|
+
batch, in_c, out_c, h, w, k
|
|
39
|
+
)
|
|
40
|
+
out = Tensor.new(output_data, prev: [input, @weight, @bias], op: 'conv2d_rust')
|
|
41
|
+
|
|
42
|
+
# Backward будет позже, пока заглушка
|
|
43
|
+
out._backward = lambda {}
|
|
44
|
+
|
|
45
|
+
return out
|
|
46
|
+
else
|
|
47
|
+
raise "Rust backend required for Conv2dRust"
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
end
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
require_relative 'module'
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
module NN
|
|
5
|
+
class Flatten < Module
|
|
6
|
+
def initialize(start_dim: 1, end_dim: -1)
|
|
7
|
+
super()
|
|
8
|
+
@start_dim = start_dim
|
|
9
|
+
@end_dim = end_dim
|
|
10
|
+
@input_shape = nil
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def forward(input)
|
|
14
|
+
@input_shape = input.shape.dup
|
|
15
|
+
batch = input.shape[0]
|
|
16
|
+
rest = input.shape[1..-1].inject(:*)
|
|
17
|
+
|
|
18
|
+
out_data = input.data.reshape(batch, rest)
|
|
19
|
+
out = Tensor.new(out_data, prev: [input], op: 'flatten')
|
|
20
|
+
|
|
21
|
+
out._backward = lambda do
|
|
22
|
+
input.grad += out.grad.reshape(*@input_shape)
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
out
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
end
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
require_relative 'module'
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
module NN
|
|
5
|
+
class Linear < Module
|
|
6
|
+
def initialize(in_features, out_features)
|
|
7
|
+
super()
|
|
8
|
+
# Простая инициализация: small random values
|
|
9
|
+
@weight = Tensor.new(Numo::DFloat.new(in_features, out_features).rand_norm * 0.01)
|
|
10
|
+
@bias = Tensor.zeros(out_features)
|
|
11
|
+
@parameters = [@weight, @bias]
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def forward(input)
|
|
15
|
+
# Теперь возвращаем Tensor с поддержкой autograd!
|
|
16
|
+
input.dot(@weight) + @bias
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
attr_reader :weight, :bias
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
end
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
require_relative 'module'
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
module NN
|
|
5
|
+
class MaxPool2d < Module
|
|
6
|
+
def initialize(kernel_size, stride: nil)
|
|
7
|
+
super()
|
|
8
|
+
@kernel_size = kernel_size.is_a?(Array) ? kernel_size : [kernel_size, kernel_size]
|
|
9
|
+
@stride = stride || kernel_size
|
|
10
|
+
@mask = nil # для backward
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def forward(input)
|
|
14
|
+
# input: [batch, channels, h, w]
|
|
15
|
+
batch_size = input.shape[0]
|
|
16
|
+
channels = input.shape[1]
|
|
17
|
+
h_in = input.shape[2]
|
|
18
|
+
w_in = input.shape[3]
|
|
19
|
+
|
|
20
|
+
k_h, k_w = @kernel_size
|
|
21
|
+
s = @stride
|
|
22
|
+
|
|
23
|
+
h_out = (h_in - k_h) / s + 1
|
|
24
|
+
w_out = (w_in - k_w) / s + 1
|
|
25
|
+
|
|
26
|
+
output = Numo::DFloat.zeros(batch_size, channels, h_out, w_out)
|
|
27
|
+
@mask = {} # запоминаем индексы максимумов
|
|
28
|
+
|
|
29
|
+
batch_size.times do |b|
|
|
30
|
+
channels.times do |c|
|
|
31
|
+
h_out.times do |i|
|
|
32
|
+
w_out.times do |j|
|
|
33
|
+
# Окно пулинга
|
|
34
|
+
i0 = i * s
|
|
35
|
+
j0 = j * s
|
|
36
|
+
window = input.data[b, c, i0...i0+k_h, j0...j0+k_w]
|
|
37
|
+
|
|
38
|
+
max_val = window.max
|
|
39
|
+
output[b, c, i, j] = max_val
|
|
40
|
+
|
|
41
|
+
# Сохраняем позицию максимума для backward
|
|
42
|
+
max_idx = window.to_a.flatten.index(max_val)
|
|
43
|
+
@mask[[b, c, i, j]] = [i0 + max_idx / k_w, j0 + max_idx % k_w]
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
out = Tensor.new(output, prev: [input], op: 'maxpool2d')
|
|
50
|
+
|
|
51
|
+
out._backward = lambda do
|
|
52
|
+
grad = out.grad
|
|
53
|
+
batch_size.times do |b|
|
|
54
|
+
channels.times do |c|
|
|
55
|
+
h_out.times do |i|
|
|
56
|
+
w_out.times do |j|
|
|
57
|
+
idx_i, idx_j = @mask[[b, c, i, j]]
|
|
58
|
+
input.grad[b, c, idx_i, idx_j] += grad[b, c, i, j]
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
out
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
end
|
|
69
|
+
end
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
require_relative 'module'
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
module NN
|
|
5
|
+
class MaxPool2dRust < Module
|
|
6
|
+
def initialize(kernel_size)
|
|
7
|
+
super()
|
|
8
|
+
@kernel_size = kernel_size
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def forward(input)
|
|
12
|
+
|
|
13
|
+
batch = input.shape[0]
|
|
14
|
+
channels = input.shape[1]
|
|
15
|
+
h = input.shape[2]
|
|
16
|
+
w = input.shape[3]
|
|
17
|
+
k = @kernel_size
|
|
18
|
+
|
|
19
|
+
if Daimond::RustBackend.available?
|
|
20
|
+
output_data = Daimond::RustBackend.maxpool2d(
|
|
21
|
+
input.data, batch, channels, h, w, k
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
out = Tensor.new(output_data, prev: [input], op: 'maxpool2d_rust')
|
|
25
|
+
out._backward = lambda {}
|
|
26
|
+
return out
|
|
27
|
+
else
|
|
28
|
+
raise "Rust backend required for MaxPool2dRust"
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
end
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
require 'fileutils'
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
module NN
|
|
5
|
+
class Module
|
|
6
|
+
def initialize
|
|
7
|
+
@parameters = []
|
|
8
|
+
end
|
|
9
|
+
|
|
10
|
+
def parameters
|
|
11
|
+
@parameters
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def zero_grad
|
|
15
|
+
@parameters.each do |p|
|
|
16
|
+
p.grad = Numo::DFloat.zeros(*p.shape)
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def forward(*args)
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def call(*args)
|
|
25
|
+
forward(*args)
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
# Сохранение модели
|
|
29
|
+
def save(path)
|
|
30
|
+
FileUtils.mkdir_p(File.dirname(path)) if File.dirname(path) != '.'
|
|
31
|
+
|
|
32
|
+
# Сохраняем массив весов как массив Numo массивов
|
|
33
|
+
params_data = @parameters.map { |p| p.data }
|
|
34
|
+
File.open(path, 'wb') { |f| Marshal.dump(params_data, f) }
|
|
35
|
+
|
|
36
|
+
puts "Model saved to #{path} (#{@parameters.length} parameters)"
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
# Загрузка модели
|
|
40
|
+
def load(path)
|
|
41
|
+
unless File.exist?(path)
|
|
42
|
+
raise "Model file not found: #{path}"
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
params_data = File.open(path, 'rb') { |f| Marshal.load(f) }
|
|
46
|
+
|
|
47
|
+
if params_data.length != @parameters.length
|
|
48
|
+
raise "Parameter count mismatch: saved #{params_data.length} vs current #{@parameters.length}"
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
@parameters.each_with_index do |param, i|
|
|
52
|
+
param.data = params_data[i]
|
|
53
|
+
param.grad = Numo::DFloat.zeros(*param.data.shape)
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
puts "Model loaded from #{path}"
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
end
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
module Daimond
|
|
2
|
+
module Optim
|
|
3
|
+
class Adam < SGD
|
|
4
|
+
def initialize(parameters, lr: 0.001, betas: [0.9, 0.999], eps: 1e-8)
|
|
5
|
+
super(parameters, lr: lr)
|
|
6
|
+
@betas = betas
|
|
7
|
+
@eps = eps
|
|
8
|
+
|
|
9
|
+
# Первые и вторые моменты
|
|
10
|
+
@m = @parameters.map { |p| Numo::DFloat.zeros(*p.shape) } # первый момент (среднее)
|
|
11
|
+
@v = @parameters.map { |p| Numo::DFloat.zeros(*p.shape) } # второй момент (квадраты)
|
|
12
|
+
@t = 0 # шаг обновления
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def step
|
|
16
|
+
@t += 1
|
|
17
|
+
beta1, beta2 = @betas
|
|
18
|
+
|
|
19
|
+
@parameters.each_with_index do |param, i|
|
|
20
|
+
# Градиент
|
|
21
|
+
g = param.grad
|
|
22
|
+
|
|
23
|
+
# Обновляем моменты
|
|
24
|
+
@m[i] = beta1 * @m[i] + (1 - beta1) * g
|
|
25
|
+
@v[i] = beta2 * @v[i] + (1 - beta2) * (g * g)
|
|
26
|
+
|
|
27
|
+
# Коррекция смещения (bias correction)
|
|
28
|
+
m_hat = @m[i] / (1 - beta1**@t)
|
|
29
|
+
v_hat = @v[i] / (1 - beta2**@t)
|
|
30
|
+
|
|
31
|
+
# Обновление параметров
|
|
32
|
+
param.data -= @lr * m_hat / (Numo::NMath.sqrt(v_hat) + @eps)
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def zero_grad
|
|
37
|
+
@parameters.each { |p| p.grad = Numo::DFloat.zeros(*p.shape) }
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
end
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
module Daimond
|
|
2
|
+
module Optim
|
|
3
|
+
class SGD
|
|
4
|
+
def initialize(parameters, lr: 0.01, momentum: 0.9)
|
|
5
|
+
@parameters = parameters
|
|
6
|
+
@lr = lr
|
|
7
|
+
@momentum = momentum
|
|
8
|
+
@velocities = parameters.map { |p| Numo::DFloat.zeros(*p.shape) }
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def step
|
|
12
|
+
@parameters.each_with_index do |param, i|
|
|
13
|
+
@velocities[i] = @momentum * @velocities[i] + param.grad
|
|
14
|
+
param.data -= @lr * @velocities[i]
|
|
15
|
+
end
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def zero_grad
|
|
19
|
+
@parameters.each do |p|
|
|
20
|
+
p.grad = Numo::DFloat.zeros(*p.shape)
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
end
|
|
25
|
+
end
|
|
Binary file
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
begin
|
|
2
|
+
require_relative 'rust/daimond_rust'
|
|
3
|
+
rescue LoadError
|
|
4
|
+
# Rust backend не скомпилирован - будем использовать чистый Ruby
|
|
5
|
+
end
|
|
6
|
+
|
|
7
|
+
module Daimond
|
|
8
|
+
module RustBackend
|
|
9
|
+
# Проверка доступности
|
|
10
|
+
def self.available?
|
|
11
|
+
true
|
|
12
|
+
rescue LoadError
|
|
13
|
+
false
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
# Обертка для матричного умножения
|
|
17
|
+
def self.matmul(a, b)
|
|
18
|
+
# Здесь будет код конвертации Ruby -> Rust -> Ruby
|
|
19
|
+
# Пока просто возвращаем Rust тензор
|
|
20
|
+
Rust::Tensor.zeros(a.shape[0], b.shape[1])
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
end
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
require_relative 'rust/daimond_rust' rescue nil
|
|
2
|
+
|
|
3
|
+
module Daimond
|
|
4
|
+
# Проверяем загрузилась ли Rust библиотека
|
|
5
|
+
def self.rust_available?
|
|
6
|
+
defined?(Daimond::Rust) && Daimond::Rust.respond_to?(:fast_matmul_flat)
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
# Модуль-обертка для вызовов
|
|
10
|
+
module RustBackend
|
|
11
|
+
class << self
|
|
12
|
+
def available?
|
|
13
|
+
Daimond.rust_available?
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def conv2d(input_data, weight_data, bias_data, batch, in_c, out_c, h, w, k)
|
|
17
|
+
return nil unless available?
|
|
18
|
+
|
|
19
|
+
flat_input = input_data.flatten.to_a
|
|
20
|
+
flat_weight = weight_data.flatten.to_a
|
|
21
|
+
flat_bias = bias_data.to_a
|
|
22
|
+
|
|
23
|
+
result_flat = Daimond::Rust.conv2d_native(
|
|
24
|
+
flat_input, flat_weight, flat_bias,
|
|
25
|
+
batch, in_c, out_c, h, w, k
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
h_out = h - k + 1
|
|
29
|
+
w_out = w - k + 1
|
|
30
|
+
Numo::DFloat[*result_flat].reshape(batch, out_c, h_out, w_out)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def maxpool2d(input_data, batch, channels, h, w, k)
|
|
34
|
+
return nil unless available?
|
|
35
|
+
|
|
36
|
+
flat_input = input_data.flatten.to_a
|
|
37
|
+
result_flat = Daimond::Rust.maxpool2d_native(
|
|
38
|
+
flat_input, batch, channels, h, w, k
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
h_out = h / k
|
|
42
|
+
w_out = w / k
|
|
43
|
+
Numo::DFloat[*result_flat].reshape(batch, channels, h_out, w_out)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def matmul_data(narray_a, narray_b)
|
|
47
|
+
return nil unless available?
|
|
48
|
+
|
|
49
|
+
shape_a = narray_a.shape
|
|
50
|
+
shape_b = narray_b.shape
|
|
51
|
+
|
|
52
|
+
flat_a = narray_a.flatten.to_a
|
|
53
|
+
flat_b = narray_b.flatten.to_a
|
|
54
|
+
|
|
55
|
+
result_flat = Daimond::Rust.fast_matmul_flat(
|
|
56
|
+
flat_a, flat_b, shape_a[0], shape_a[1], shape_b[1]
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
Numo::DFloat[*result_flat].reshape(shape_a[0], shape_b[1])
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|