torch-rb 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +28 -0
  3. data/LICENSE.txt +46 -0
  4. data/README.md +426 -0
  5. data/ext/torch/ext.cpp +839 -0
  6. data/ext/torch/extconf.rb +25 -0
  7. data/lib/torch-rb.rb +1 -0
  8. data/lib/torch.rb +422 -0
  9. data/lib/torch/ext.bundle +0 -0
  10. data/lib/torch/inspector.rb +85 -0
  11. data/lib/torch/nn/alpha_dropout.rb +9 -0
  12. data/lib/torch/nn/conv2d.rb +37 -0
  13. data/lib/torch/nn/convnd.rb +41 -0
  14. data/lib/torch/nn/dropout.rb +9 -0
  15. data/lib/torch/nn/dropout2d.rb +9 -0
  16. data/lib/torch/nn/dropout3d.rb +9 -0
  17. data/lib/torch/nn/dropoutnd.rb +15 -0
  18. data/lib/torch/nn/embedding.rb +52 -0
  19. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  20. data/lib/torch/nn/functional.rb +100 -0
  21. data/lib/torch/nn/init.rb +30 -0
  22. data/lib/torch/nn/linear.rb +36 -0
  23. data/lib/torch/nn/module.rb +85 -0
  24. data/lib/torch/nn/mse_loss.rb +13 -0
  25. data/lib/torch/nn/parameter.rb +14 -0
  26. data/lib/torch/nn/relu.rb +13 -0
  27. data/lib/torch/nn/sequential.rb +29 -0
  28. data/lib/torch/optim/adadelta.rb +57 -0
  29. data/lib/torch/optim/adagrad.rb +71 -0
  30. data/lib/torch/optim/adam.rb +81 -0
  31. data/lib/torch/optim/adamax.rb +68 -0
  32. data/lib/torch/optim/adamw.rb +82 -0
  33. data/lib/torch/optim/asgd.rb +65 -0
  34. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  35. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  36. data/lib/torch/optim/optimizer.rb +62 -0
  37. data/lib/torch/optim/rmsprop.rb +76 -0
  38. data/lib/torch/optim/rprop.rb +68 -0
  39. data/lib/torch/optim/sgd.rb +60 -0
  40. data/lib/torch/tensor.rb +196 -0
  41. data/lib/torch/utils/data/data_loader.rb +27 -0
  42. data/lib/torch/utils/data/tensor_dataset.rb +22 -0
  43. data/lib/torch/version.rb +3 -0
  44. metadata +169 -0
@@ -0,0 +1,65 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
2
+ module Torch
3
+ module Optim
4
+ class ASGD < Optimizer
5
+ def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
8
+
9
+ defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ loss = nil
15
+ if closure
16
+ loss = closure.call
17
+ end
18
+
19
+ @param_groups.each do |group|
20
+ group[:params].each do |p|
21
+ next unless p.grad
22
+ grad = p.grad.data
23
+ if grad.sparse?
24
+ raise Error, "ASGD does not support sparse gradients"
25
+ end
26
+ state = @state[p]
27
+
28
+ # State initialization
29
+ if state.size == 0
30
+ state[:step] = 0
31
+ state[:eta] = group[:lr]
32
+ state[:mu] = 1
33
+ state[:ax] = Torch.zeros_like(p.data)
34
+ end
35
+
36
+ state[:step] += 1
37
+
38
+ if group[:weight_decay] != 0
39
+ grad = grad.add(group[:weight_decay], p.data)
40
+ end
41
+
42
+ # decay term
43
+ p.data.mul!(1 - group[:lambd] * state[:eta])
44
+
45
+ # update parameter
46
+ p.data.add!(-state[:eta], grad)
47
+
48
+ # averaging
49
+ if state[:mu] != 1
50
+ state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
51
+ else
52
+ state[:ax].copy!(p.data)
53
+ end
54
+
55
+ # update eta and mu
56
+ state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
57
+ state[:mu] = 1 / [1, state[:step] - group[:t0]].max
58
+ end
59
+ end
60
+
61
+ loss
62
+ end
63
+ end
64
+ end
65
+ end
@@ -0,0 +1,33 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class LRScheduler
5
+ def initialize(optimizer, last_epoch)
6
+ @optimizer = optimizer
7
+ if last_epoch == -1
8
+ optimizer.param_groups.each do |group|
9
+ group[:initial_lr] ||= group[:lr]
10
+ end
11
+ last_epoch = 0
12
+ else
13
+ raise NotImplementedYet
14
+ end
15
+ @base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
16
+ @last_epoch = last_epoch
17
+
18
+ @step_count = 0
19
+ step(last_epoch)
20
+ end
21
+
22
+ def step(epoch = nil)
23
+ @step_count += 1
24
+ epoch ||= @last_epoch + 1
25
+ @last_epoch = epoch
26
+ @optimizer.param_groups.zip(get_lr).each do |param_group, lr|
27
+ param_group[:lr] = lr
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class StepLR < LRScheduler
5
+ def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
6
+ @step_size = step_size
7
+ @gamma = gamma
8
+ super(optimizer, last_epoch)
9
+ end
10
+
11
+ def get_lr
12
+ @base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
13
+ end
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,62 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
2
+ module Torch
3
+ module Optim
4
+ class Optimizer
5
+ attr_reader :param_groups
6
+
7
+ def initialize(params, defaults)
8
+ @defaults = defaults
9
+ @state = Hash.new { |hash, key| hash[key] = {} }
10
+ @param_groups = []
11
+
12
+ param_groups = params
13
+ if param_groups.empty?
14
+ raise ArgumentError, "optimizer got an empty parameter list"
15
+ end
16
+ if !param_groups[0].is_a?(Hash)
17
+ param_groups = [{params: param_groups}]
18
+ end
19
+
20
+ param_groups.each do |param_group|
21
+ add_param_group(param_group)
22
+ end
23
+ end
24
+
25
+ def add_param_group(param_group)
26
+ # TODO more advanced logic
27
+ @param_groups << @defaults.merge(param_group)
28
+ end
29
+
30
+ def load_state_dict(state_dict)
31
+ raise NotImplementedYet
32
+ end
33
+
34
+ def state_dict
35
+ pack_group = lambda do |group|
36
+ packed = group.select { |k, _| k != :params }.to_h
37
+ packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed
39
+ end
40
+
41
+ param_groups = @param_groups.map { |g| pack_group.call(g) }
42
+ packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
+
44
+ {
45
+ state: packed_state,
46
+ param_groups: param_groups
47
+ }
48
+ end
49
+
50
+ def zero_grad
51
+ @param_groups.each do |group|
52
+ group[:params].each do |p|
53
+ if p.grad
54
+ p.grad.detach!
55
+ p.grad.zero!
56
+ end
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,76 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
2
+ module Torch
3
+ module Optim
4
+ class RMSprop < Optimizer
5
+ def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
8
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
9
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
10
+ raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
11
+
12
+ defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
13
+ super(params, defaults)
14
+ end
15
+
16
+ def step(closure = nil)
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "RMSprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:square_avg] = Torch.zeros_like(p.data)
35
+ if group[:momentum] > 0
36
+ state[:momentum_buffer] = Torch.zeros_like(p.data)
37
+ end
38
+ if group[:centered]
39
+ state[:grad_avg] = Torch.zeros_like(p.data)
40
+ end
41
+ end
42
+
43
+ square_avg = state[:square_avg]
44
+ alpha = group[:alpha]
45
+
46
+ state[:step] += 1
47
+
48
+ if group[:weight_decay] != 0
49
+ grad = grad.add(group[:weight_decay], p.data)
50
+ end
51
+
52
+ square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
53
+
54
+ if group[:centered]
55
+ grad_avg = state[:grad_avg]
56
+ grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
+ avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
58
+ else
59
+ avg = square_avg.sqrt.add!(group[:eps])
60
+ end
61
+
62
+ if group[:momentum] > 0
63
+ buf = state[:momentum_buffer]
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
+ p.data.add!(-group[:lr], buf)
66
+ else
67
+ p.data.addcdiv!(-group[:lr], grad, avg)
68
+ end
69
+ end
70
+ end
71
+
72
+ loss
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
2
+ module Torch
3
+ module Optim
4
+ class Rprop < Optimizer
5
+ def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
8
+
9
+ defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ # TODO implement []=
15
+ raise NotImplementedYet
16
+
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Rprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:prev] = Torch.zeros_like(p.data)
35
+ state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
36
+ end
37
+
38
+ etaminus, etaplus = group[:etas]
39
+ step_size_min, step_size_max = group[:step_sizes]
40
+ step_size = state[:step_size]
41
+
42
+ state[:step] += 1
43
+
44
+ sign = grad.mul(state[:prev]).sign
45
+ sign[sign.gt(0)] = etaplus
46
+ sign[sign.lt(0)] = etaminus
47
+ sign[sign.eq(0)] = 1
48
+
49
+ # update stepsizes with step size updates
50
+ step_size.mul!(sign).clamp!(step_size_min, step_size_max)
51
+
52
+ # for dir<0, dfdx=0
53
+ # for dir>=0 dfdx=dfdx
54
+ grad = grad.clone
55
+ grad[sign.eq(etaminus)] = 0
56
+
57
+ # update parameters
58
+ p.data.addcmul!(-1, grad.sign, step_size)
59
+
60
+ state[:prev].copy!(grad)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,60 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
2
+ module Torch
3
+ module Optim
4
+ class SGD < Optimizer
5
+ def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
7
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
8
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
9
+
10
+ defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
11
+
12
+ if nesterov && (momentum <= 0 || dampening != 0)
13
+ raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
14
+ end
15
+
16
+ super(params, defaults)
17
+ end
18
+
19
+ def step(closure = nil)
20
+ loss = nil
21
+ if closure
22
+ loss = closure.call
23
+ end
24
+
25
+ @param_groups.each do |group|
26
+ weight_decay = group[:weight_decay]
27
+ momentum = group[:momentum]
28
+ dampening = group[:dampening]
29
+ nesterov = group[:nesterov]
30
+
31
+ group[:params].each do |p|
32
+ next unless p.grad
33
+ d_p = p.grad.data
34
+ if weight_decay != 0
35
+ d_p.add!(weight_decay, p.data)
36
+ end
37
+ if momentum != 0
38
+ param_state = @state[p]
39
+ if !param_state.key(:momentum_buffer)
40
+ buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
+ else
42
+ buf = param_state[:momentum_buffer]
43
+ buf.mul!(momentum).add!(1 - dampening, d_p)
44
+ end
45
+ if nesterov
46
+ d_p = d_p.add(momentum, buf)
47
+ else
48
+ d_p = buf
49
+ end
50
+ end
51
+
52
+ p.data.add!(-group[:lr], d_p)
53
+ end
54
+ end
55
+
56
+ loss
57
+ end
58
+ end
59
+ end
60
+ end
@@ -0,0 +1,196 @@
1
+ module Torch
2
+ class Tensor
3
+ include Comparable
4
+ include Inspector
5
+
6
+ alias_method :requires_grad?, :requires_grad
7
+
8
+ def self.new(*size)
9
+ if size.length == 1 && size.first.is_a?(Tensor)
10
+ size.first
11
+ else
12
+ Torch.empty(*size)
13
+ end
14
+ end
15
+
16
+ def dtype
17
+ dtype = ENUM_TO_DTYPE[_dtype]
18
+ raise Error, "Unknown type: #{_dtype}" unless dtype
19
+ dtype
20
+ end
21
+
22
+ def layout
23
+ _layout.downcase.to_sym
24
+ end
25
+
26
+ def to_s
27
+ inspect
28
+ end
29
+
30
+ def to_a
31
+ reshape_arr(_data, shape)
32
+ end
33
+
34
+ # TODO support dtype
35
+ def to(device, non_blocking: false, copy: false)
36
+ device = Device.new(device) if device.is_a?(String)
37
+ _to(device, _dtype, non_blocking, copy)
38
+ end
39
+
40
+ def size(dim = nil)
41
+ if dim
42
+ _size(dim)
43
+ else
44
+ shape
45
+ end
46
+ end
47
+
48
+ def shape
49
+ dim.times.map { |i| size(i) }
50
+ end
51
+
52
+ def view(*size)
53
+ _view(size)
54
+ end
55
+
56
+ def item
57
+ if numel != 1
58
+ raise Error, "only one element tensors can be converted to Ruby scalars"
59
+ end
60
+ _data.first
61
+ end
62
+
63
+ # unsure if this is correct
64
+ def new
65
+ Torch.empty(0, dtype: dtype)
66
+ end
67
+
68
+ def backward(gradient = nil)
69
+ if gradient
70
+ _backward_gradient(gradient)
71
+ else
72
+ _backward
73
+ end
74
+ end
75
+
76
+ # TODO read directly from memory
77
+ def numo
78
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
79
+ cls = Torch._dtype_to_numo[dtype]
80
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
81
+ cls.cast(_data).reshape(*shape)
82
+ end
83
+
84
+ def new_ones(*size, **options)
85
+ Torch.ones_like(Torch.empty(*size), **options)
86
+ end
87
+
88
+ def requires_grad!(requires_grad = true)
89
+ _requires_grad!(requires_grad)
90
+ end
91
+
92
+ def type(dtype)
93
+ enum = DTYPE_TO_ENUM[dtype]
94
+ raise Error, "Unknown type: #{dtype}" unless enum
95
+ _type(enum)
96
+ end
97
+
98
+ def add!(value = 1, other)
99
+ if other.is_a?(Numeric)
100
+ _add_scalar!(other * value)
101
+ else
102
+ # need to use alpha for sparse tensors instead of multiplying
103
+ _add_alpha!(other, value)
104
+ end
105
+ end
106
+
107
+ def mul!(other)
108
+ if other.is_a?(Numeric)
109
+ _mul_scalar!(other)
110
+ else
111
+ _mul!(other)
112
+ end
113
+ end
114
+
115
+ # operations
116
+ %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
117
+ define_method(op) do |*args, **options, &block|
118
+ if options.any?
119
+ Torch.send(op, self, *args, **options, &block)
120
+ else
121
+ Torch.send(op, self, *args, &block)
122
+ end
123
+ end
124
+ end
125
+
126
+ def +(other)
127
+ add(other)
128
+ end
129
+
130
+ def -(other)
131
+ sub(other)
132
+ end
133
+
134
+ def *(other)
135
+ mul(other)
136
+ end
137
+
138
+ def /(other)
139
+ div(other)
140
+ end
141
+
142
+ def %(other)
143
+ remainder(other)
144
+ end
145
+
146
+ def **(other)
147
+ pow(other)
148
+ end
149
+
150
+ def -@
151
+ neg
152
+ end
153
+
154
+ def <=>(other)
155
+ item <=> other
156
+ end
157
+
158
+ # based on python_variable_indexing.cpp
159
+ def [](*indexes)
160
+ result = self
161
+ dim = 0
162
+ indexes.each do |index|
163
+ if index.is_a?(Numeric)
164
+ result = result._select(dim, index)
165
+ elsif index.is_a?(Range)
166
+ finish = index.end
167
+ finish += 1 unless index.exclude_end?
168
+ result = result._slice(dim, index.begin, finish, 1)
169
+ dim += 1
170
+ else
171
+ raise Error, "Unsupported index type"
172
+ end
173
+ end
174
+ result
175
+ end
176
+
177
+ # TODO
178
+ # based on python_variable_indexing.cpp
179
+ # def []=(index, value)
180
+ # end
181
+
182
+ private
183
+
184
+ def reshape_arr(arr, dims)
185
+ if dims.empty?
186
+ arr
187
+ else
188
+ arr = arr.flatten
189
+ dims[1..-1].reverse.each do |dim|
190
+ arr = arr.each_slice(dim)
191
+ end
192
+ arr.to_a
193
+ end
194
+ end
195
+ end
196
+ end