torch-rb 0.3.3 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class LeakyReLU < Module
4
- def initialize(negative_slope: 1e-2) #, inplace: false)
4
+ def initialize(negative_slope: 1e-2, inplace: false)
5
5
  super()
6
6
  @negative_slope = negative_slope
7
- # @inplace = inplace
7
+ @inplace = inplace
8
8
  end
9
9
 
10
10
  def forward(input)
11
- F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
11
+ F.leaky_relu(input, @negative_slope, inplace: @inplace)
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -55,7 +55,15 @@ module Torch
55
55
  end
56
56
  end
57
57
  end
58
- # TODO apply to more objects
58
+
59
+ @buffers.each_key do |k|
60
+ buf = @buffers[k]
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
65
+ end
66
+
59
67
  self
60
68
  end
61
69
 
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -45,7 +45,7 @@ module Torch
45
45
  square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
49
  acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
50
50
  end
51
51
  end
@@ -53,11 +53,11 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
61
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -46,7 +46,7 @@ module Torch
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -58,7 +58,7 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
62
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -43,7 +43,7 @@ module Torch
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -32,7 +32,7 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
@@ -40,7 +40,7 @@ module Torch
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
46
  d_p = d_p.add(momentum, buf)
@@ -49,7 +49,7 @@ module Torch
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55
 
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -48,6 +60,11 @@ module Torch
48
60
  end
49
61
 
50
62
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
63
+ if device.is_a?(Symbol) && !dtype
64
+ dtype = device
65
+ device = nil
66
+ end
67
+
51
68
  device ||= self.device
52
69
  device = Device.new(device) if device.is_a?(String)
53
70
 
@@ -68,14 +85,18 @@ module Torch
68
85
 
69
86
  def size(dim = nil)
70
87
  if dim
71
- _size_int(dim)
88
+ _size(dim)
72
89
  else
73
90
  shape
74
91
  end
75
92
  end
76
93
 
77
- def shape
78
- dim.times.map { |i| size(i) }
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
79
100
  end
80
101
 
81
102
  # mirror Python len()
@@ -103,11 +124,6 @@ module Torch
103
124
  Torch.empty(0, dtype: dtype)
104
125
  end
105
126
 
106
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
- retain_graph = create_graph if retain_graph.nil?
108
- _backward(gradient, retain_graph, create_graph)
109
- end
110
-
111
127
  # TODO read directly from memory
112
128
  def numo
113
129
  cls = Torch._dtype_to_numo[dtype]
@@ -124,60 +140,14 @@ module Torch
124
140
  end
125
141
 
126
142
  def type(dtype)
127
- enum = DTYPE_TO_ENUM[dtype]
128
- raise Error, "Unknown type: #{dtype}" unless enum
129
- _type(enum)
130
- end
131
-
132
- def reshape(*size)
133
- # Python doesn't check if size == 1, just ignores later arguments
134
- size = size.first if size.size == 1 && size.first.is_a?(Array)
135
- _reshape(size)
136
- end
137
-
138
- def view(*size)
139
- size = size.first if size.size == 1 && size.first.is_a?(Array)
140
- _view(size)
141
- end
142
-
143
- def +(other)
144
- add(other)
145
- end
146
-
147
- def -(other)
148
- sub(other)
149
- end
150
-
151
- def *(other)
152
- mul(other)
153
- end
154
-
155
- def /(other)
156
- div(other)
157
- end
158
-
159
- def %(other)
160
- remainder(other)
161
- end
162
-
163
- def **(other)
164
- pow(other)
165
- end
166
-
167
- def -@
168
- neg
169
- end
170
-
171
- def &(other)
172
- logical_and(other)
173
- end
174
-
175
- def |(other)
176
- logical_or(other)
177
- end
178
-
179
- def ^(other)
180
- logical_xor(other)
143
+ if dtype.is_a?(Class)
144
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
145
+ dtype.new(self)
146
+ else
147
+ enum = DTYPE_TO_ENUM[dtype]
148
+ raise Error, "Invalid type: #{dtype}" unless enum
149
+ _type(enum)
150
+ end
181
151
  end
182
152
 
183
153
  # TODO better compare?
@@ -188,7 +158,7 @@ module Torch
188
158
  # based on python_variable_indexing.cpp and
189
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
160
  def [](*indexes)
191
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
192
162
  end
193
163
 
194
164
  # based on python_variable_indexing.cpp and
@@ -196,62 +166,13 @@ module Torch
196
166
  def []=(*indexes, value)
197
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
198
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
199
- _index_put_custom(tensor_indexes(indexes), value)
200
- end
201
-
202
- # native functions that need manually defined
203
-
204
- # value and other are swapped for some methods
205
- def add!(value = 1, other)
206
- if other.is_a?(Numeric)
207
- _add__scalar(other, value)
208
- else
209
- _add__tensor(other, value)
210
- end
169
+ _index_put_custom(indexes, value)
211
170
  end
212
171
 
213
172
  # parser can't handle overlap, so need to handle manually
214
173
  def random!(*args)
215
- case args.size
216
- when 1
217
- _random__to(*args)
218
- when 2
219
- _random__from(*args)
220
- else
221
- _random_(*args)
222
- end
223
- end
224
-
225
- def clamp!(min, max)
226
- _clamp_min_(min)
227
- _clamp_max_(max)
228
- end
229
-
230
- private
231
-
232
- def tensor_indexes(indexes)
233
- indexes.map do |index|
234
- case index
235
- when Integer
236
- TensorIndex.integer(index)
237
- when Range
238
- finish = index.end
239
- if finish == -1 && !index.exclude_end?
240
- finish = nil
241
- else
242
- finish += 1 unless index.exclude_end?
243
- end
244
- TensorIndex.slice(index.begin, finish)
245
- when Tensor
246
- TensorIndex.tensor(index)
247
- when nil
248
- TensorIndex.none
249
- when true, false
250
- TensorIndex.boolean(index)
251
- else
252
- raise Error, "Unsupported index type: #{index.class.name}"
253
- end
254
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
255
176
  end
256
177
  end
257
178
  end
@@ -45,6 +45,8 @@ module Torch
45
45
  def size
46
46
  (@dataset.size / @batch_size.to_f).ceil
47
47
  end
48
+ alias_method :length, :size
49
+ alias_method :count, :size
48
50
 
49
51
  private
50
52
 
@@ -16,6 +16,8 @@ module Torch
16
16
  def size
17
17
  @tensors[0].size(0)
18
18
  end
19
+ alias_method :length, :size
20
+ alias_method :count, :size
19
21
  end
20
22
  end
21
23
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.3"
2
+ VERSION = "0.4.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.3
4
+ version: 0.4.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-26 00:00:00.000000000 Z
11
+ date: 2020-09-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,6 +108,20 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
+ - !ruby/object:Gem::Dependency
112
+ name: magro
113
+ requirement: !ruby/object:Gem::Requirement
114
+ requirements:
115
+ - - ">="
116
+ - !ruby/object:Gem::Version
117
+ version: '0'
118
+ type: :development
119
+ prerelease: false
120
+ version_requirements: !ruby/object:Gem::Requirement
121
+ requirements:
122
+ - - ">="
123
+ - !ruby/object:Gem::Version
124
+ version: '0'
111
125
  description:
112
126
  email: andrew@chartkick.com
113
127
  executables: []
@@ -118,19 +132,23 @@ files:
118
132
  - CHANGELOG.md
119
133
  - LICENSE.txt
120
134
  - README.md
135
+ - codegen/function.rb
136
+ - codegen/generate_functions.rb
137
+ - codegen/native_functions.yaml
121
138
  - ext/torch/ext.cpp
122
139
  - ext/torch/extconf.rb
123
- - ext/torch/templates.cpp
124
- - ext/torch/templates.hpp
140
+ - ext/torch/nn_functions.h
141
+ - ext/torch/ruby_arg_parser.cpp
142
+ - ext/torch/ruby_arg_parser.h
143
+ - ext/torch/templates.h
144
+ - ext/torch/tensor_functions.h
145
+ - ext/torch/torch_functions.h
146
+ - ext/torch/utils.h
147
+ - ext/torch/wrap_outputs.h
125
148
  - lib/torch-rb.rb
126
149
  - lib/torch.rb
127
150
  - lib/torch/hub.rb
128
151
  - lib/torch/inspector.rb
129
- - lib/torch/native/dispatcher.rb
130
- - lib/torch/native/function.rb
131
- - lib/torch/native/generator.rb
132
- - lib/torch/native/native_functions.yaml
133
- - lib/torch/native/parser.rb
134
152
  - lib/torch/nn/adaptive_avg_pool1d.rb
135
153
  - lib/torch/nn/adaptive_avg_pool2d.rb
136
154
  - lib/torch/nn/adaptive_avg_pool3d.rb
@@ -238,6 +256,7 @@ files:
238
256
  - lib/torch/nn/tanhshrink.rb
239
257
  - lib/torch/nn/triplet_margin_loss.rb
240
258
  - lib/torch/nn/unfold.rb
259
+ - lib/torch/nn/upsample.rb
241
260
  - lib/torch/nn/utils.rb
242
261
  - lib/torch/nn/weighted_loss.rb
243
262
  - lib/torch/nn/zero_pad2d.rb