torch-rb 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +46 -22
- data/README.md +14 -5
- data/ext/torch/ext.cpp +248 -31
- data/lib/torch.rb +80 -9
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +4 -3
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +12 -24
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +54 -12
- data/lib/torch/nn/linear.rb +2 -2
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/tensor.rb +38 -4
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +21 -3
@@ -0,0 +1,33 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class LRScheduler
|
5
|
+
def initialize(optimizer, last_epoch)
|
6
|
+
@optimizer = optimizer
|
7
|
+
if last_epoch == -1
|
8
|
+
optimizer.param_groups.each do |group|
|
9
|
+
group[:initial_lr] ||= group[:lr]
|
10
|
+
end
|
11
|
+
last_epoch = 0
|
12
|
+
else
|
13
|
+
raise NotImplementedYet
|
14
|
+
end
|
15
|
+
@base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
|
16
|
+
@last_epoch = last_epoch
|
17
|
+
|
18
|
+
@step_count = 0
|
19
|
+
step(last_epoch)
|
20
|
+
end
|
21
|
+
|
22
|
+
def step(epoch = nil)
|
23
|
+
@step_count += 1
|
24
|
+
epoch ||= @last_epoch + 1
|
25
|
+
@last_epoch = epoch
|
26
|
+
@optimizer.param_groups.zip(get_lr).each do |param_group, lr|
|
27
|
+
param_group[:lr] = lr
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class StepLR < LRScheduler
|
5
|
+
def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
|
6
|
+
@step_size = step_size
|
7
|
+
@gamma = gamma
|
8
|
+
super(optimizer, last_epoch)
|
9
|
+
end
|
10
|
+
|
11
|
+
def get_lr
|
12
|
+
@base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -1,6 +1,62 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
|
1
2
|
module Torch
|
2
3
|
module Optim
|
3
4
|
class Optimizer
|
5
|
+
attr_reader :param_groups
|
6
|
+
|
7
|
+
def initialize(params, defaults)
|
8
|
+
@defaults = defaults
|
9
|
+
@state = Hash.new { |hash, key| hash[key] = {} }
|
10
|
+
@param_groups = []
|
11
|
+
|
12
|
+
param_groups = params
|
13
|
+
if param_groups.empty?
|
14
|
+
raise ArgumentError, "optimizer got an empty parameter list"
|
15
|
+
end
|
16
|
+
if !param_groups[0].is_a?(Hash)
|
17
|
+
param_groups = [{params: param_groups}]
|
18
|
+
end
|
19
|
+
|
20
|
+
param_groups.each do |param_group|
|
21
|
+
add_param_group(param_group)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def add_param_group(param_group)
|
26
|
+
# TODO more advanced logic
|
27
|
+
@param_groups << @defaults.merge(param_group)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load_state_dict(state_dict)
|
31
|
+
raise NotImplementedYet
|
32
|
+
end
|
33
|
+
|
34
|
+
def state_dict
|
35
|
+
pack_group = lambda do |group|
|
36
|
+
packed = group.select { |k, _| k != :params }.to_h
|
37
|
+
packed[:params] = group[:params].map { |p| p.object_id }
|
38
|
+
packed
|
39
|
+
end
|
40
|
+
|
41
|
+
param_groups = @param_groups.map { |g| pack_group.call(g) }
|
42
|
+
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
|
+
|
44
|
+
{
|
45
|
+
state: packed_state,
|
46
|
+
param_groups: param_groups
|
47
|
+
}
|
48
|
+
end
|
49
|
+
|
50
|
+
def zero_grad
|
51
|
+
@param_groups.each do |group|
|
52
|
+
group[:params].each do |p|
|
53
|
+
if p.grad
|
54
|
+
p.grad.detach!
|
55
|
+
p.grad.zero!
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
4
60
|
end
|
5
61
|
end
|
6
62
|
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class RMSprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "RMSprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:square_avg] = Torch.zeros_like(p.data)
|
35
|
+
if group[:momentum] > 0
|
36
|
+
state[:momentum_buffer] = Torch.zeros_like(p.data)
|
37
|
+
end
|
38
|
+
if group[:centered]
|
39
|
+
state[:grad_avg] = Torch.zeros_like(p.data)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
square_avg = state[:square_avg]
|
44
|
+
alpha = group[:alpha]
|
45
|
+
|
46
|
+
state[:step] += 1
|
47
|
+
|
48
|
+
if group[:weight_decay] != 0
|
49
|
+
grad = grad.add(group[:weight_decay], p.data)
|
50
|
+
end
|
51
|
+
|
52
|
+
square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
|
53
|
+
|
54
|
+
if group[:centered]
|
55
|
+
grad_avg = state[:grad_avg]
|
56
|
+
grad_avg.mul!(alpha).add!(1 - alpha, grad)
|
57
|
+
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
|
58
|
+
else
|
59
|
+
avg = square_avg.sqrt.add!(group[:eps])
|
60
|
+
end
|
61
|
+
|
62
|
+
if group[:momentum] > 0
|
63
|
+
buf = state[:momentum_buffer]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
+
p.data.add!(-group[:lr], buf)
|
66
|
+
else
|
67
|
+
p.data.addcdiv!(-group[:lr], grad, avg)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
loss
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Rprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
|
8
|
+
|
9
|
+
defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
# TODO implement []=
|
15
|
+
raise NotImplementedYet
|
16
|
+
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Rprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:prev] = Torch.zeros_like(p.data)
|
35
|
+
state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
|
36
|
+
end
|
37
|
+
|
38
|
+
etaminus, etaplus = group[:etas]
|
39
|
+
step_size_min, step_size_max = group[:step_sizes]
|
40
|
+
step_size = state[:step_size]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
sign = grad.mul(state[:prev]).sign
|
45
|
+
sign[sign.gt(0)] = etaplus
|
46
|
+
sign[sign.lt(0)] = etaminus
|
47
|
+
sign[sign.eq(0)] = 1
|
48
|
+
|
49
|
+
# update stepsizes with step size updates
|
50
|
+
step_size.mul!(sign).clamp!(step_size_min, step_size_max)
|
51
|
+
|
52
|
+
# for dir<0, dfdx=0
|
53
|
+
# for dir>=0 dfdx=dfdx
|
54
|
+
grad = grad.clone
|
55
|
+
grad[sign.eq(etaminus)] = 0
|
56
|
+
|
57
|
+
# update parameters
|
58
|
+
p.data.addcmul!(-1, grad.sign, step_size)
|
59
|
+
|
60
|
+
state[:prev].copy!(grad)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
data/lib/torch/optim/sgd.rb
CHANGED
@@ -1,27 +1,59 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
|
1
2
|
module Torch
|
2
3
|
module Optim
|
3
4
|
class SGD < Optimizer
|
4
|
-
def initialize(params, lr:)
|
5
|
-
|
6
|
-
|
7
|
-
|
5
|
+
def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
|
7
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
|
8
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
param.grad.zero!
|
14
|
-
end
|
10
|
+
defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
|
11
|
+
|
12
|
+
if nesterov && (momentum <= 0 || dampening != 0)
|
13
|
+
raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
|
15
14
|
end
|
15
|
+
|
16
|
+
super(params, defaults)
|
16
17
|
end
|
17
18
|
|
18
|
-
def step
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
19
|
+
def step(closure = nil)
|
20
|
+
loss = nil
|
21
|
+
if closure
|
22
|
+
loss = closure.call
|
23
|
+
end
|
24
|
+
|
25
|
+
@param_groups.each do |group|
|
26
|
+
weight_decay = group[:weight_decay]
|
27
|
+
momentum = group[:momentum]
|
28
|
+
dampening = group[:dampening]
|
29
|
+
nesterov = group[:nesterov]
|
30
|
+
|
31
|
+
group[:params].each do |p|
|
32
|
+
next unless p.grad
|
33
|
+
d_p = p.grad.data
|
34
|
+
if weight_decay != 0
|
35
|
+
d_p.add!(weight_decay, p.data)
|
36
|
+
end
|
37
|
+
if momentum != 0
|
38
|
+
param_state = @state[p]
|
39
|
+
if !param_state.key(:momentum_buffer)
|
40
|
+
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
|
+
else
|
42
|
+
buf = param_state[:momentum_buffer]
|
43
|
+
buf.mul!(momentum).add!(1 - dampening, d_p)
|
44
|
+
end
|
45
|
+
if nesterov
|
46
|
+
d_p = d_p.add(momentum, buf)
|
47
|
+
else
|
48
|
+
d_p = buf
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
p.data.add!(-group[:lr], d_p)
|
53
|
+
end
|
24
54
|
end
|
55
|
+
|
56
|
+
loss
|
25
57
|
end
|
26
58
|
end
|
27
59
|
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -6,10 +6,10 @@ module Torch
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
7
|
|
8
8
|
def self.new(*size)
|
9
|
-
if size.first.is_a?(Tensor)
|
9
|
+
if size.length == 1 && size.first.is_a?(Tensor)
|
10
10
|
size.first
|
11
11
|
else
|
12
|
-
Torch.
|
12
|
+
Torch.empty(*size)
|
13
13
|
end
|
14
14
|
end
|
15
15
|
|
@@ -31,6 +31,12 @@ module Torch
|
|
31
31
|
reshape_arr(_data, shape)
|
32
32
|
end
|
33
33
|
|
34
|
+
# TODO support dtype
|
35
|
+
def to(device, non_blocking: false, copy: false)
|
36
|
+
device = Device.new(device) if device.is_a?(String)
|
37
|
+
_to(device, _dtype, non_blocking, copy)
|
38
|
+
end
|
39
|
+
|
34
40
|
def size(dim = nil)
|
35
41
|
if dim
|
36
42
|
_size(dim)
|
@@ -54,6 +60,11 @@ module Torch
|
|
54
60
|
_data.first
|
55
61
|
end
|
56
62
|
|
63
|
+
# unsure if this is correct
|
64
|
+
def new
|
65
|
+
Torch.empty(0, dtype: dtype)
|
66
|
+
end
|
67
|
+
|
57
68
|
def backward(gradient = nil)
|
58
69
|
if gradient
|
59
70
|
_backward_gradient(gradient)
|
@@ -84,8 +95,25 @@ module Torch
|
|
84
95
|
_type(enum)
|
85
96
|
end
|
86
97
|
|
98
|
+
def add!(value = 1, other)
|
99
|
+
if other.is_a?(Numeric)
|
100
|
+
_add_scalar!(other * value)
|
101
|
+
else
|
102
|
+
# need to use alpha for sparse tensors instead of multiplying
|
103
|
+
_add_alpha!(other, value)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
107
|
+
def mul!(other)
|
108
|
+
if other.is_a?(Numeric)
|
109
|
+
_mul_scalar!(other)
|
110
|
+
else
|
111
|
+
_mul!(other)
|
112
|
+
end
|
113
|
+
end
|
114
|
+
|
87
115
|
# operations
|
88
|
-
%w(add
|
116
|
+
%w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
|
89
117
|
define_method(op) do |*args, **options, &block|
|
90
118
|
if options.any?
|
91
119
|
Torch.send(op, self, *args, **options, &block)
|
@@ -127,10 +155,11 @@ module Torch
|
|
127
155
|
item <=> other
|
128
156
|
end
|
129
157
|
|
158
|
+
# based on python_variable_indexing.cpp
|
130
159
|
def [](*indexes)
|
131
160
|
result = self
|
132
161
|
dim = 0
|
133
|
-
indexes.
|
162
|
+
indexes.each do |index|
|
134
163
|
if index.is_a?(Numeric)
|
135
164
|
result = result._select(dim, index)
|
136
165
|
elsif index.is_a?(Range)
|
@@ -145,6 +174,11 @@ module Torch
|
|
145
174
|
result
|
146
175
|
end
|
147
176
|
|
177
|
+
# TODO
|
178
|
+
# based on python_variable_indexing.cpp
|
179
|
+
# def []=(index, value)
|
180
|
+
# end
|
181
|
+
|
148
182
|
private
|
149
183
|
|
150
184
|
def reshape_arr(arr, dims)
|
@@ -2,19 +2,25 @@ module Torch
|
|
2
2
|
module Utils
|
3
3
|
module Data
|
4
4
|
class DataLoader
|
5
|
+
include Enumerable
|
6
|
+
|
7
|
+
attr_reader :dataset
|
8
|
+
|
5
9
|
def initialize(dataset, batch_size: 1)
|
6
10
|
@dataset = dataset
|
7
11
|
@batch_size = batch_size
|
8
12
|
end
|
9
13
|
|
10
14
|
def each
|
11
|
-
size
|
12
|
-
|
13
|
-
while start_index < size
|
15
|
+
size.times do |i|
|
16
|
+
start_index = i * @batch_size
|
14
17
|
yield @dataset[start_index...(start_index + @batch_size)]
|
15
|
-
start_index += @batch_size
|
16
18
|
end
|
17
19
|
end
|
20
|
+
|
21
|
+
def size
|
22
|
+
(@dataset.size / @batch_size.to_f).ceil
|
23
|
+
end
|
18
24
|
end
|
19
25
|
end
|
20
26
|
end
|