torch-rb 0.3.7 → 0.5.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -53,12 +53,12 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,11 +42,11 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -58,8 +58,8 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,14 +36,14 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end
@@ -52,7 +52,7 @@ module Torch
52
52
  grad[sign.eq(etaminus)] = 0
53
53
 
54
54
  # update parameters
55
- p.data.addcmul!(-1, grad.sign, step_size)
55
+ p.data.addcmul!(grad.sign, step_size, value: -1)
56
56
 
57
57
  state[:prev].copy!(grad)
58
58
  end
@@ -32,24 +32,24 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
39
- if !param_state.key(:momentum_buffer)
39
+ if !param_state.key?(:momentum_buffer)
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
- d_p = d_p.add(momentum, buf)
46
+ d_p = d_p.add(buf, alpha: momentum)
47
47
  else
48
48
  d_p = buf
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55
 
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -73,12 +85,20 @@ module Torch
73
85
 
74
86
  def size(dim = nil)
75
87
  if dim
76
- _size_int(dim)
88
+ _size(dim)
77
89
  else
78
90
  shape
79
91
  end
80
92
  end
81
93
 
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
100
+ end
101
+
82
102
  # mirror Python len()
83
103
  def length
84
104
  size(0)
@@ -130,57 +150,6 @@ module Torch
130
150
  end
131
151
  end
132
152
 
133
- def reshape(*size)
134
- # Python doesn't check if size == 1, just ignores later arguments
135
- size = size.first if size.size == 1 && size.first.is_a?(Array)
136
- _reshape(size)
137
- end
138
-
139
- def view(*size)
140
- size = size.first if size.size == 1 && size.first.is_a?(Array)
141
- _view(size)
142
- end
143
-
144
- def +(other)
145
- add(other)
146
- end
147
-
148
- def -(other)
149
- sub(other)
150
- end
151
-
152
- def *(other)
153
- mul(other)
154
- end
155
-
156
- def /(other)
157
- div(other)
158
- end
159
-
160
- def %(other)
161
- remainder(other)
162
- end
163
-
164
- def **(other)
165
- pow(other)
166
- end
167
-
168
- def -@
169
- neg
170
- end
171
-
172
- def &(other)
173
- logical_and(other)
174
- end
175
-
176
- def |(other)
177
- logical_or(other)
178
- end
179
-
180
- def ^(other)
181
- logical_xor(other)
182
- end
183
-
184
153
  # TODO better compare?
185
154
  def <=>(other)
186
155
  item <=> other
@@ -189,7 +158,7 @@ module Torch
189
158
  # based on python_variable_indexing.cpp and
190
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
191
160
  def [](*indexes)
192
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
193
162
  end
194
163
 
195
164
  # based on python_variable_indexing.cpp and
@@ -197,62 +166,18 @@ module Torch
197
166
  def []=(*indexes, value)
198
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
199
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
200
- _index_put_custom(tensor_indexes(indexes), value)
201
- end
202
-
203
- # native functions that need manually defined
204
-
205
- # value and other are swapped for some methods
206
- def add!(value = 1, other)
207
- if other.is_a?(Numeric)
208
- _add__scalar(other, value)
209
- else
210
- _add__tensor(other, value)
211
- end
169
+ _index_put_custom(indexes, value)
212
170
  end
213
171
 
214
172
  # parser can't handle overlap, so need to handle manually
215
173
  def random!(*args)
216
- case args.size
217
- when 1
218
- _random__to(*args)
219
- when 2
220
- _random__from(*args)
221
- else
222
- _random_(*args)
223
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
224
176
  end
225
177
 
226
- def clamp!(min, max)
227
- _clamp_min_(min)
228
- _clamp_max_(max)
229
- end
230
-
231
- private
232
-
233
- def tensor_indexes(indexes)
234
- indexes.map do |index|
235
- case index
236
- when Integer
237
- TensorIndex.integer(index)
238
- when Range
239
- finish = index.end || -1
240
- if finish == -1 && !index.exclude_end?
241
- finish = nil
242
- else
243
- finish += 1 unless index.exclude_end?
244
- end
245
- TensorIndex.slice(index.begin, finish)
246
- when Tensor
247
- TensorIndex.tensor(index)
248
- when nil
249
- TensorIndex.none
250
- when true, false
251
- TensorIndex.boolean(index)
252
- else
253
- raise Error, "Unsupported index type: #{index.class.name}"
254
- end
255
- end
178
+ # center option
179
+ def stft(*args)
180
+ Torch.stft(*args)
256
181
  end
257
182
  end
258
183
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.7"
2
+ VERSION = "0.5.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.7
4
+ version: 0.5.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-23 00:00:00.000000000 Z
11
+ date: 2020-10-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,7 +108,7 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
- description:
111
+ description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
114
114
  extensions:
@@ -118,19 +118,23 @@ files:
118
118
  - CHANGELOG.md
119
119
  - LICENSE.txt
120
120
  - README.md
121
+ - codegen/function.rb
122
+ - codegen/generate_functions.rb
123
+ - codegen/native_functions.yaml
121
124
  - ext/torch/ext.cpp
122
125
  - ext/torch/extconf.rb
123
- - ext/torch/templates.cpp
124
- - ext/torch/templates.hpp
126
+ - ext/torch/nn_functions.h
127
+ - ext/torch/ruby_arg_parser.cpp
128
+ - ext/torch/ruby_arg_parser.h
129
+ - ext/torch/templates.h
130
+ - ext/torch/tensor_functions.h
131
+ - ext/torch/torch_functions.h
132
+ - ext/torch/utils.h
133
+ - ext/torch/wrap_outputs.h
125
134
  - lib/torch-rb.rb
126
135
  - lib/torch.rb
127
136
  - lib/torch/hub.rb
128
137
  - lib/torch/inspector.rb
129
- - lib/torch/native/dispatcher.rb
130
- - lib/torch/native/function.rb
131
- - lib/torch/native/generator.rb
132
- - lib/torch/native/native_functions.yaml
133
- - lib/torch/native/parser.rb
134
138
  - lib/torch/nn/adaptive_avg_pool1d.rb
135
139
  - lib/torch/nn/adaptive_avg_pool2d.rb
136
140
  - lib/torch/nn/adaptive_avg_pool3d.rb
@@ -270,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
270
274
  licenses:
271
275
  - BSD-3-Clause
272
276
  metadata: {}
273
- post_install_message:
277
+ post_install_message:
274
278
  rdoc_options: []
275
279
  require_paths:
276
280
  - lib
@@ -285,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
285
289
  - !ruby/object:Gem::Version
286
290
  version: '0'
287
291
  requirements: []
288
- rubygems_version: 3.1.2
289
- signing_key:
292
+ rubygems_version: 3.1.4
293
+ signing_key:
290
294
  specification_version: 4
291
295
  summary: Deep learning for Ruby, powered by LibTorch
292
296
  test_files: []
@@ -1,70 +0,0 @@
1
- # We use a generic interface for methods (*args, **options)
2
- # and this class to determine the C++ method to call
3
- #
4
- # This is needed since LibTorch uses function overloading,
5
- # which isn't available in Ruby or Python
6
- #
7
- # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
- #
9
- # We could generate Ruby methods directly, but an advantage of this approach is
10
- # arguments and keyword arguments can be used interchangably like in Python,
11
- # making it easier to port code
12
-
13
- module Torch
14
- module Native
15
- module Dispatcher
16
- class << self
17
- def bind
18
- functions = Generator.grouped_functions
19
- bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
- bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
- end
23
-
24
- def bind_functions(context, def_method, functions)
25
- instance_method = def_method == :define_method
26
- functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
27
- if instance_method
28
- funcs.map! { |f| Function.new(f.function) }
29
- funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
30
- end
31
-
32
- defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
33
- next if defined && name != "clone"
34
-
35
- # skip parser when possible for performance
36
- if funcs.size == 1 && funcs.first.args.size == 0
37
- # functions with no arguments
38
- if instance_method
39
- context.send(:alias_method, name, funcs.first.cpp_name)
40
- else
41
- context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
- end
43
- elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
- # functions that take a tensor or scalar
45
- scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
- context.send(def_method, name) do |other|
47
- case other
48
- when Tensor
49
- send(tensor_name, other)
50
- else
51
- send(scalar_name, other)
52
- end
53
- end
54
- else
55
- parser = Parser.new(funcs)
56
-
57
- context.send(def_method, name) do |*args, **options|
58
- result = parser.parse(args, options)
59
- raise ArgumentError, result[:error] if result[:error]
60
- send(result[:name], *result[:args])
61
- end
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- end
69
-
70
- Torch::Native::Dispatcher.bind