torch-rb 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
@@ -1,8 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
- class MSELoss < Module
3
+ class MSELoss < Loss
4
4
  def initialize(reduction: "mean")
5
- @reduction = reduction
5
+ super(reduction)
6
6
  end
7
7
 
8
8
  def forward(input, target)
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class NLLLoss < WeightedLoss
4
+ def initialize(weight: nil, ignore_index: -100, reduction: "mean")
5
+ super(weight, reduction)
6
+ @ignore_index = ignore_index
7
+ end
8
+
9
+ def forward(input, target)
10
+ F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PairwiseDistance < Module
4
+ def initialize(p: 2.0, eps: 1e-6, keepdim: false)
5
+ super()
6
+ @norm = p
7
+ @eps = eps
8
+ @keepdim = keepdim
9
+ end
10
+
11
+ def forward(x1, x2)
12
+ F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -5,10 +5,6 @@ module Torch
5
5
  data = Tensor.new unless data
6
6
  Tensor._make_subclass(data, requires_grad)
7
7
  end
8
-
9
- def grad
10
- _grad if _grad_defined
11
- end
12
8
  end
13
9
  end
14
10
  end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class PoissonNLLLoss < Loss
4
+ def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
5
+ super(reduction)
6
+ @log_input = log_input
7
+ @full = full
8
+ @eps = eps
9
+ end
10
+
11
+ def forward(log_input, target)
12
+ F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class PReLU < Module
4
+ def initialize(num_parameters: 1, init: 0.25)
5
+ @num_parameters = num_parameters
6
+ super()
7
+ @weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
8
+ end
9
+
10
+ def forward(input)
11
+ F.prelu(input, @weight)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("num_parameters: %s", @num_parameters)
16
+ end
17
+ end
18
+ end
19
+ end
data/lib/torch/nn/relu.rb CHANGED
@@ -1,12 +1,17 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ReLU < Module
4
- def initialize #(inplace: false)
5
- # @inplace = inplace
4
+ def initialize(inplace: false)
5
+ super()
6
+ @inplace = inplace
6
7
  end
7
8
 
8
9
  def forward(input)
9
- F.relu(input) #, inplace: @inplace)
10
+ F.relu(input, inplace: @inplace)
11
+ end
12
+
13
+ def extra_inspect
14
+ @inplace ? "inplace: true" : ""
10
15
  end
11
16
  end
12
17
  end
@@ -2,28 +2,19 @@ module Torch
2
2
  module NN
3
3
  class Sequential < Module
4
4
  def initialize(*args)
5
- @modules = {}
5
+ super()
6
6
  # TODO support hash arg (named modules)
7
7
  args.each_with_index do |mod, idx|
8
8
  add_module(idx.to_s, mod)
9
9
  end
10
10
  end
11
11
 
12
- def add_module(name, mod)
13
- # TODO add checks
14
- @modules[name] = mod
15
- end
16
-
17
12
  def forward(input)
18
13
  @modules.values.each do |mod|
19
14
  input = mod.call(input)
20
15
  end
21
16
  input
22
17
  end
23
-
24
- def parameters
25
- @modules.flat_map { |_, mod| mod.parameters }
26
- end
27
18
  end
28
19
  end
29
20
  end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Sigmoid < Module
4
+ def forward(input)
5
+ Torch.sigmoid(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmax(input, dim: @dim)
11
+ end
12
+
13
+ def extra_inspect
14
+ format("dim: %s", @dim)
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Softmax2d < Module
4
+ def forward(input)
5
+ raise ArgumentError, "Softmax2d requires a 4D tensor as input" unless input.dim == 4
6
+ F.softmax(input, dim: 1)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class Softmin < Module
4
+ def initialize(dim: nil)
5
+ super()
6
+ @dim = dim
7
+ end
8
+
9
+ def forward(input)
10
+ F.softmin(input, dim: @dim)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Softplus < Module
4
+ def initialize(beta: 1, threshold: 20)
5
+ super()
6
+ @beta = beta
7
+ @threshold = threshold
8
+ end
9
+
10
+ def forward(input)
11
+ F.softplus(input, beta: @beta, threshold: @threshold)
12
+ end
13
+
14
+ def extra_inspect
15
+ format("beta: %s, threshold: %s", @beta, @threshold)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class WeightedLoss < Loss
4
+ def initialize(weight, reduction)
5
+ super(reduction)
6
+ register_buffer("weight", weight)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module Random
3
+ class << self
4
+ # not available through LibTorch
5
+ def initial_seed
6
+ raise NotImplementedYet
7
+ end
8
+ end
9
+ end
10
+ end
data/lib/torch/tensor.rb CHANGED
@@ -66,16 +66,11 @@ module Torch
66
66
  end
67
67
 
68
68
  def backward(gradient = nil)
69
- if gradient
70
- _backward_gradient(gradient)
71
- else
72
- _backward
73
- end
69
+ _backward(gradient)
74
70
  end
75
71
 
76
72
  # TODO read directly from memory
77
73
  def numo
78
- raise Error, "Numo not found" unless defined?(Numo::NArray)
79
74
  cls = Torch._dtype_to_numo[dtype]
80
75
  raise Error, "Cannot convert #{dtype} to Numo" unless cls
81
76
  cls.cast(_data).reshape(*shape)
@@ -113,7 +108,7 @@ module Torch
113
108
  end
114
109
 
115
110
  # operations
116
- %w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
111
+ %w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op|
117
112
  define_method(op) do |*args, **options, &block|
118
113
  if options.any?
119
114
  Torch.send(op, self, *args, **options, &block)
@@ -167,8 +162,14 @@ module Torch
167
162
  finish += 1 unless index.exclude_end?
168
163
  result = result._slice(dim, index.begin, finish, 1)
169
164
  dim += 1
165
+ elsif index.nil?
166
+ result = result.unsqueeze(dim)
167
+ dim += 1
168
+ elsif index == true
169
+ result = result.unsqueeze(dim)
170
+ # TODO handle false
170
171
  else
171
- raise Error, "Unsupported index type"
172
+ raise Error, "Unsupported index type: #{index.class.name}"
172
173
  end
173
174
  end
174
175
  result
@@ -176,11 +177,28 @@ module Torch
176
177
 
177
178
  # TODO
178
179
  # based on python_variable_indexing.cpp
179
- # def []=(index, value)
180
- # end
180
+ def []=(index, value)
181
+ raise ArgumentError, "Tensor does not support deleting items" if value.nil?
182
+
183
+ value = Torch.tensor(value) unless value.is_a?(Tensor)
184
+
185
+ if index.is_a?(Numeric)
186
+ copy_to(_select(0, index), value)
187
+ elsif index.is_a?(Range)
188
+ finish = index.end
189
+ finish += 1 unless index.exclude_end?
190
+ copy_to(_slice(0, index.begin, finish, 1), value)
191
+ else
192
+ raise Error, "Unsupported index type: #{index.class.name}"
193
+ end
194
+ end
181
195
 
182
196
  private
183
197
 
198
+ def copy_to(dst, src)
199
+ dst.copy!(src)
200
+ end
201
+
184
202
  def reshape_arr(arr, dims)
185
203
  if dims.empty?
186
204
  arr
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.3"
2
+ VERSION = "0.1.4"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.3
4
+ version: 0.1.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-11-30 00:00:00.000000000 Z
11
+ date: 2019-12-01 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -111,22 +111,48 @@ files:
111
111
  - lib/torch/ext.bundle
112
112
  - lib/torch/inspector.rb
113
113
  - lib/torch/nn/alpha_dropout.rb
114
+ - lib/torch/nn/avg_pool2d.rb
115
+ - lib/torch/nn/avg_poolnd.rb
116
+ - lib/torch/nn/bce_loss.rb
117
+ - lib/torch/nn/bilinear.rb
114
118
  - lib/torch/nn/conv2d.rb
115
119
  - lib/torch/nn/convnd.rb
120
+ - lib/torch/nn/cosine_similarity.rb
121
+ - lib/torch/nn/cross_entropy_loss.rb
122
+ - lib/torch/nn/ctc_loss.rb
116
123
  - lib/torch/nn/dropout.rb
117
124
  - lib/torch/nn/dropout2d.rb
118
125
  - lib/torch/nn/dropout3d.rb
119
126
  - lib/torch/nn/dropoutnd.rb
120
127
  - lib/torch/nn/embedding.rb
128
+ - lib/torch/nn/embedding_bag.rb
121
129
  - lib/torch/nn/feature_alpha_dropout.rb
122
130
  - lib/torch/nn/functional.rb
131
+ - lib/torch/nn/identity.rb
123
132
  - lib/torch/nn/init.rb
133
+ - lib/torch/nn/kl_div_loss.rb
134
+ - lib/torch/nn/l1_loss.rb
135
+ - lib/torch/nn/leaky_relu.rb
124
136
  - lib/torch/nn/linear.rb
137
+ - lib/torch/nn/log_softmax.rb
138
+ - lib/torch/nn/loss.rb
139
+ - lib/torch/nn/max_pool2d.rb
140
+ - lib/torch/nn/max_poolnd.rb
125
141
  - lib/torch/nn/module.rb
126
142
  - lib/torch/nn/mse_loss.rb
143
+ - lib/torch/nn/nll_loss.rb
144
+ - lib/torch/nn/pairwise_distance.rb
127
145
  - lib/torch/nn/parameter.rb
146
+ - lib/torch/nn/poisson_nll_loss.rb
147
+ - lib/torch/nn/prelu.rb
128
148
  - lib/torch/nn/relu.rb
129
149
  - lib/torch/nn/sequential.rb
150
+ - lib/torch/nn/sigmoid.rb
151
+ - lib/torch/nn/softmax.rb
152
+ - lib/torch/nn/softmax2d.rb
153
+ - lib/torch/nn/softmin.rb
154
+ - lib/torch/nn/softplus.rb
155
+ - lib/torch/nn/weighted_loss.rb
130
156
  - lib/torch/optim/adadelta.rb
131
157
  - lib/torch/optim/adagrad.rb
132
158
  - lib/torch/optim/adam.rb
@@ -139,6 +165,7 @@ files:
139
165
  - lib/torch/optim/rmsprop.rb
140
166
  - lib/torch/optim/rprop.rb
141
167
  - lib/torch/optim/sgd.rb
168
+ - lib/torch/random.rb
142
169
  - lib/torch/tensor.rb
143
170
  - lib/torch/utils/data/data_loader.rb
144
171
  - lib/torch/utils/data/tensor_dataset.rb