torch-rb 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +28 -0
  3. data/LICENSE.txt +46 -0
  4. data/README.md +426 -0
  5. data/ext/torch/ext.cpp +839 -0
  6. data/ext/torch/extconf.rb +25 -0
  7. data/lib/torch-rb.rb +1 -0
  8. data/lib/torch.rb +422 -0
  9. data/lib/torch/ext.bundle +0 -0
  10. data/lib/torch/inspector.rb +85 -0
  11. data/lib/torch/nn/alpha_dropout.rb +9 -0
  12. data/lib/torch/nn/conv2d.rb +37 -0
  13. data/lib/torch/nn/convnd.rb +41 -0
  14. data/lib/torch/nn/dropout.rb +9 -0
  15. data/lib/torch/nn/dropout2d.rb +9 -0
  16. data/lib/torch/nn/dropout3d.rb +9 -0
  17. data/lib/torch/nn/dropoutnd.rb +15 -0
  18. data/lib/torch/nn/embedding.rb +52 -0
  19. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  20. data/lib/torch/nn/functional.rb +100 -0
  21. data/lib/torch/nn/init.rb +30 -0
  22. data/lib/torch/nn/linear.rb +36 -0
  23. data/lib/torch/nn/module.rb +85 -0
  24. data/lib/torch/nn/mse_loss.rb +13 -0
  25. data/lib/torch/nn/parameter.rb +14 -0
  26. data/lib/torch/nn/relu.rb +13 -0
  27. data/lib/torch/nn/sequential.rb +29 -0
  28. data/lib/torch/optim/adadelta.rb +57 -0
  29. data/lib/torch/optim/adagrad.rb +71 -0
  30. data/lib/torch/optim/adam.rb +81 -0
  31. data/lib/torch/optim/adamax.rb +68 -0
  32. data/lib/torch/optim/adamw.rb +82 -0
  33. data/lib/torch/optim/asgd.rb +65 -0
  34. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  35. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  36. data/lib/torch/optim/optimizer.rb +62 -0
  37. data/lib/torch/optim/rmsprop.rb +76 -0
  38. data/lib/torch/optim/rprop.rb +68 -0
  39. data/lib/torch/optim/sgd.rb +60 -0
  40. data/lib/torch/tensor.rb +196 -0
  41. data/lib/torch/utils/data/data_loader.rb +27 -0
  42. data/lib/torch/utils/data/tensor_dataset.rb +22 -0
  43. data/lib/torch/version.rb +3 -0
  44. metadata +169 -0
@@ -0,0 +1,65 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
2
+ module Torch
3
+ module Optim
4
+ class ASGD < Optimizer
5
+ def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
8
+
9
+ defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ loss = nil
15
+ if closure
16
+ loss = closure.call
17
+ end
18
+
19
+ @param_groups.each do |group|
20
+ group[:params].each do |p|
21
+ next unless p.grad
22
+ grad = p.grad.data
23
+ if grad.sparse?
24
+ raise Error, "ASGD does not support sparse gradients"
25
+ end
26
+ state = @state[p]
27
+
28
+ # State initialization
29
+ if state.size == 0
30
+ state[:step] = 0
31
+ state[:eta] = group[:lr]
32
+ state[:mu] = 1
33
+ state[:ax] = Torch.zeros_like(p.data)
34
+ end
35
+
36
+ state[:step] += 1
37
+
38
+ if group[:weight_decay] != 0
39
+ grad = grad.add(group[:weight_decay], p.data)
40
+ end
41
+
42
+ # decay term
43
+ p.data.mul!(1 - group[:lambd] * state[:eta])
44
+
45
+ # update parameter
46
+ p.data.add!(-state[:eta], grad)
47
+
48
+ # averaging
49
+ if state[:mu] != 1
50
+ state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
51
+ else
52
+ state[:ax].copy!(p.data)
53
+ end
54
+
55
+ # update eta and mu
56
+ state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
57
+ state[:mu] = 1 / [1, state[:step] - group[:t0]].max
58
+ end
59
+ end
60
+
61
+ loss
62
+ end
63
+ end
64
+ end
65
+ end
@@ -0,0 +1,33 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class LRScheduler
5
+ def initialize(optimizer, last_epoch)
6
+ @optimizer = optimizer
7
+ if last_epoch == -1
8
+ optimizer.param_groups.each do |group|
9
+ group[:initial_lr] ||= group[:lr]
10
+ end
11
+ last_epoch = 0
12
+ else
13
+ raise NotImplementedYet
14
+ end
15
+ @base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
16
+ @last_epoch = last_epoch
17
+
18
+ @step_count = 0
19
+ step(last_epoch)
20
+ end
21
+
22
+ def step(epoch = nil)
23
+ @step_count += 1
24
+ epoch ||= @last_epoch + 1
25
+ @last_epoch = epoch
26
+ @optimizer.param_groups.zip(get_lr).each do |param_group, lr|
27
+ param_group[:lr] = lr
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class StepLR < LRScheduler
5
+ def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
6
+ @step_size = step_size
7
+ @gamma = gamma
8
+ super(optimizer, last_epoch)
9
+ end
10
+
11
+ def get_lr
12
+ @base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
13
+ end
14
+ end
15
+ end
16
+ end
17
+ end
@@ -0,0 +1,62 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
2
+ module Torch
3
+ module Optim
4
+ class Optimizer
5
+ attr_reader :param_groups
6
+
7
+ def initialize(params, defaults)
8
+ @defaults = defaults
9
+ @state = Hash.new { |hash, key| hash[key] = {} }
10
+ @param_groups = []
11
+
12
+ param_groups = params
13
+ if param_groups.empty?
14
+ raise ArgumentError, "optimizer got an empty parameter list"
15
+ end
16
+ if !param_groups[0].is_a?(Hash)
17
+ param_groups = [{params: param_groups}]
18
+ end
19
+
20
+ param_groups.each do |param_group|
21
+ add_param_group(param_group)
22
+ end
23
+ end
24
+
25
+ def add_param_group(param_group)
26
+ # TODO more advanced logic
27
+ @param_groups << @defaults.merge(param_group)
28
+ end
29
+
30
+ def load_state_dict(state_dict)
31
+ raise NotImplementedYet
32
+ end
33
+
34
+ def state_dict
35
+ pack_group = lambda do |group|
36
+ packed = group.select { |k, _| k != :params }.to_h
37
+ packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed
39
+ end
40
+
41
+ param_groups = @param_groups.map { |g| pack_group.call(g) }
42
+ packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
+
44
+ {
45
+ state: packed_state,
46
+ param_groups: param_groups
47
+ }
48
+ end
49
+
50
+ def zero_grad
51
+ @param_groups.each do |group|
52
+ group[:params].each do |p|
53
+ if p.grad
54
+ p.grad.detach!
55
+ p.grad.zero!
56
+ end
57
+ end
58
+ end
59
+ end
60
+ end
61
+ end
62
+ end
@@ -0,0 +1,76 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
2
+ module Torch
3
+ module Optim
4
+ class RMSprop < Optimizer
5
+ def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
8
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
9
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
10
+ raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
11
+
12
+ defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
13
+ super(params, defaults)
14
+ end
15
+
16
+ def step(closure = nil)
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "RMSprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:square_avg] = Torch.zeros_like(p.data)
35
+ if group[:momentum] > 0
36
+ state[:momentum_buffer] = Torch.zeros_like(p.data)
37
+ end
38
+ if group[:centered]
39
+ state[:grad_avg] = Torch.zeros_like(p.data)
40
+ end
41
+ end
42
+
43
+ square_avg = state[:square_avg]
44
+ alpha = group[:alpha]
45
+
46
+ state[:step] += 1
47
+
48
+ if group[:weight_decay] != 0
49
+ grad = grad.add(group[:weight_decay], p.data)
50
+ end
51
+
52
+ square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
53
+
54
+ if group[:centered]
55
+ grad_avg = state[:grad_avg]
56
+ grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
+ avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
58
+ else
59
+ avg = square_avg.sqrt.add!(group[:eps])
60
+ end
61
+
62
+ if group[:momentum] > 0
63
+ buf = state[:momentum_buffer]
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
+ p.data.add!(-group[:lr], buf)
66
+ else
67
+ p.data.addcdiv!(-group[:lr], grad, avg)
68
+ end
69
+ end
70
+ end
71
+
72
+ loss
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
2
+ module Torch
3
+ module Optim
4
+ class Rprop < Optimizer
5
+ def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
8
+
9
+ defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ # TODO implement []=
15
+ raise NotImplementedYet
16
+
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Rprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:prev] = Torch.zeros_like(p.data)
35
+ state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
36
+ end
37
+
38
+ etaminus, etaplus = group[:etas]
39
+ step_size_min, step_size_max = group[:step_sizes]
40
+ step_size = state[:step_size]
41
+
42
+ state[:step] += 1
43
+
44
+ sign = grad.mul(state[:prev]).sign
45
+ sign[sign.gt(0)] = etaplus
46
+ sign[sign.lt(0)] = etaminus
47
+ sign[sign.eq(0)] = 1
48
+
49
+ # update stepsizes with step size updates
50
+ step_size.mul!(sign).clamp!(step_size_min, step_size_max)
51
+
52
+ # for dir<0, dfdx=0
53
+ # for dir>=0 dfdx=dfdx
54
+ grad = grad.clone
55
+ grad[sign.eq(etaminus)] = 0
56
+
57
+ # update parameters
58
+ p.data.addcmul!(-1, grad.sign, step_size)
59
+
60
+ state[:prev].copy!(grad)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,60 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
2
+ module Torch
3
+ module Optim
4
+ class SGD < Optimizer
5
+ def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
7
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
8
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
9
+
10
+ defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
11
+
12
+ if nesterov && (momentum <= 0 || dampening != 0)
13
+ raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
14
+ end
15
+
16
+ super(params, defaults)
17
+ end
18
+
19
+ def step(closure = nil)
20
+ loss = nil
21
+ if closure
22
+ loss = closure.call
23
+ end
24
+
25
+ @param_groups.each do |group|
26
+ weight_decay = group[:weight_decay]
27
+ momentum = group[:momentum]
28
+ dampening = group[:dampening]
29
+ nesterov = group[:nesterov]
30
+
31
+ group[:params].each do |p|
32
+ next unless p.grad
33
+ d_p = p.grad.data
34
+ if weight_decay != 0
35
+ d_p.add!(weight_decay, p.data)
36
+ end
37
+ if momentum != 0
38
+ param_state = @state[p]
39
+ if !param_state.key(:momentum_buffer)
40
+ buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
+ else
42
+ buf = param_state[:momentum_buffer]
43
+ buf.mul!(momentum).add!(1 - dampening, d_p)
44
+ end
45
+ if nesterov
46
+ d_p = d_p.add(momentum, buf)
47
+ else
48
+ d_p = buf
49
+ end
50
+ end
51
+
52
+ p.data.add!(-group[:lr], d_p)
53
+ end
54
+ end
55
+
56
+ loss
57
+ end
58
+ end
59
+ end
60
+ end
@@ -0,0 +1,196 @@
1
+ module Torch
2
+ class Tensor
3
+ include Comparable
4
+ include Inspector
5
+
6
+ alias_method :requires_grad?, :requires_grad
7
+
8
+ def self.new(*size)
9
+ if size.length == 1 && size.first.is_a?(Tensor)
10
+ size.first
11
+ else
12
+ Torch.empty(*size)
13
+ end
14
+ end
15
+
16
+ def dtype
17
+ dtype = ENUM_TO_DTYPE[_dtype]
18
+ raise Error, "Unknown type: #{_dtype}" unless dtype
19
+ dtype
20
+ end
21
+
22
+ def layout
23
+ _layout.downcase.to_sym
24
+ end
25
+
26
+ def to_s
27
+ inspect
28
+ end
29
+
30
+ def to_a
31
+ reshape_arr(_data, shape)
32
+ end
33
+
34
+ # TODO support dtype
35
+ def to(device, non_blocking: false, copy: false)
36
+ device = Device.new(device) if device.is_a?(String)
37
+ _to(device, _dtype, non_blocking, copy)
38
+ end
39
+
40
+ def size(dim = nil)
41
+ if dim
42
+ _size(dim)
43
+ else
44
+ shape
45
+ end
46
+ end
47
+
48
+ def shape
49
+ dim.times.map { |i| size(i) }
50
+ end
51
+
52
+ def view(*size)
53
+ _view(size)
54
+ end
55
+
56
+ def item
57
+ if numel != 1
58
+ raise Error, "only one element tensors can be converted to Ruby scalars"
59
+ end
60
+ _data.first
61
+ end
62
+
63
+ # unsure if this is correct
64
+ def new
65
+ Torch.empty(0, dtype: dtype)
66
+ end
67
+
68
+ def backward(gradient = nil)
69
+ if gradient
70
+ _backward_gradient(gradient)
71
+ else
72
+ _backward
73
+ end
74
+ end
75
+
76
+ # TODO read directly from memory
77
+ def numo
78
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
79
+ cls = Torch._dtype_to_numo[dtype]
80
+ raise Error, "Cannot convert #{dtype} to Numo" unless cls
81
+ cls.cast(_data).reshape(*shape)
82
+ end
83
+
84
+ def new_ones(*size, **options)
85
+ Torch.ones_like(Torch.empty(*size), **options)
86
+ end
87
+
88
+ def requires_grad!(requires_grad = true)
89
+ _requires_grad!(requires_grad)
90
+ end
91
+
92
+ def type(dtype)
93
+ enum = DTYPE_TO_ENUM[dtype]
94
+ raise Error, "Unknown type: #{dtype}" unless enum
95
+ _type(enum)
96
+ end
97
+
98
+ def add!(value = 1, other)
99
+ if other.is_a?(Numeric)
100
+ _add_scalar!(other * value)
101
+ else
102
+ # need to use alpha for sparse tensors instead of multiplying
103
+ _add_alpha!(other, value)
104
+ end
105
+ end
106
+
107
+ def mul!(other)
108
+ if other.is_a?(Numeric)
109
+ _mul_scalar!(other)
110
+ else
111
+ _mul!(other)
112
+ end
113
+ end
114
+
115
+ # operations
116
+ %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
117
+ define_method(op) do |*args, **options, &block|
118
+ if options.any?
119
+ Torch.send(op, self, *args, **options, &block)
120
+ else
121
+ Torch.send(op, self, *args, &block)
122
+ end
123
+ end
124
+ end
125
+
126
+ def +(other)
127
+ add(other)
128
+ end
129
+
130
+ def -(other)
131
+ sub(other)
132
+ end
133
+
134
+ def *(other)
135
+ mul(other)
136
+ end
137
+
138
+ def /(other)
139
+ div(other)
140
+ end
141
+
142
+ def %(other)
143
+ remainder(other)
144
+ end
145
+
146
+ def **(other)
147
+ pow(other)
148
+ end
149
+
150
+ def -@
151
+ neg
152
+ end
153
+
154
+ def <=>(other)
155
+ item <=> other
156
+ end
157
+
158
+ # based on python_variable_indexing.cpp
159
+ def [](*indexes)
160
+ result = self
161
+ dim = 0
162
+ indexes.each do |index|
163
+ if index.is_a?(Numeric)
164
+ result = result._select(dim, index)
165
+ elsif index.is_a?(Range)
166
+ finish = index.end
167
+ finish += 1 unless index.exclude_end?
168
+ result = result._slice(dim, index.begin, finish, 1)
169
+ dim += 1
170
+ else
171
+ raise Error, "Unsupported index type"
172
+ end
173
+ end
174
+ result
175
+ end
176
+
177
+ # TODO
178
+ # based on python_variable_indexing.cpp
179
+ # def []=(index, value)
180
+ # end
181
+
182
+ private
183
+
184
+ def reshape_arr(arr, dims)
185
+ if dims.empty?
186
+ arr
187
+ else
188
+ arr = arr.flatten
189
+ dims[1..-1].reverse.each do |dim|
190
+ arr = arr.each_slice(dim)
191
+ end
192
+ arr.to_a
193
+ end
194
+ end
195
+ end
196
+ end