torch-rb 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class LRScheduler
5
+ def initialize(optimizer, last_epoch)
6
+ @optimizer = optimizer
7
+ if last_epoch == -1
8
+ optimizer.param_groups.each do |group|
9
+ group[:initial_lr] ||= group[:lr]
10
+ end
11
+ last_epoch = 0
12
+ else
13
+ raise NotImplementedYet
14
+ end
15
+ @base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
16
+ @last_epoch = last_epoch
17
+
18
+ @step_count = 0
19
+ step(last_epoch)
20
+ end
21
+
22
+ def step(epoch = nil)
23
+ @step_count += 1
24
+ epoch ||= @last_epoch + 1
25
+ @last_epoch = epoch
26
+ @optimizer.param_groups.zip(get_lr).each do |param_group, lr|
27
+ param_group[:lr] = lr
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class StepLR < LRScheduler
5
+ def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
6
+ @step_size = step_size
7
+ @gamma = gamma
8
+ super(optimizer, last_epoch)
9
+ end
10
+
11
+ def get_lr
12
+ @base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
13
+ end
14
+ end
15
+ end
16
+ end
17
+ end
@@ -1,6 +1,62 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
1
2
  module Torch
2
3
  module Optim
3
4
  class Optimizer
5
+ attr_reader :param_groups
6
+
7
+ def initialize(params, defaults)
8
+ @defaults = defaults
9
+ @state = Hash.new { |hash, key| hash[key] = {} }
10
+ @param_groups = []
11
+
12
+ param_groups = params
13
+ if param_groups.empty?
14
+ raise ArgumentError, "optimizer got an empty parameter list"
15
+ end
16
+ if !param_groups[0].is_a?(Hash)
17
+ param_groups = [{params: param_groups}]
18
+ end
19
+
20
+ param_groups.each do |param_group|
21
+ add_param_group(param_group)
22
+ end
23
+ end
24
+
25
+ def add_param_group(param_group)
26
+ # TODO more advanced logic
27
+ @param_groups << @defaults.merge(param_group)
28
+ end
29
+
30
+ def load_state_dict(state_dict)
31
+ raise NotImplementedYet
32
+ end
33
+
34
+ def state_dict
35
+ pack_group = lambda do |group|
36
+ packed = group.select { |k, _| k != :params }.to_h
37
+ packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed
39
+ end
40
+
41
+ param_groups = @param_groups.map { |g| pack_group.call(g) }
42
+ packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
+
44
+ {
45
+ state: packed_state,
46
+ param_groups: param_groups
47
+ }
48
+ end
49
+
50
+ def zero_grad
51
+ @param_groups.each do |group|
52
+ group[:params].each do |p|
53
+ if p.grad
54
+ p.grad.detach!
55
+ p.grad.zero!
56
+ end
57
+ end
58
+ end
59
+ end
4
60
  end
5
61
  end
6
62
  end
@@ -0,0 +1,76 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
2
+ module Torch
3
+ module Optim
4
+ class RMSprop < Optimizer
5
+ def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
8
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
9
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
10
+ raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
11
+
12
+ defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
13
+ super(params, defaults)
14
+ end
15
+
16
+ def step(closure = nil)
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "RMSprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:square_avg] = Torch.zeros_like(p.data)
35
+ if group[:momentum] > 0
36
+ state[:momentum_buffer] = Torch.zeros_like(p.data)
37
+ end
38
+ if group[:centered]
39
+ state[:grad_avg] = Torch.zeros_like(p.data)
40
+ end
41
+ end
42
+
43
+ square_avg = state[:square_avg]
44
+ alpha = group[:alpha]
45
+
46
+ state[:step] += 1
47
+
48
+ if group[:weight_decay] != 0
49
+ grad = grad.add(group[:weight_decay], p.data)
50
+ end
51
+
52
+ square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
53
+
54
+ if group[:centered]
55
+ grad_avg = state[:grad_avg]
56
+ grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
+ avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
58
+ else
59
+ avg = square_avg.sqrt.add!(group[:eps])
60
+ end
61
+
62
+ if group[:momentum] > 0
63
+ buf = state[:momentum_buffer]
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
+ p.data.add!(-group[:lr], buf)
66
+ else
67
+ p.data.addcdiv!(-group[:lr], grad, avg)
68
+ end
69
+ end
70
+ end
71
+
72
+ loss
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
2
+ module Torch
3
+ module Optim
4
+ class Rprop < Optimizer
5
+ def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
8
+
9
+ defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ # TODO implement []=
15
+ raise NotImplementedYet
16
+
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Rprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:prev] = Torch.zeros_like(p.data)
35
+ state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
36
+ end
37
+
38
+ etaminus, etaplus = group[:etas]
39
+ step_size_min, step_size_max = group[:step_sizes]
40
+ step_size = state[:step_size]
41
+
42
+ state[:step] += 1
43
+
44
+ sign = grad.mul(state[:prev]).sign
45
+ sign[sign.gt(0)] = etaplus
46
+ sign[sign.lt(0)] = etaminus
47
+ sign[sign.eq(0)] = 1
48
+
49
+ # update stepsizes with step size updates
50
+ step_size.mul!(sign).clamp!(step_size_min, step_size_max)
51
+
52
+ # for dir<0, dfdx=0
53
+ # for dir>=0 dfdx=dfdx
54
+ grad = grad.clone
55
+ grad[sign.eq(etaminus)] = 0
56
+
57
+ # update parameters
58
+ p.data.addcmul!(-1, grad.sign, step_size)
59
+
60
+ state[:prev].copy!(grad)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end
@@ -1,27 +1,59 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
1
2
  module Torch
2
3
  module Optim
3
4
  class SGD < Optimizer
4
- def initialize(params, lr:)
5
- @params = params
6
- @lr = lr
7
- end
5
+ def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
7
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
8
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
8
9
 
9
- def zero_grad
10
- @params.each do |param|
11
- if param.grad
12
- param.grad.detach!
13
- param.grad.zero!
14
- end
10
+ defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
11
+
12
+ if nesterov && (momentum <= 0 || dampening != 0)
13
+ raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
15
14
  end
15
+
16
+ super(params, defaults)
16
17
  end
17
18
 
18
- def step
19
- @params.each do |param|
20
- next unless param.grad
21
- d_p = param.grad.data
22
- # same as param.data.add!(-@lr, d_p)
23
- param.data.sub!(d_p * @lr)
19
+ def step(closure = nil)
20
+ loss = nil
21
+ if closure
22
+ loss = closure.call
23
+ end
24
+
25
+ @param_groups.each do |group|
26
+ weight_decay = group[:weight_decay]
27
+ momentum = group[:momentum]
28
+ dampening = group[:dampening]
29
+ nesterov = group[:nesterov]
30
+
31
+ group[:params].each do |p|
32
+ next unless p.grad
33
+ d_p = p.grad.data
34
+ if weight_decay != 0
35
+ d_p.add!(weight_decay, p.data)
36
+ end
37
+ if momentum != 0
38
+ param_state = @state[p]
39
+ if !param_state.key(:momentum_buffer)
40
+ buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
+ else
42
+ buf = param_state[:momentum_buffer]
43
+ buf.mul!(momentum).add!(1 - dampening, d_p)
44
+ end
45
+ if nesterov
46
+ d_p = d_p.add(momentum, buf)
47
+ else
48
+ d_p = buf
49
+ end
50
+ end
51
+
52
+ p.data.add!(-group[:lr], d_p)
53
+ end
24
54
  end
55
+
56
+ loss
25
57
  end
26
58
  end
27
59
  end
data/lib/torch/tensor.rb CHANGED
@@ -6,10 +6,10 @@ module Torch
6
6
  alias_method :requires_grad?, :requires_grad
7
7
 
8
8
  def self.new(*size)
9
- if size.first.is_a?(Tensor)
9
+ if size.length == 1 && size.first.is_a?(Tensor)
10
10
  size.first
11
11
  else
12
- Torch.rand(*size)
12
+ Torch.empty(*size)
13
13
  end
14
14
  end
15
15
 
@@ -31,6 +31,12 @@ module Torch
31
31
  reshape_arr(_data, shape)
32
32
  end
33
33
 
34
+ # TODO support dtype
35
+ def to(device, non_blocking: false, copy: false)
36
+ device = Device.new(device) if device.is_a?(String)
37
+ _to(device, _dtype, non_blocking, copy)
38
+ end
39
+
34
40
  def size(dim = nil)
35
41
  if dim
36
42
  _size(dim)
@@ -54,6 +60,11 @@ module Torch
54
60
  _data.first
55
61
  end
56
62
 
63
+ # unsure if this is correct
64
+ def new
65
+ Torch.empty(0, dtype: dtype)
66
+ end
67
+
57
68
  def backward(gradient = nil)
58
69
  if gradient
59
70
  _backward_gradient(gradient)
@@ -84,8 +95,25 @@ module Torch
84
95
  _type(enum)
85
96
  end
86
97
 
98
+ def add!(value = 1, other)
99
+ if other.is_a?(Numeric)
100
+ _add_scalar!(other * value)
101
+ else
102
+ # need to use alpha for sparse tensors instead of multiplying
103
+ _add_alpha!(other, value)
104
+ end
105
+ end
106
+
107
+ def mul!(other)
108
+ if other.is_a?(Numeric)
109
+ _mul_scalar!(other)
110
+ else
111
+ _mul!(other)
112
+ end
113
+ end
114
+
87
115
  # operations
88
- %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op|
116
+ %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
89
117
  define_method(op) do |*args, **options, &block|
90
118
  if options.any?
91
119
  Torch.send(op, self, *args, **options, &block)
@@ -127,10 +155,11 @@ module Torch
127
155
  item <=> other
128
156
  end
129
157
 
158
+ # based on python_variable_indexing.cpp
130
159
  def [](*indexes)
131
160
  result = self
132
161
  dim = 0
133
- indexes.each_with_index do |index|
162
+ indexes.each do |index|
134
163
  if index.is_a?(Numeric)
135
164
  result = result._select(dim, index)
136
165
  elsif index.is_a?(Range)
@@ -145,6 +174,11 @@ module Torch
145
174
  result
146
175
  end
147
176
 
177
+ # TODO
178
+ # based on python_variable_indexing.cpp
179
+ # def []=(index, value)
180
+ # end
181
+
148
182
  private
149
183
 
150
184
  def reshape_arr(arr, dims)
@@ -2,19 +2,25 @@ module Torch
2
2
  module Utils
3
3
  module Data
4
4
  class DataLoader
5
+ include Enumerable
6
+
7
+ attr_reader :dataset
8
+
5
9
  def initialize(dataset, batch_size: 1)
6
10
  @dataset = dataset
7
11
  @batch_size = batch_size
8
12
  end
9
13
 
10
14
  def each
11
- size = @dataset.size
12
- start_index = 0
13
- while start_index < size
15
+ size.times do |i|
16
+ start_index = i * @batch_size
14
17
  yield @dataset[start_index...(start_index + @batch_size)]
15
- start_index += @batch_size
16
18
  end
17
19
  end
20
+
21
+ def size
22
+ (@dataset.size / @batch_size.to_f).ceil
23
+ end
18
24
  end
19
25
  end
20
26
  end