torch-rb 0.3.7 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -53,12 +53,12 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
64
64
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -70,7 +70,7 @@ module Torch
70
70
 
71
71
  step_size = group[:lr] / bias_correction1
72
72
 
73
- p.data.addcdiv!(-step_size, exp_avg, denom)
73
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
74
74
  end
75
75
  end
76
76
 
@@ -42,11 +42,11 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -57,7 +57,7 @@ module Torch
57
57
  bias_correction = 1 - beta1 ** state[:step]
58
58
  clr = group[:lr] / bias_correction
59
59
 
60
- p.data.addcdiv!(-clr, exp_avg, exp_inf)
60
+ p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
61
61
  end
62
62
  end
63
63
 
@@ -58,8 +58,8 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
62
- exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
+ exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
65
65
  Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
@@ -71,7 +71,7 @@ module Torch
71
71
 
72
72
  step_size = group[:lr] / bias_correction1
73
73
 
74
- p.data.addcdiv!(-step_size, exp_avg, denom)
74
+ p.data.addcdiv!(exp_avg, denom, value: -step_size)
75
75
  end
76
76
  end
77
77
 
@@ -36,14 +36,14 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -46,25 +46,25 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
- square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
52
+ square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
- p.data.addcdiv!(-group[:lr], grad, avg)
67
+ p.data.addcdiv!(grad, avg, value: -group[:lr])
68
68
  end
69
69
  end
70
70
  end
@@ -52,7 +52,7 @@ module Torch
52
52
  grad[sign.eq(etaminus)] = 0
53
53
 
54
54
  # update parameters
55
- p.data.addcmul!(-1, grad.sign, step_size)
55
+ p.data.addcmul!(grad.sign, step_size, value: -1)
56
56
 
57
57
  state[:prev].copy!(grad)
58
58
  end
@@ -32,24 +32,24 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
39
- if !param_state.key(:momentum_buffer)
39
+ if !param_state.key?(:momentum_buffer)
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
- d_p = d_p.add(momentum, buf)
46
+ d_p = d_p.add(buf, alpha: momentum)
47
47
  else
48
48
  d_p = buf
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55
 
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -73,12 +85,20 @@ module Torch
73
85
 
74
86
  def size(dim = nil)
75
87
  if dim
76
- _size_int(dim)
88
+ _size(dim)
77
89
  else
78
90
  shape
79
91
  end
80
92
  end
81
93
 
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
100
+ end
101
+
82
102
  # mirror Python len()
83
103
  def length
84
104
  size(0)
@@ -130,57 +150,6 @@ module Torch
130
150
  end
131
151
  end
132
152
 
133
- def reshape(*size)
134
- # Python doesn't check if size == 1, just ignores later arguments
135
- size = size.first if size.size == 1 && size.first.is_a?(Array)
136
- _reshape(size)
137
- end
138
-
139
- def view(*size)
140
- size = size.first if size.size == 1 && size.first.is_a?(Array)
141
- _view(size)
142
- end
143
-
144
- def +(other)
145
- add(other)
146
- end
147
-
148
- def -(other)
149
- sub(other)
150
- end
151
-
152
- def *(other)
153
- mul(other)
154
- end
155
-
156
- def /(other)
157
- div(other)
158
- end
159
-
160
- def %(other)
161
- remainder(other)
162
- end
163
-
164
- def **(other)
165
- pow(other)
166
- end
167
-
168
- def -@
169
- neg
170
- end
171
-
172
- def &(other)
173
- logical_and(other)
174
- end
175
-
176
- def |(other)
177
- logical_or(other)
178
- end
179
-
180
- def ^(other)
181
- logical_xor(other)
182
- end
183
-
184
153
  # TODO better compare?
185
154
  def <=>(other)
186
155
  item <=> other
@@ -189,7 +158,7 @@ module Torch
189
158
  # based on python_variable_indexing.cpp and
190
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
191
160
  def [](*indexes)
192
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
193
162
  end
194
163
 
195
164
  # based on python_variable_indexing.cpp and
@@ -197,62 +166,18 @@ module Torch
197
166
  def []=(*indexes, value)
198
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
199
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
200
- _index_put_custom(tensor_indexes(indexes), value)
201
- end
202
-
203
- # native functions that need manually defined
204
-
205
- # value and other are swapped for some methods
206
- def add!(value = 1, other)
207
- if other.is_a?(Numeric)
208
- _add__scalar(other, value)
209
- else
210
- _add__tensor(other, value)
211
- end
169
+ _index_put_custom(indexes, value)
212
170
  end
213
171
 
214
172
  # parser can't handle overlap, so need to handle manually
215
173
  def random!(*args)
216
- case args.size
217
- when 1
218
- _random__to(*args)
219
- when 2
220
- _random__from(*args)
221
- else
222
- _random_(*args)
223
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
224
176
  end
225
177
 
226
- def clamp!(min, max)
227
- _clamp_min_(min)
228
- _clamp_max_(max)
229
- end
230
-
231
- private
232
-
233
- def tensor_indexes(indexes)
234
- indexes.map do |index|
235
- case index
236
- when Integer
237
- TensorIndex.integer(index)
238
- when Range
239
- finish = index.end || -1
240
- if finish == -1 && !index.exclude_end?
241
- finish = nil
242
- else
243
- finish += 1 unless index.exclude_end?
244
- end
245
- TensorIndex.slice(index.begin, finish)
246
- when Tensor
247
- TensorIndex.tensor(index)
248
- when nil
249
- TensorIndex.none
250
- when true, false
251
- TensorIndex.boolean(index)
252
- else
253
- raise Error, "Unsupported index type: #{index.class.name}"
254
- end
255
- end
178
+ # center option
179
+ def stft(*args)
180
+ Torch.stft(*args)
256
181
  end
257
182
  end
258
183
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.7"
2
+ VERSION = "0.5.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.7
4
+ version: 0.5.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-09-23 00:00:00.000000000 Z
11
+ date: 2020-10-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,7 +108,7 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
- description:
111
+ description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
114
114
  extensions:
@@ -118,19 +118,23 @@ files:
118
118
  - CHANGELOG.md
119
119
  - LICENSE.txt
120
120
  - README.md
121
+ - codegen/function.rb
122
+ - codegen/generate_functions.rb
123
+ - codegen/native_functions.yaml
121
124
  - ext/torch/ext.cpp
122
125
  - ext/torch/extconf.rb
123
- - ext/torch/templates.cpp
124
- - ext/torch/templates.hpp
126
+ - ext/torch/nn_functions.h
127
+ - ext/torch/ruby_arg_parser.cpp
128
+ - ext/torch/ruby_arg_parser.h
129
+ - ext/torch/templates.h
130
+ - ext/torch/tensor_functions.h
131
+ - ext/torch/torch_functions.h
132
+ - ext/torch/utils.h
133
+ - ext/torch/wrap_outputs.h
125
134
  - lib/torch-rb.rb
126
135
  - lib/torch.rb
127
136
  - lib/torch/hub.rb
128
137
  - lib/torch/inspector.rb
129
- - lib/torch/native/dispatcher.rb
130
- - lib/torch/native/function.rb
131
- - lib/torch/native/generator.rb
132
- - lib/torch/native/native_functions.yaml
133
- - lib/torch/native/parser.rb
134
138
  - lib/torch/nn/adaptive_avg_pool1d.rb
135
139
  - lib/torch/nn/adaptive_avg_pool2d.rb
136
140
  - lib/torch/nn/adaptive_avg_pool3d.rb
@@ -270,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
270
274
  licenses:
271
275
  - BSD-3-Clause
272
276
  metadata: {}
273
- post_install_message:
277
+ post_install_message:
274
278
  rdoc_options: []
275
279
  require_paths:
276
280
  - lib
@@ -285,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
285
289
  - !ruby/object:Gem::Version
286
290
  version: '0'
287
291
  requirements: []
288
- rubygems_version: 3.1.2
289
- signing_key:
292
+ rubygems_version: 3.1.4
293
+ signing_key:
290
294
  specification_version: 4
291
295
  summary: Deep learning for Ruby, powered by LibTorch
292
296
  test_files: []
@@ -1,70 +0,0 @@
1
- # We use a generic interface for methods (*args, **options)
2
- # and this class to determine the C++ method to call
3
- #
4
- # This is needed since LibTorch uses function overloading,
5
- # which isn't available in Ruby or Python
6
- #
7
- # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
- #
9
- # We could generate Ruby methods directly, but an advantage of this approach is
10
- # arguments and keyword arguments can be used interchangably like in Python,
11
- # making it easier to port code
12
-
13
- module Torch
14
- module Native
15
- module Dispatcher
16
- class << self
17
- def bind
18
- functions = Generator.grouped_functions
19
- bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
- bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
- end
23
-
24
- def bind_functions(context, def_method, functions)
25
- instance_method = def_method == :define_method
26
- functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
27
- if instance_method
28
- funcs.map! { |f| Function.new(f.function) }
29
- funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
30
- end
31
-
32
- defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
33
- next if defined && name != "clone"
34
-
35
- # skip parser when possible for performance
36
- if funcs.size == 1 && funcs.first.args.size == 0
37
- # functions with no arguments
38
- if instance_method
39
- context.send(:alias_method, name, funcs.first.cpp_name)
40
- else
41
- context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
42
- end
43
- elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
44
- # functions that take a tensor or scalar
45
- scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
46
- context.send(def_method, name) do |other|
47
- case other
48
- when Tensor
49
- send(tensor_name, other)
50
- else
51
- send(scalar_name, other)
52
- end
53
- end
54
- else
55
- parser = Parser.new(funcs)
56
-
57
- context.send(def_method, name) do |*args, **options|
58
- result = parser.parse(args, options)
59
- raise ArgumentError, result[:error] if result[:error]
60
- send(result[:name], *result[:args])
61
- end
62
- end
63
- end
64
- end
65
- end
66
- end
67
- end
68
- end
69
-
70
- Torch::Native::Dispatcher.bind