torch-rb 0.1.5 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +1 -1
  4. data/ext/torch/ext.cpp +0 -170
  5. data/ext/torch/nn_functions.cpp +44 -24
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +48 -0
  8. data/ext/torch/tensor_functions.cpp +76 -16
  9. data/ext/torch/torch_functions.cpp +165 -65
  10. data/lib/torch.rb +51 -42
  11. data/lib/torch/ext.bundle +0 -0
  12. data/lib/torch/native/dispatcher.rb +1 -1
  13. data/lib/torch/native/function.rb +36 -5
  14. data/lib/torch/native/generator.rb +26 -7
  15. data/lib/torch/native/parser.rb +51 -14
  16. data/lib/torch/nn/avg_pool1d.rb +18 -0
  17. data/lib/torch/nn/avg_pool2d.rb +7 -2
  18. data/lib/torch/nn/avg_pool3d.rb +19 -0
  19. data/lib/torch/nn/avg_poolnd.rb +1 -1
  20. data/lib/torch/nn/batch_norm.rb +75 -0
  21. data/lib/torch/nn/batch_norm1d.rb +11 -0
  22. data/lib/torch/nn/batch_norm2d.rb +11 -0
  23. data/lib/torch/nn/batch_norm3d.rb +11 -0
  24. data/lib/torch/nn/constant_pad1d.rb +10 -0
  25. data/lib/torch/nn/constant_pad2d.rb +10 -0
  26. data/lib/torch/nn/constant_pad3d.rb +10 -0
  27. data/lib/torch/nn/constant_padnd.rb +18 -0
  28. data/lib/torch/nn/conv1d.rb +22 -0
  29. data/lib/torch/nn/conv2d.rb +9 -17
  30. data/lib/torch/nn/conv3d.rb +22 -0
  31. data/lib/torch/nn/fold.rb +20 -0
  32. data/lib/torch/nn/functional.rb +320 -100
  33. data/lib/torch/nn/group_norm.rb +36 -0
  34. data/lib/torch/nn/gru.rb +49 -0
  35. data/lib/torch/nn/hardshrink.rb +18 -0
  36. data/lib/torch/nn/instance_norm.rb +20 -0
  37. data/lib/torch/nn/instance_norm1d.rb +18 -0
  38. data/lib/torch/nn/instance_norm2d.rb +11 -0
  39. data/lib/torch/nn/instance_norm3d.rb +11 -0
  40. data/lib/torch/nn/layer_norm.rb +35 -0
  41. data/lib/torch/nn/local_response_norm.rb +21 -0
  42. data/lib/torch/nn/log_sigmoid.rb +9 -0
  43. data/lib/torch/nn/lp_pool1d.rb +9 -0
  44. data/lib/torch/nn/lp_pool2d.rb +9 -0
  45. data/lib/torch/nn/lp_poolnd.rb +22 -0
  46. data/lib/torch/nn/lstm.rb +66 -0
  47. data/lib/torch/nn/max_pool1d.rb +9 -0
  48. data/lib/torch/nn/max_pool2d.rb +1 -1
  49. data/lib/torch/nn/max_pool3d.rb +9 -0
  50. data/lib/torch/nn/max_poolnd.rb +6 -6
  51. data/lib/torch/nn/max_unpool1d.rb +16 -0
  52. data/lib/torch/nn/max_unpool2d.rb +16 -0
  53. data/lib/torch/nn/max_unpool3d.rb +16 -0
  54. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  55. data/lib/torch/nn/module.rb +7 -0
  56. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  57. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  58. data/lib/torch/nn/reflection_padnd.rb +13 -0
  59. data/lib/torch/nn/replication_pad1d.rb +10 -0
  60. data/lib/torch/nn/replication_pad2d.rb +10 -0
  61. data/lib/torch/nn/replication_pad3d.rb +10 -0
  62. data/lib/torch/nn/replication_padnd.rb +13 -0
  63. data/lib/torch/nn/rnn_base.rb +48 -4
  64. data/lib/torch/nn/softshrink.rb +18 -0
  65. data/lib/torch/nn/softsign.rb +9 -0
  66. data/lib/torch/nn/tanh.rb +9 -0
  67. data/lib/torch/nn/tanhshrink.rb +9 -0
  68. data/lib/torch/nn/unfold.rb +19 -0
  69. data/lib/torch/nn/utils.rb +25 -0
  70. data/lib/torch/nn/zero_pad2d.rb +9 -0
  71. data/lib/torch/tensor.rb +14 -25
  72. data/lib/torch/version.rb +1 -1
  73. metadata +50 -2
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool3d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _triple(kernel_size)
7
+ @stride = _triple(stride || kernel_size)
8
+ @padding = _triple(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Module
4
+ include Utils
5
+
4
6
  def initialize
5
7
  @training = true
6
8
  @parameters = {}
@@ -15,6 +17,7 @@ module Torch
15
17
  def register_buffer(name, tensor)
16
18
  # TODO add checks
17
19
  @buffers[name] = tensor
20
+ instance_variable_set("@#{name}", tensor)
18
21
  end
19
22
 
20
23
  def register_parameter(name, param)
@@ -216,6 +219,10 @@ module Torch
216
219
  end
217
220
  str % vars
218
221
  end
222
+
223
+ def dict
224
+ instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
225
+ end
219
226
  end
220
227
  end
221
228
  end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad1d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad2d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "reflect")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad1d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad2d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad3d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _ntuple(6, padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "replicate")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
@@ -90,12 +90,13 @@ module Torch
90
90
  end
91
91
 
92
92
  def permute_hidden(hx, permutation)
93
+ if permutation.nil?
94
+ return hx
95
+ end
93
96
  raise NotImplementedYet
94
97
  end
95
98
 
96
99
  def forward(input, hx: nil)
97
- raise NotImplementedYet
98
-
99
100
  is_packed = false # TODO isinstance(input, PackedSequence)
100
101
  if is_packed
101
102
  input, batch_sizes, sorted_indices, unsorted_indices = input
@@ -120,8 +121,8 @@ module Torch
120
121
 
121
122
  check_forward_args(input, hx, batch_sizes)
122
123
  _rnn_impls = {
123
- "RNN_TANH" => Torch.method(:_rnn_tanh),
124
- "RNN_RELU" => Torch.method(:_rnn_relu)
124
+ "RNN_TANH" => Torch.method(:rnn_tanh),
125
+ "RNN_RELU" => Torch.method(:rnn_relu)
125
126
  }
126
127
  _impl = _rnn_impls[@mode]
127
128
  if batch_sizes.nil?
@@ -149,6 +150,49 @@ module Torch
149
150
  end
150
151
  format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
151
152
  end
153
+
154
+ private
155
+
156
+ def _flat_weights
157
+ @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
158
+ end
159
+
160
+ def _get_flat_weights
161
+ _flat_weights
162
+ end
163
+
164
+ def check_input(input, batch_sizes)
165
+ expected_input_dim = !batch_sizes.nil? ? 2 : 3
166
+ if input.dim != expected_input_dim
167
+ raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
168
+ end
169
+ if @input_size != input.size(-1)
170
+ raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
171
+ end
172
+ end
173
+
174
+ def get_expected_hidden_size(input, batch_sizes)
175
+ if !batch_sizes.nil?
176
+ mini_batch = batch_sizes[0]
177
+ mini_batch = mini_batch.to_i
178
+ else
179
+ mini_batch = @batch_first ? input.size(0) : input.size(1)
180
+ end
181
+ num_directions = @bidirectional ? 2 : 1
182
+ [@num_layers * num_directions, mini_batch, @hidden_size]
183
+ end
184
+
185
+ def check_hidden_size(hx, expected_hidden_size)
186
+ if hx.size != expected_hidden_size
187
+ raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
188
+ end
189
+ end
190
+
191
+ def check_forward_args(input, hidden, batch_sizes)
192
+ check_input(input, batch_sizes)
193
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
194
+ check_hidden_size(hidden, expected_hidden_size)
195
+ end
152
196
  end
153
197
  end
154
198
  end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softshrink < Module
4
+ def initialize(lambd: 0.5)
5
+ super()
6
+ @lambd = lambd
7
+ end
8
+
9
+ def forward(input)
10
+ F.softshrink(input, @lambd)
11
+ end
12
+
13
+ def extra_inspect
14
+ @lambd.to_s
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Softsign < Module
4
+ def forward(input)
5
+ F.softsign(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanh < Module
4
+ def forward(input)
5
+ Torch.tanh(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanhshrink < Module
4
+ def forward(input)
5
+ F.tanhshrink(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Unfold < Module
4
+ def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @dilation = dilation
8
+ @padding = padding
9
+ @stride = stride
10
+ end
11
+
12
+ def forward(input)
13
+ F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
14
+ end
15
+
16
+ # TODO add extra_inspect
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ module Utils
4
+ def _single(value)
5
+ _ntuple(1, value)
6
+ end
7
+
8
+ def _pair(value)
9
+ _ntuple(2, value)
10
+ end
11
+
12
+ def _triple(value)
13
+ _ntuple(3, value)
14
+ end
15
+
16
+ def _quadrupal(value)
17
+ _ntuple(4, value)
18
+ end
19
+
20
+ def _ntuple(n, value)
21
+ value.is_a?(Array) ? value : [value] * n
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class ZeroPad2d < ConstantPad2d
4
+ def initialize(padding)
5
+ super(padding, 0.0)
6
+ end
7
+ end
8
+ end
9
+ end
data/lib/torch/tensor.rb CHANGED
@@ -45,8 +45,9 @@ module Torch
45
45
  dim.times.map { |i| size(i) }
46
46
  end
47
47
 
48
- def view(*size)
49
- _view(size)
48
+ # mirror Python len()
49
+ def length
50
+ size(0)
50
51
  end
51
52
 
52
53
  def item
@@ -86,38 +87,26 @@ module Torch
86
87
  _type(enum)
87
88
  end
88
89
 
89
- # start temp operations
90
+ def reshape(*size)
91
+ # Python doesn't check if size == 1, just ignores later arguments
92
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
93
+ _reshape(size)
94
+ end
90
95
 
96
+ def view(*size)
97
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
98
+ _view(size)
99
+ end
100
+
101
+ # value and other are swapped for some methods
91
102
  def add!(value = 1, other)
92
103
  if other.is_a?(Numeric)
93
104
  _add__scalar(other, value)
94
105
  else
95
- # need to use alpha for sparse tensors instead of multiplying
96
106
  _add__tensor(other, value)
97
107
  end
98
108
  end
99
109
 
100
- def mul!(other)
101
- if other.is_a?(Numeric)
102
- _mul__scalar(other)
103
- else
104
- _mul__tensor(other)
105
- end
106
- end
107
-
108
- # operations
109
- %w(log_softmax mean softmax sum topk).each do |op|
110
- define_method(op) do |*args, **options, &block|
111
- if options.any?
112
- Torch.send(op, self, *args, **options, &block)
113
- else
114
- Torch.send(op, self, *args, &block)
115
- end
116
- end
117
- end
118
-
119
- # end temp operations
120
-
121
110
  def +(other)
122
111
  add(other)
123
112
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.5"
2
+ VERSION = "0.1.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.5
4
+ version: 0.1.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-12-07 00:00:00.000000000 Z
11
+ date: 2019-12-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,6 +108,7 @@ files:
108
108
  - ext/torch/extconf.rb
109
109
  - ext/torch/nn_functions.cpp
110
110
  - ext/torch/nn_functions.hpp
111
+ - ext/torch/templates.cpp
111
112
  - ext/torch/templates.hpp
112
113
  - ext/torch/tensor_functions.cpp
113
114
  - ext/torch/tensor_functions.hpp
@@ -123,12 +124,24 @@ files:
123
124
  - lib/torch/native/native_functions.yaml
124
125
  - lib/torch/native/parser.rb
125
126
  - lib/torch/nn/alpha_dropout.rb
127
+ - lib/torch/nn/avg_pool1d.rb
126
128
  - lib/torch/nn/avg_pool2d.rb
129
+ - lib/torch/nn/avg_pool3d.rb
127
130
  - lib/torch/nn/avg_poolnd.rb
131
+ - lib/torch/nn/batch_norm.rb
132
+ - lib/torch/nn/batch_norm1d.rb
133
+ - lib/torch/nn/batch_norm2d.rb
134
+ - lib/torch/nn/batch_norm3d.rb
128
135
  - lib/torch/nn/bce_loss.rb
129
136
  - lib/torch/nn/bce_with_logits_loss.rb
130
137
  - lib/torch/nn/bilinear.rb
138
+ - lib/torch/nn/constant_pad1d.rb
139
+ - lib/torch/nn/constant_pad2d.rb
140
+ - lib/torch/nn/constant_pad3d.rb
141
+ - lib/torch/nn/constant_padnd.rb
142
+ - lib/torch/nn/conv1d.rb
131
143
  - lib/torch/nn/conv2d.rb
144
+ - lib/torch/nn/conv3d.rb
132
145
  - lib/torch/nn/convnd.rb
133
146
  - lib/torch/nn/cosine_embedding_loss.rb
134
147
  - lib/torch/nn/cosine_similarity.rb
@@ -141,19 +154,40 @@ files:
141
154
  - lib/torch/nn/embedding.rb
142
155
  - lib/torch/nn/embedding_bag.rb
143
156
  - lib/torch/nn/feature_alpha_dropout.rb
157
+ - lib/torch/nn/fold.rb
144
158
  - lib/torch/nn/functional.rb
159
+ - lib/torch/nn/group_norm.rb
160
+ - lib/torch/nn/gru.rb
161
+ - lib/torch/nn/hardshrink.rb
145
162
  - lib/torch/nn/hinge_embedding_loss.rb
146
163
  - lib/torch/nn/identity.rb
147
164
  - lib/torch/nn/init.rb
165
+ - lib/torch/nn/instance_norm.rb
166
+ - lib/torch/nn/instance_norm1d.rb
167
+ - lib/torch/nn/instance_norm2d.rb
168
+ - lib/torch/nn/instance_norm3d.rb
148
169
  - lib/torch/nn/kl_div_loss.rb
149
170
  - lib/torch/nn/l1_loss.rb
171
+ - lib/torch/nn/layer_norm.rb
150
172
  - lib/torch/nn/leaky_relu.rb
151
173
  - lib/torch/nn/linear.rb
174
+ - lib/torch/nn/local_response_norm.rb
175
+ - lib/torch/nn/log_sigmoid.rb
152
176
  - lib/torch/nn/log_softmax.rb
153
177
  - lib/torch/nn/loss.rb
178
+ - lib/torch/nn/lp_pool1d.rb
179
+ - lib/torch/nn/lp_pool2d.rb
180
+ - lib/torch/nn/lp_poolnd.rb
181
+ - lib/torch/nn/lstm.rb
154
182
  - lib/torch/nn/margin_ranking_loss.rb
183
+ - lib/torch/nn/max_pool1d.rb
155
184
  - lib/torch/nn/max_pool2d.rb
185
+ - lib/torch/nn/max_pool3d.rb
156
186
  - lib/torch/nn/max_poolnd.rb
187
+ - lib/torch/nn/max_unpool1d.rb
188
+ - lib/torch/nn/max_unpool2d.rb
189
+ - lib/torch/nn/max_unpool3d.rb
190
+ - lib/torch/nn/max_unpoolnd.rb
157
191
  - lib/torch/nn/module.rb
158
192
  - lib/torch/nn/mse_loss.rb
159
193
  - lib/torch/nn/multi_label_margin_loss.rb
@@ -164,7 +198,14 @@ files:
164
198
  - lib/torch/nn/parameter.rb
165
199
  - lib/torch/nn/poisson_nll_loss.rb
166
200
  - lib/torch/nn/prelu.rb
201
+ - lib/torch/nn/reflection_pad1d.rb
202
+ - lib/torch/nn/reflection_pad2d.rb
203
+ - lib/torch/nn/reflection_padnd.rb
167
204
  - lib/torch/nn/relu.rb
205
+ - lib/torch/nn/replication_pad1d.rb
206
+ - lib/torch/nn/replication_pad2d.rb
207
+ - lib/torch/nn/replication_pad3d.rb
208
+ - lib/torch/nn/replication_padnd.rb
168
209
  - lib/torch/nn/rnn.rb
169
210
  - lib/torch/nn/rnn_base.rb
170
211
  - lib/torch/nn/sequential.rb
@@ -175,8 +216,15 @@ files:
175
216
  - lib/torch/nn/softmax2d.rb
176
217
  - lib/torch/nn/softmin.rb
177
218
  - lib/torch/nn/softplus.rb
219
+ - lib/torch/nn/softshrink.rb
220
+ - lib/torch/nn/softsign.rb
221
+ - lib/torch/nn/tanh.rb
222
+ - lib/torch/nn/tanhshrink.rb
178
223
  - lib/torch/nn/triplet_margin_loss.rb
224
+ - lib/torch/nn/unfold.rb
225
+ - lib/torch/nn/utils.rb
179
226
  - lib/torch/nn/weighted_loss.rb
227
+ - lib/torch/nn/zero_pad2d.rb
180
228
  - lib/torch/optim/adadelta.rb
181
229
  - lib/torch/optim/adagrad.rb
182
230
  - lib/torch/optim/adam.rb