torch-rb 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,33 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class LRScheduler
5
+ def initialize(optimizer, last_epoch)
6
+ @optimizer = optimizer
7
+ if last_epoch == -1
8
+ optimizer.param_groups.each do |group|
9
+ group[:initial_lr] ||= group[:lr]
10
+ end
11
+ last_epoch = 0
12
+ else
13
+ raise NotImplementedYet
14
+ end
15
+ @base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
16
+ @last_epoch = last_epoch
17
+
18
+ @step_count = 0
19
+ step(last_epoch)
20
+ end
21
+
22
+ def step(epoch = nil)
23
+ @step_count += 1
24
+ epoch ||= @last_epoch + 1
25
+ @last_epoch = epoch
26
+ @optimizer.param_groups.zip(get_lr).each do |param_group, lr|
27
+ param_group[:lr] = lr
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class StepLR < LRScheduler
5
+ def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
6
+ @step_size = step_size
7
+ @gamma = gamma
8
+ super(optimizer, last_epoch)
9
+ end
10
+
11
+ def get_lr
12
+ @base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
13
+ end
14
+ end
15
+ end
16
+ end
17
+ end
@@ -1,6 +1,62 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
1
2
  module Torch
2
3
  module Optim
3
4
  class Optimizer
5
+ attr_reader :param_groups
6
+
7
+ def initialize(params, defaults)
8
+ @defaults = defaults
9
+ @state = Hash.new { |hash, key| hash[key] = {} }
10
+ @param_groups = []
11
+
12
+ param_groups = params
13
+ if param_groups.empty?
14
+ raise ArgumentError, "optimizer got an empty parameter list"
15
+ end
16
+ if !param_groups[0].is_a?(Hash)
17
+ param_groups = [{params: param_groups}]
18
+ end
19
+
20
+ param_groups.each do |param_group|
21
+ add_param_group(param_group)
22
+ end
23
+ end
24
+
25
+ def add_param_group(param_group)
26
+ # TODO more advanced logic
27
+ @param_groups << @defaults.merge(param_group)
28
+ end
29
+
30
+ def load_state_dict(state_dict)
31
+ raise NotImplementedYet
32
+ end
33
+
34
+ def state_dict
35
+ pack_group = lambda do |group|
36
+ packed = group.select { |k, _| k != :params }.to_h
37
+ packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed
39
+ end
40
+
41
+ param_groups = @param_groups.map { |g| pack_group.call(g) }
42
+ packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
+
44
+ {
45
+ state: packed_state,
46
+ param_groups: param_groups
47
+ }
48
+ end
49
+
50
+ def zero_grad
51
+ @param_groups.each do |group|
52
+ group[:params].each do |p|
53
+ if p.grad
54
+ p.grad.detach!
55
+ p.grad.zero!
56
+ end
57
+ end
58
+ end
59
+ end
4
60
  end
5
61
  end
6
62
  end
@@ -0,0 +1,76 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
2
+ module Torch
3
+ module Optim
4
+ class RMSprop < Optimizer
5
+ def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
8
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
9
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
10
+ raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
11
+
12
+ defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
13
+ super(params, defaults)
14
+ end
15
+
16
+ def step(closure = nil)
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "RMSprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:square_avg] = Torch.zeros_like(p.data)
35
+ if group[:momentum] > 0
36
+ state[:momentum_buffer] = Torch.zeros_like(p.data)
37
+ end
38
+ if group[:centered]
39
+ state[:grad_avg] = Torch.zeros_like(p.data)
40
+ end
41
+ end
42
+
43
+ square_avg = state[:square_avg]
44
+ alpha = group[:alpha]
45
+
46
+ state[:step] += 1
47
+
48
+ if group[:weight_decay] != 0
49
+ grad = grad.add(group[:weight_decay], p.data)
50
+ end
51
+
52
+ square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
53
+
54
+ if group[:centered]
55
+ grad_avg = state[:grad_avg]
56
+ grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
+ avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
58
+ else
59
+ avg = square_avg.sqrt.add!(group[:eps])
60
+ end
61
+
62
+ if group[:momentum] > 0
63
+ buf = state[:momentum_buffer]
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
+ p.data.add!(-group[:lr], buf)
66
+ else
67
+ p.data.addcdiv!(-group[:lr], grad, avg)
68
+ end
69
+ end
70
+ end
71
+
72
+ loss
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
2
+ module Torch
3
+ module Optim
4
+ class Rprop < Optimizer
5
+ def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
8
+
9
+ defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ # TODO implement []=
15
+ raise NotImplementedYet
16
+
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Rprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:prev] = Torch.zeros_like(p.data)
35
+ state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
36
+ end
37
+
38
+ etaminus, etaplus = group[:etas]
39
+ step_size_min, step_size_max = group[:step_sizes]
40
+ step_size = state[:step_size]
41
+
42
+ state[:step] += 1
43
+
44
+ sign = grad.mul(state[:prev]).sign
45
+ sign[sign.gt(0)] = etaplus
46
+ sign[sign.lt(0)] = etaminus
47
+ sign[sign.eq(0)] = 1
48
+
49
+ # update stepsizes with step size updates
50
+ step_size.mul!(sign).clamp!(step_size_min, step_size_max)
51
+
52
+ # for dir<0, dfdx=0
53
+ # for dir>=0 dfdx=dfdx
54
+ grad = grad.clone
55
+ grad[sign.eq(etaminus)] = 0
56
+
57
+ # update parameters
58
+ p.data.addcmul!(-1, grad.sign, step_size)
59
+
60
+ state[:prev].copy!(grad)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end
@@ -1,27 +1,59 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
1
2
  module Torch
2
3
  module Optim
3
4
  class SGD < Optimizer
4
- def initialize(params, lr:)
5
- @params = params
6
- @lr = lr
7
- end
5
+ def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
7
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
8
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
8
9
 
9
- def zero_grad
10
- @params.each do |param|
11
- if param.grad
12
- param.grad.detach!
13
- param.grad.zero!
14
- end
10
+ defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
11
+
12
+ if nesterov && (momentum <= 0 || dampening != 0)
13
+ raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
15
14
  end
15
+
16
+ super(params, defaults)
16
17
  end
17
18
 
18
- def step
19
- @params.each do |param|
20
- next unless param.grad
21
- d_p = param.grad.data
22
- # same as param.data.add!(-@lr, d_p)
23
- param.data.sub!(d_p * @lr)
19
+ def step(closure = nil)
20
+ loss = nil
21
+ if closure
22
+ loss = closure.call
23
+ end
24
+
25
+ @param_groups.each do |group|
26
+ weight_decay = group[:weight_decay]
27
+ momentum = group[:momentum]
28
+ dampening = group[:dampening]
29
+ nesterov = group[:nesterov]
30
+
31
+ group[:params].each do |p|
32
+ next unless p.grad
33
+ d_p = p.grad.data
34
+ if weight_decay != 0
35
+ d_p.add!(weight_decay, p.data)
36
+ end
37
+ if momentum != 0
38
+ param_state = @state[p]
39
+ if !param_state.key(:momentum_buffer)
40
+ buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
+ else
42
+ buf = param_state[:momentum_buffer]
43
+ buf.mul!(momentum).add!(1 - dampening, d_p)
44
+ end
45
+ if nesterov
46
+ d_p = d_p.add(momentum, buf)
47
+ else
48
+ d_p = buf
49
+ end
50
+ end
51
+
52
+ p.data.add!(-group[:lr], d_p)
53
+ end
24
54
  end
55
+
56
+ loss
25
57
  end
26
58
  end
27
59
  end
data/lib/torch/tensor.rb CHANGED
@@ -6,10 +6,10 @@ module Torch
6
6
  alias_method :requires_grad?, :requires_grad
7
7
 
8
8
  def self.new(*size)
9
- if size.first.is_a?(Tensor)
9
+ if size.length == 1 && size.first.is_a?(Tensor)
10
10
  size.first
11
11
  else
12
- Torch.rand(*size)
12
+ Torch.empty(*size)
13
13
  end
14
14
  end
15
15
 
@@ -31,6 +31,12 @@ module Torch
31
31
  reshape_arr(_data, shape)
32
32
  end
33
33
 
34
+ # TODO support dtype
35
+ def to(device, non_blocking: false, copy: false)
36
+ device = Device.new(device) if device.is_a?(String)
37
+ _to(device, _dtype, non_blocking, copy)
38
+ end
39
+
34
40
  def size(dim = nil)
35
41
  if dim
36
42
  _size(dim)
@@ -54,6 +60,11 @@ module Torch
54
60
  _data.first
55
61
  end
56
62
 
63
+ # unsure if this is correct
64
+ def new
65
+ Torch.empty(0, dtype: dtype)
66
+ end
67
+
57
68
  def backward(gradient = nil)
58
69
  if gradient
59
70
  _backward_gradient(gradient)
@@ -84,8 +95,25 @@ module Torch
84
95
  _type(enum)
85
96
  end
86
97
 
98
+ def add!(value = 1, other)
99
+ if other.is_a?(Numeric)
100
+ _add_scalar!(other * value)
101
+ else
102
+ # need to use alpha for sparse tensors instead of multiplying
103
+ _add_alpha!(other, value)
104
+ end
105
+ end
106
+
107
+ def mul!(other)
108
+ if other.is_a?(Numeric)
109
+ _mul_scalar!(other)
110
+ else
111
+ _mul!(other)
112
+ end
113
+ end
114
+
87
115
  # operations
88
- %w(add sub mul div remainder pow neg sum mean num norm min max dot matmul exp log unsqueeze reshape argmax eq).each do |op|
116
+ %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
89
117
  define_method(op) do |*args, **options, &block|
90
118
  if options.any?
91
119
  Torch.send(op, self, *args, **options, &block)
@@ -127,10 +155,11 @@ module Torch
127
155
  item <=> other
128
156
  end
129
157
 
158
+ # based on python_variable_indexing.cpp
130
159
  def [](*indexes)
131
160
  result = self
132
161
  dim = 0
133
- indexes.each_with_index do |index|
162
+ indexes.each do |index|
134
163
  if index.is_a?(Numeric)
135
164
  result = result._select(dim, index)
136
165
  elsif index.is_a?(Range)
@@ -145,6 +174,11 @@ module Torch
145
174
  result
146
175
  end
147
176
 
177
+ # TODO
178
+ # based on python_variable_indexing.cpp
179
+ # def []=(index, value)
180
+ # end
181
+
148
182
  private
149
183
 
150
184
  def reshape_arr(arr, dims)
@@ -2,19 +2,25 @@ module Torch
2
2
  module Utils
3
3
  module Data
4
4
  class DataLoader
5
+ include Enumerable
6
+
7
+ attr_reader :dataset
8
+
5
9
  def initialize(dataset, batch_size: 1)
6
10
  @dataset = dataset
7
11
  @batch_size = batch_size
8
12
  end
9
13
 
10
14
  def each
11
- size = @dataset.size
12
- start_index = 0
13
- while start_index < size
15
+ size.times do |i|
16
+ start_index = i * @batch_size
14
17
  yield @dataset[start_index...(start_index + @batch_size)]
15
- start_index += @batch_size
16
18
  end
17
19
  end
20
+
21
+ def size
22
+ (@dataset.size / @batch_size.to_f).ceil
23
+ end
18
24
  end
19
25
  end
20
26
  end