torch-rb 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
@@ -1,8 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
- class MSELoss < Module
3
+ class MSELoss < Loss
4
4
  def initialize(reduction: "mean")
5
- @reduction = reduction
5
+ super(reduction)
6
6
  end
7
7
 
8
8
  def forward(input, target)
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class NLLLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PairwiseDistance < Module
4
+ def initialize(p: 2.0, eps: 1e-6, keepdim: false)
5
+ super()
6
+ @norm = p
7
+ @eps = eps
8
+ @keepdim = keepdim
9
+ end
10
+
11
+ def forward(x1, x2)
12
+ F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -5,10 +5,6 @@ module Torch
5
5
  data = Tensor.new unless data
6
6
  Tensor._make_subclass(data, requires_grad)
7
7
  end
8
-
9
- def grad
10
- _grad if _grad_defined
11
- end
12
8
  end
13
9
  end
14
10
  end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PoissonNLLLoss < Loss
4
+ def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
5
+ super(reduction)
6
+ @log_input = log_input
7
+ @full = full
8
+ @eps = eps
9
+ end
10
+
11
+ def forward(log_input, target)
12
+ F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class PReLU < Module
4
+ def initialize(num_parameters: 1, init: 0.25)
5
+ @num_parameters = num_parameters
6
+ super()
7
+ @weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
8
+ end
9
+
10
+ def forward(input)
11
+ F.prelu(input, @weight)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("num_parameters: %s", @num_parameters)
16
+ end
17
+ end
18
+ end
19
+ end
data/lib/torch/nn/relu.rb CHANGED
@@ -1,12 +1,17 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ReLU < Module
4
- def initialize #(inplace: false)
5
- # @inplace = inplace
4
+ def initialize(inplace: false)
5
+ super()
6
+ @inplace = inplace
6
7
  end
7
8
 
8
9
  def forward(input)
9
- F.relu(input) #, inplace: @inplace)
10
+ F.relu(input, inplace: @inplace)
11
+ end
12
+
13
+ def extra_inspect
14
+ @inplace ? "inplace: true" : ""
10
15
  end
11
16
  end
12
17
  end
@@ -2,28 +2,19 @@ module Torch
2
2
  module NN
3
3
  class Sequential < Module
4
4
  def initialize(*args)
5
- @modules = {}
5
+ super()
6
6
  # TODO support hash arg (named modules)
7
7
  args.each_with_index do |mod, idx|
8
8
  add_module(idx.to_s, mod)
9
9
  end
10
10
  end
11
11
 
12
- def add_module(name, mod)
13
- # TODO add checks
14
- @modules[name] = mod
15
- end
16
-
17
12
  def forward(input)
18
13
  @modules.values.each do |mod|
19
14
  input = mod.call(input)
20
15
  end
21
16
  input
22
17
  end
23
-
24
- def parameters
25
- @modules.flat_map { |_, mod| mod.parameters }
26
- end
27
18
  end
28
19
  end
29
20
  end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Sigmoid < Module
4
+ def forward(input)
5
+ Torch.sigmoid(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmax(input, dim: @dim)
11
+ end
12
+
13
+ def extra_inspect
14
+ format("dim: %s", @dim)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax2d < Module
4
+ def forward(input)
5
+ raise ArgumentError, "Softmax2d requires a 4D tensor as input" unless input.dim == 4
6
+ F.softmax(input, dim: 1)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class Softmin < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmin(input, dim: @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Softplus < Module
4
+ def initialize(beta: 1, threshold: 20)
5
+ super()
6
+ @beta = beta
7
+ @threshold = threshold
8
+ end
9
+
10
+ def forward(input)
11
+ F.softplus(input, beta: @beta, threshold: @threshold)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("beta: %s, threshold: %s", @beta, @threshold)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class WeightedLoss < Loss
4
+ def initialize(weight, reduction)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module Random
3
+ class << self
4
+ # not available through LibTorch
5
+ def initial_seed
6
+ raise NotImplementedYet
7
+ end
8
+ end
9
+ end
10
+ end
data/lib/torch/tensor.rb CHANGED
@@ -66,16 +66,11 @@ module Torch
66
66
  end
67
67
 
68
68
  def backward(gradient = nil)
69
- if gradient
70
- _backward_gradient(gradient)
71
- else
72
- _backward
73
- end
69
+ _backward(gradient)
74
70
  end
75
71
 
76
72
  # TODO read directly from memory
77
73
  def numo
78
- raise Error, "Numo not found" unless defined?(Numo::NArray)
79
74
  cls = Torch._dtype_to_numo[dtype]
80
75
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
81
76
  cls.cast(_data).reshape(*shape)
@@ -113,7 +108,7 @@ module Torch
113
108
  end
114
109
 
115
110
  # operations
116
- %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
111
+ %w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op|
117
112
  define_method(op) do |*args, **options, &block|
118
113
  if options.any?
119
114
  Torch.send(op, self, *args, **options, &block)
@@ -167,8 +162,14 @@ module Torch
167
162
  finish += 1 unless index.exclude_end?
168
163
  result = result._slice(dim, index.begin, finish, 1)
169
164
  dim += 1
165
+ elsif index.nil?
166
+ result = result.unsqueeze(dim)
167
+ dim += 1
168
+ elsif index == true
169
+ result = result.unsqueeze(dim)
170
+ # TODO handle false
170
171
  else
171
- raise Error, "Unsupported index type"
172
+ raise Error, "Unsupported index type: #{index.class.name}"
172
173
  end
173
174
  end
174
175
  result
@@ -176,11 +177,28 @@ module Torch
176
177
 
177
178
  # TODO
178
179
  # based on python_variable_indexing.cpp
179
- # def []=(index, value)
180
- # end
180
+ def []=(index, value)
181
+ raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
+
183
+ value = Torch.tensor(value) unless value.is_a?(Tensor)
184
+
185
+ if index.is_a?(Numeric)
186
+ copy_to(_select(0, index), value)
187
+ elsif index.is_a?(Range)
188
+ finish = index.end
189
+ finish += 1 unless index.exclude_end?
190
+ copy_to(_slice(0, index.begin, finish, 1), value)
191
+ else
192
+ raise Error, "Unsupported index type: #{index.class.name}"
193
+ end
194
+ end
181
195
 
182
196
  private
183
197
 
198
+ def copy_to(dst, src)
199
+ dst.copy!(src)
200
+ end
201
+
184
202
  def reshape_arr(arr, dims)
185
203
  if dims.empty?
186
204
  arr
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.3"
2
+ VERSION = "0.1.4"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.3
4
+ version: 0.1.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-11-30 00:00:00.000000000 Z
11
+ date: 2019-12-01 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -111,22 +111,48 @@ files:
111
111
  - lib/torch/ext.bundle
112
112
  - lib/torch/inspector.rb
113
113
  - lib/torch/nn/alpha_dropout.rb
114
+ - lib/torch/nn/avg_pool2d.rb
115
+ - lib/torch/nn/avg_poolnd.rb
116
+ - lib/torch/nn/bce_loss.rb
117
+ - lib/torch/nn/bilinear.rb
114
118
  - lib/torch/nn/conv2d.rb
115
119
  - lib/torch/nn/convnd.rb
120
+ - lib/torch/nn/cosine_similarity.rb
121
+ - lib/torch/nn/cross_entropy_loss.rb
122
+ - lib/torch/nn/ctc_loss.rb
116
123
  - lib/torch/nn/dropout.rb
117
124
  - lib/torch/nn/dropout2d.rb
118
125
  - lib/torch/nn/dropout3d.rb
119
126
  - lib/torch/nn/dropoutnd.rb
120
127
  - lib/torch/nn/embedding.rb
128
+ - lib/torch/nn/embedding_bag.rb
121
129
  - lib/torch/nn/feature_alpha_dropout.rb
122
130
  - lib/torch/nn/functional.rb
131
+ - lib/torch/nn/identity.rb
123
132
  - lib/torch/nn/init.rb
133
+ - lib/torch/nn/kl_div_loss.rb
134
+ - lib/torch/nn/l1_loss.rb
135
+ - lib/torch/nn/leaky_relu.rb
124
136
  - lib/torch/nn/linear.rb
137
+ - lib/torch/nn/log_softmax.rb
138
+ - lib/torch/nn/loss.rb
139
+ - lib/torch/nn/max_pool2d.rb
140
+ - lib/torch/nn/max_poolnd.rb
125
141
  - lib/torch/nn/module.rb
126
142
  - lib/torch/nn/mse_loss.rb
143
+ - lib/torch/nn/nll_loss.rb
144
+ - lib/torch/nn/pairwise_distance.rb
127
145
  - lib/torch/nn/parameter.rb
146
+ - lib/torch/nn/poisson_nll_loss.rb
147
+ - lib/torch/nn/prelu.rb
128
148
  - lib/torch/nn/relu.rb
129
149
  - lib/torch/nn/sequential.rb
150
+ - lib/torch/nn/sigmoid.rb
151
+ - lib/torch/nn/softmax.rb
152
+ - lib/torch/nn/softmax2d.rb
153
+ - lib/torch/nn/softmin.rb
154
+ - lib/torch/nn/softplus.rb
155
+ - lib/torch/nn/weighted_loss.rb
130
156
  - lib/torch/optim/adadelta.rb
131
157
  - lib/torch/optim/adagrad.rb
132
158
  - lib/torch/optim/adam.rb
@@ -139,6 +165,7 @@ files:
139
165
  - lib/torch/optim/rmsprop.rb
140
166
  - lib/torch/optim/rprop.rb
141
167
  - lib/torch/optim/sgd.rb
168
+ - lib/torch/random.rb
142
169
  - lib/torch/tensor.rb
143
170
  - lib/torch/utils/data/data_loader.rb
144
171
  - lib/torch/utils/data/tensor_dataset.rb