torch-rb 0.3.6 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -58,7 +58,10 @@ module Torch
58
58
 
59
59
  @buffers.each_key do |k|
60
60
  buf = @buffers[k]
61
- @buffers[k] = fn.call(buf) unless buf.nil?
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
62
65
  end
63
66
 
64
67
  self
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end
@@ -53,12 +53,12 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,11 +42,11 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -58,8 +58,8 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,14 +36,14 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end
@@ -52,7 +52,7 @@ module Torch
52
52
  grad[sign.eq(etaminus)] = 0
53
53
 
54
54
  # update parameters
55
- p.data.addcmul!(-1, grad.sign, step_size)
55
+ p.data.addcmul!(grad.sign, step_size, value: -1)
56
56
 
57
57
  state[:prev].copy!(grad)
58
58
  end
@@ -32,24 +32,24 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
39
- if !param_state.key(:momentum_buffer)
39
+ if !param_state.key?(:momentum_buffer)
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
- d_p = d_p.add(momentum, buf)
46
+ d_p = d_p.add(buf, alpha: momentum)
47
47
  else
48
48
  d_p = buf
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55
 
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -48,6 +60,11 @@ module Torch
48
60
  end
49
61
 
50
62
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
63
+ if device.is_a?(Symbol) && !dtype
64
+ dtype = device
65
+ device = nil
66
+ end
67
+
51
68
  device ||= self.device
52
69
  device = Device.new(device) if device.is_a?(String)
53
70
 
@@ -68,14 +85,18 @@ module Torch
68
85
 
69
86
  def size(dim = nil)
70
87
  if dim
71
- _size_int(dim)
88
+ _size(dim)
72
89
  else
73
90
  shape
74
91
  end
75
92
  end
76
93
 
77
- def shape
78
- dim.times.map { |i| size(i) }
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
79
100
  end
80
101
 
81
102
  # mirror Python len()
@@ -119,60 +140,14 @@ module Torch
119
140
  end
120
141
 
121
142
  def type(dtype)
122
- enum = DTYPE_TO_ENUM[dtype]
123
- raise Error, "Unknown type: #{dtype}" unless enum
124
- _type(enum)
125
- end
126
-
127
- def reshape(*size)
128
- # Python doesn't check if size == 1, just ignores later arguments
129
- size = size.first if size.size == 1 && size.first.is_a?(Array)
130
- _reshape(size)
131
- end
132
-
133
- def view(*size)
134
- size = size.first if size.size == 1 && size.first.is_a?(Array)
135
- _view(size)
136
- end
137
-
138
- def +(other)
139
- add(other)
140
- end
141
-
142
- def -(other)
143
- sub(other)
144
- end
145
-
146
- def *(other)
147
- mul(other)
148
- end
149
-
150
- def /(other)
151
- div(other)
152
- end
153
-
154
- def %(other)
155
- remainder(other)
156
- end
157
-
158
- def **(other)
159
- pow(other)
160
- end
161
-
162
- def -@
163
- neg
164
- end
165
-
166
- def &(other)
167
- logical_and(other)
168
- end
169
-
170
- def |(other)
171
- logical_or(other)
172
- end
173
-
174
- def ^(other)
175
- logical_xor(other)
143
+ if dtype.is_a?(Class)
144
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
145
+ dtype.new(self)
146
+ else
147
+ enum = DTYPE_TO_ENUM[dtype]
148
+ raise Error, "Invalid type: #{dtype}" unless enum
149
+ _type(enum)
150
+ end
176
151
  end
177
152
 
178
153
  # TODO better compare?
@@ -183,7 +158,7 @@ module Torch
183
158
  # based on python_variable_indexing.cpp and
184
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
185
160
  def [](*indexes)
186
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
187
162
  end
188
163
 
189
164
  # based on python_variable_indexing.cpp and
@@ -191,62 +166,13 @@ module Torch
191
166
  def []=(*indexes, value)
192
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
193
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
194
- _index_put_custom(tensor_indexes(indexes), value)
195
- end
196
-
197
- # native functions that need manually defined
198
-
199
- # value and other are swapped for some methods
200
- def add!(value = 1, other)
201
- if other.is_a?(Numeric)
202
- _add__scalar(other, value)
203
- else
204
- _add__tensor(other, value)
205
- end
169
+ _index_put_custom(indexes, value)
206
170
  end
207
171
 
208
172
  # parser can't handle overlap, so need to handle manually
209
173
  def random!(*args)
210
- case args.size
211
- when 1
212
- _random__to(*args)
213
- when 2
214
- _random__from(*args)
215
- else
216
- _random_(*args)
217
- end
218
- end
219
-
220
- def clamp!(min, max)
221
- _clamp_min_(min)
222
- _clamp_max_(max)
223
- end
224
-
225
- private
226
-
227
- def tensor_indexes(indexes)
228
- indexes.map do |index|
229
- case index
230
- when Integer
231
- TensorIndex.integer(index)
232
- when Range
233
- finish = index.end || -1
234
- if finish == -1 && !index.exclude_end?
235
- finish = nil
236
- else
237
- finish += 1 unless index.exclude_end?
238
- end
239
- TensorIndex.slice(index.begin, finish)
240
- when Tensor
241
- TensorIndex.tensor(index)
242
- when nil
243
- TensorIndex.none
244
- when true, false
245
- TensorIndex.boolean(index)
246
- else
247
- raise Error, "Unsupported index type: #{index.class.name}"
248
- end
249
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
250
176
  end
251
177
  end
252
178
  end