torch-rb 0.3.6 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -58,7 +58,10 @@ module Torch
58
58
 
59
59
  @buffers.each_key do |k|
60
60
  buf = @buffers[k]
61
- @buffers[k] = fn.call(buf) unless buf.nil?
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
62
65
  end
63
66
 
64
67
  self
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end
@@ -53,12 +53,12 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,11 +42,11 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -58,8 +58,8 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,14 +36,14 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end
@@ -52,7 +52,7 @@ module Torch
52
52
  grad[sign.eq(etaminus)] = 0
53
53
 
54
54
  # update parameters
55
- p.data.addcmul!(-1, grad.sign, step_size)
55
+ p.data.addcmul!(grad.sign, step_size, value: -1)
56
56
 
57
57
  state[:prev].copy!(grad)
58
58
  end
@@ -32,24 +32,24 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
39
- if !param_state.key(:momentum_buffer)
39
+ if !param_state.key?(:momentum_buffer)
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
- d_p = d_p.add(momentum, buf)
46
+ d_p = d_p.add(buf, alpha: momentum)
47
47
  else
48
48
  d_p = buf
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55
 
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -48,6 +60,11 @@ module Torch
48
60
  end
49
61
 
50
62
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
63
+ if device.is_a?(Symbol) && !dtype
64
+ dtype = device
65
+ device = nil
66
+ end
67
+
51
68
  device ||= self.device
52
69
  device = Device.new(device) if device.is_a?(String)
53
70
 
@@ -68,14 +85,18 @@ module Torch
68
85
 
69
86
  def size(dim = nil)
70
87
  if dim
71
- _size_int(dim)
88
+ _size(dim)
72
89
  else
73
90
  shape
74
91
  end
75
92
  end
76
93
 
77
- def shape
78
- dim.times.map { |i| size(i) }
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
79
100
  end
80
101
 
81
102
  # mirror Python len()
@@ -119,60 +140,14 @@ module Torch
119
140
  end
120
141
 
121
142
  def type(dtype)
122
- enum = DTYPE_TO_ENUM[dtype]
123
- raise Error, "Unknown type: #{dtype}" unless enum
124
- _type(enum)
125
- end
126
-
127
- def reshape(*size)
128
- # Python doesn't check if size == 1, just ignores later arguments
129
- size = size.first if size.size == 1 && size.first.is_a?(Array)
130
- _reshape(size)
131
- end
132
-
133
- def view(*size)
134
- size = size.first if size.size == 1 && size.first.is_a?(Array)
135
- _view(size)
136
- end
137
-
138
- def +(other)
139
- add(other)
140
- end
141
-
142
- def -(other)
143
- sub(other)
144
- end
145
-
146
- def *(other)
147
- mul(other)
148
- end
149
-
150
- def /(other)
151
- div(other)
152
- end
153
-
154
- def %(other)
155
- remainder(other)
156
- end
157
-
158
- def **(other)
159
- pow(other)
160
- end
161
-
162
- def -@
163
- neg
164
- end
165
-
166
- def &(other)
167
- logical_and(other)
168
- end
169
-
170
- def |(other)
171
- logical_or(other)
172
- end
173
-
174
- def ^(other)
175
- logical_xor(other)
143
+ if dtype.is_a?(Class)
144
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
145
+ dtype.new(self)
146
+ else
147
+ enum = DTYPE_TO_ENUM[dtype]
148
+ raise Error, "Invalid type: #{dtype}" unless enum
149
+ _type(enum)
150
+ end
176
151
  end
177
152
 
178
153
  # TODO better compare?
@@ -183,7 +158,7 @@ module Torch
183
158
  # based on python_variable_indexing.cpp and
184
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
185
160
  def [](*indexes)
186
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
187
162
  end
188
163
 
189
164
  # based on python_variable_indexing.cpp and
@@ -191,62 +166,13 @@ module Torch
191
166
  def []=(*indexes, value)
192
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
193
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
194
- _index_put_custom(tensor_indexes(indexes), value)
195
- end
196
-
197
- # native functions that need manually defined
198
-
199
- # value and other are swapped for some methods
200
- def add!(value = 1, other)
201
- if other.is_a?(Numeric)
202
- _add__scalar(other, value)
203
- else
204
- _add__tensor(other, value)
205
- end
169
+ _index_put_custom(indexes, value)
206
170
  end
207
171
 
208
172
  # parser can't handle overlap, so need to handle manually
209
173
  def random!(*args)
210
- case args.size
211
- when 1
212
- _random__to(*args)
213
- when 2
214
- _random__from(*args)
215
- else
216
- _random_(*args)
217
- end
218
- end
219
-
220
- def clamp!(min, max)
221
- _clamp_min_(min)
222
- _clamp_max_(max)
223
- end
224
-
225
- private
226
-
227
- def tensor_indexes(indexes)
228
- indexes.map do |index|
229
- case index
230
- when Integer
231
- TensorIndex.integer(index)
232
- when Range
233
- finish = index.end || -1
234
- if finish == -1 && !index.exclude_end?
235
- finish = nil
236
- else
237
- finish += 1 unless index.exclude_end?
238
- end
239
- TensorIndex.slice(index.begin, finish)
240
- when Tensor
241
- TensorIndex.tensor(index)
242
- when nil
243
- TensorIndex.none
244
- when true, false
245
- TensorIndex.boolean(index)
246
- else
247
- raise Error, "Unsupported index type: #{index.class.name}"
248
- end
249
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
250
176
  end
251
177
  end
252
178
  end