torch-rb 0.3.4 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -1,14 +1,14 @@
1
1
  module Torch
2
2
  module NN
3
3
  class LeakyReLU < Module
4
- def initialize(negative_slope: 1e-2) #, inplace: false)
4
+ def initialize(negative_slope: 1e-2, inplace: false)
5
5
  super()
6
6
  @negative_slope = negative_slope
7
- # @inplace = inplace
7
+ @inplace = inplace
8
8
  end
9
9
 
10
10
  def forward(input)
11
- F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
11
+ F.leaky_relu(input, @negative_slope, inplace: @inplace)
12
12
  end
13
13
 
14
14
  def extra_inspect
@@ -55,7 +55,15 @@ module Torch
55
55
  end
56
56
  end
57
57
  end
58
- # TODO apply to more objects
58
+
59
+ @buffers.each_key do |k|
60
+ buf = @buffers[k]
61
+ unless buf.nil?
62
+ @buffers[k] = fn.call(buf)
63
+ instance_variable_set("@#{k}", @buffers[k])
64
+ end
65
+ end
66
+
59
67
  self
60
68
  end
61
69
 
@@ -0,0 +1,31 @@
1
+ module Torch
2
+ module NN
3
+ class Upsample < Module
4
+ def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
5
+ super()
6
+ @size = size
7
+ if scale_factor.is_a?(Array)
8
+ @scale_factor = scale_factor.map(&:to_f)
9
+ else
10
+ @scale_factor = scale_factor ? scale_factor.to_f : nil
11
+ end
12
+ @mode = mode
13
+ @align_corners = align_corners
14
+ end
15
+
16
+ def forward(input)
17
+ F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
18
+ end
19
+
20
+ def extra_inspect
21
+ if !@scale_factor.nil?
22
+ info = "scale_factor: #{@scale_factor.inspect}"
23
+ else
24
+ info = "size: #{@size.inspect}"
25
+ end
26
+ info += ", mode: #{@mode.inspect}"
27
+ info
28
+ end
29
+ end
30
+ end
31
+ end
@@ -45,7 +45,7 @@ module Torch
45
45
  square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
49
  acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
50
50
  end
51
51
  end
@@ -53,11 +53,11 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
61
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -46,7 +46,7 @@ module Torch
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -58,7 +58,7 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
62
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -43,7 +43,7 @@ module Torch
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -32,7 +32,7 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
@@ -40,7 +40,7 @@ module Torch
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
46
  d_p = d_p.add(momentum, buf)
@@ -49,7 +49,7 @@ module Torch
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55
 
@@ -8,6 +8,18 @@ module Torch
8
8
  alias_method :ndim, :dim
9
9
  alias_method :ndimension, :dim
10
10
 
11
+ # use alias_method for performance
12
+ alias_method :+, :add
13
+ alias_method :-, :sub
14
+ alias_method :*, :mul
15
+ alias_method :/, :div
16
+ alias_method :%, :remainder
17
+ alias_method :**, :pow
18
+ alias_method :-@, :neg
19
+ alias_method :&, :logical_and
20
+ alias_method :|, :logical_or
21
+ alias_method :^, :logical_xor
22
+
11
23
  def self.new(*args)
12
24
  FloatTensor.new(*args)
13
25
  end
@@ -48,6 +60,11 @@ module Torch
48
60
  end
49
61
 
50
62
  def to(device = nil, dtype: nil, non_blocking: false, copy: false)
63
+ if device.is_a?(Symbol) && !dtype
64
+ dtype = device
65
+ device = nil
66
+ end
67
+
51
68
  device ||= self.device
52
69
  device = Device.new(device) if device.is_a?(String)
53
70
 
@@ -68,14 +85,18 @@ module Torch
68
85
 
69
86
  def size(dim = nil)
70
87
  if dim
71
- _size_int(dim)
88
+ _size(dim)
72
89
  else
73
90
  shape
74
91
  end
75
92
  end
76
93
 
77
- def shape
78
- dim.times.map { |i| size(i) }
94
+ def stride(dim = nil)
95
+ if dim
96
+ _stride(dim)
97
+ else
98
+ _strides
99
+ end
79
100
  end
80
101
 
81
102
  # mirror Python len()
@@ -103,11 +124,6 @@ module Torch
103
124
  Torch.empty(0, dtype: dtype)
104
125
  end
105
126
 
106
- def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
- retain_graph = create_graph if retain_graph.nil?
108
- _backward(gradient, retain_graph, create_graph)
109
- end
110
-
111
127
  # TODO read directly from memory
112
128
  def numo
113
129
  cls = Torch._dtype_to_numo[dtype]
@@ -124,60 +140,14 @@ module Torch
124
140
  end
125
141
 
126
142
  def type(dtype)
127
- enum = DTYPE_TO_ENUM[dtype]
128
- raise Error, "Unknown type: #{dtype}" unless enum
129
- _type(enum)
130
- end
131
-
132
- def reshape(*size)
133
- # Python doesn't check if size == 1, just ignores later arguments
134
- size = size.first if size.size == 1 && size.first.is_a?(Array)
135
- _reshape(size)
136
- end
137
-
138
- def view(*size)
139
- size = size.first if size.size == 1 && size.first.is_a?(Array)
140
- _view(size)
141
- end
142
-
143
- def +(other)
144
- add(other)
145
- end
146
-
147
- def -(other)
148
- sub(other)
149
- end
150
-
151
- def *(other)
152
- mul(other)
153
- end
154
-
155
- def /(other)
156
- div(other)
157
- end
158
-
159
- def %(other)
160
- remainder(other)
161
- end
162
-
163
- def **(other)
164
- pow(other)
165
- end
166
-
167
- def -@
168
- neg
169
- end
170
-
171
- def &(other)
172
- logical_and(other)
173
- end
174
-
175
- def |(other)
176
- logical_or(other)
177
- end
178
-
179
- def ^(other)
180
- logical_xor(other)
143
+ if dtype.is_a?(Class)
144
+ raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
145
+ dtype.new(self)
146
+ else
147
+ enum = DTYPE_TO_ENUM[dtype]
148
+ raise Error, "Invalid type: #{dtype}" unless enum
149
+ _type(enum)
150
+ end
181
151
  end
182
152
 
183
153
  # TODO better compare?
@@ -188,7 +158,7 @@ module Torch
188
158
  # based on python_variable_indexing.cpp and
189
159
  # https://pytorch.org/cppdocs/notes/tensor_indexing.html
190
160
  def [](*indexes)
191
- _index(tensor_indexes(indexes))
161
+ _index(indexes)
192
162
  end
193
163
 
194
164
  # based on python_variable_indexing.cpp and
@@ -196,62 +166,13 @@ module Torch
196
166
  def []=(*indexes, value)
197
167
  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
198
168
  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
199
- _index_put_custom(tensor_indexes(indexes), value)
200
- end
201
-
202
- # native functions that need manually defined
203
-
204
- # value and other are swapped for some methods
205
- def add!(value = 1, other)
206
- if other.is_a?(Numeric)
207
- _add__scalar(other, value)
208
- else
209
- _add__tensor(other, value)
210
- end
169
+ _index_put_custom(indexes, value)
211
170
  end
212
171
 
213
172
  # parser can't handle overlap, so need to handle manually
214
173
  def random!(*args)
215
- case args.size
216
- when 1
217
- _random__to(*args)
218
- when 2
219
- _random__from(*args)
220
- else
221
- _random_(*args)
222
- end
223
- end
224
-
225
- def clamp!(min, max)
226
- _clamp_min_(min)
227
- _clamp_max_(max)
228
- end
229
-
230
- private
231
-
232
- def tensor_indexes(indexes)
233
- indexes.map do |index|
234
- case index
235
- when Integer
236
- TensorIndex.integer(index)
237
- when Range
238
- finish = index.end
239
- if finish == -1 && !index.exclude_end?
240
- finish = nil
241
- else
242
- finish += 1 unless index.exclude_end?
243
- end
244
- TensorIndex.slice(index.begin, finish)
245
- when Tensor
246
- TensorIndex.tensor(index)
247
- when nil
248
- TensorIndex.none
249
- when true, false
250
- TensorIndex.boolean(index)
251
- else
252
- raise Error, "Unsupported index type: #{index.class.name}"
253
- end
254
- end
174
+ return _random!(0, *args) if args.size == 1
175
+ _random!(*args)
255
176
  end
256
177
  end
257
178
  end
@@ -45,6 +45,8 @@ module Torch
45
45
  def size
46
46
  (@dataset.size / @batch_size.to_f).ceil
47
47
  end
48
+ alias_method :length, :size
49
+ alias_method :count, :size
48
50
 
49
51
  private
50
52
 
@@ -16,6 +16,8 @@ module Torch
16
16
  def size
17
17
  @tensors[0].size(0)
18
18
  end
19
+ alias_method :length, :size
20
+ alias_method :count, :size
19
21
  end
20
22
  end
21
23
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.4"
2
+ VERSION = "0.4.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.4
4
+ version: 0.4.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-26 00:00:00.000000000 Z
11
+ date: 2020-10-13 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,7 +108,7 @@ dependencies:
108
108
  - - ">="
109
109
  - !ruby/object:Gem::Version
110
110
  version: 0.1.1
111
- description:
111
+ description:
112
112
  email: andrew@chartkick.com
113
113
  executables: []
114
114
  extensions:
@@ -118,19 +118,23 @@ files:
118
118
  - CHANGELOG.md
119
119
  - LICENSE.txt
120
120
  - README.md
121
+ - codegen/function.rb
122
+ - codegen/generate_functions.rb
123
+ - codegen/native_functions.yaml
121
124
  - ext/torch/ext.cpp
122
125
  - ext/torch/extconf.rb
123
- - ext/torch/templates.cpp
124
- - ext/torch/templates.hpp
126
+ - ext/torch/nn_functions.h
127
+ - ext/torch/ruby_arg_parser.cpp
128
+ - ext/torch/ruby_arg_parser.h
129
+ - ext/torch/templates.h
130
+ - ext/torch/tensor_functions.h
131
+ - ext/torch/torch_functions.h
132
+ - ext/torch/utils.h
133
+ - ext/torch/wrap_outputs.h
125
134
  - lib/torch-rb.rb
126
135
  - lib/torch.rb
127
136
  - lib/torch/hub.rb
128
137
  - lib/torch/inspector.rb
129
- - lib/torch/native/dispatcher.rb
130
- - lib/torch/native/function.rb
131
- - lib/torch/native/generator.rb
132
- - lib/torch/native/native_functions.yaml
133
- - lib/torch/native/parser.rb
134
138
  - lib/torch/nn/adaptive_avg_pool1d.rb
135
139
  - lib/torch/nn/adaptive_avg_pool2d.rb
136
140
  - lib/torch/nn/adaptive_avg_pool3d.rb
@@ -238,6 +242,7 @@ files:
238
242
  - lib/torch/nn/tanhshrink.rb
239
243
  - lib/torch/nn/triplet_margin_loss.rb
240
244
  - lib/torch/nn/unfold.rb
245
+ - lib/torch/nn/upsample.rb
241
246
  - lib/torch/nn/utils.rb
242
247
  - lib/torch/nn/weighted_loss.rb
243
248
  - lib/torch/nn/zero_pad2d.rb
@@ -269,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
269
274
  licenses:
270
275
  - BSD-3-Clause
271
276
  metadata: {}
272
- post_install_message:
277
+ post_install_message:
273
278
  rdoc_options: []
274
279
  require_paths:
275
280
  - lib
@@ -284,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
284
289
  - !ruby/object:Gem::Version
285
290
  version: '0'
286
291
  requirements: []
287
- rubygems_version: 3.1.2
288
- signing_key:
292
+ rubygems_version: 3.0.3
293
+ signing_key:
289
294
  specification_version: 4
290
295
  summary: Deep learning for Ruby, powered by LibTorch
291
296
  test_files: []