torch-rb 0.1.5 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (73) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +1 -1
  4. data/ext/torch/ext.cpp +0 -170
  5. data/ext/torch/nn_functions.cpp +44 -24
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +48 -0
  8. data/ext/torch/tensor_functions.cpp +76 -16
  9. data/ext/torch/torch_functions.cpp +165 -65
  10. data/lib/torch.rb +51 -42
  11. data/lib/torch/ext.bundle +0 -0
  12. data/lib/torch/native/dispatcher.rb +1 -1
  13. data/lib/torch/native/function.rb +36 -5
  14. data/lib/torch/native/generator.rb +26 -7
  15. data/lib/torch/native/parser.rb +51 -14
  16. data/lib/torch/nn/avg_pool1d.rb +18 -0
  17. data/lib/torch/nn/avg_pool2d.rb +7 -2
  18. data/lib/torch/nn/avg_pool3d.rb +19 -0
  19. data/lib/torch/nn/avg_poolnd.rb +1 -1
  20. data/lib/torch/nn/batch_norm.rb +75 -0
  21. data/lib/torch/nn/batch_norm1d.rb +11 -0
  22. data/lib/torch/nn/batch_norm2d.rb +11 -0
  23. data/lib/torch/nn/batch_norm3d.rb +11 -0
  24. data/lib/torch/nn/constant_pad1d.rb +10 -0
  25. data/lib/torch/nn/constant_pad2d.rb +10 -0
  26. data/lib/torch/nn/constant_pad3d.rb +10 -0
  27. data/lib/torch/nn/constant_padnd.rb +18 -0
  28. data/lib/torch/nn/conv1d.rb +22 -0
  29. data/lib/torch/nn/conv2d.rb +9 -17
  30. data/lib/torch/nn/conv3d.rb +22 -0
  31. data/lib/torch/nn/fold.rb +20 -0
  32. data/lib/torch/nn/functional.rb +320 -100
  33. data/lib/torch/nn/group_norm.rb +36 -0
  34. data/lib/torch/nn/gru.rb +49 -0
  35. data/lib/torch/nn/hardshrink.rb +18 -0
  36. data/lib/torch/nn/instance_norm.rb +20 -0
  37. data/lib/torch/nn/instance_norm1d.rb +18 -0
  38. data/lib/torch/nn/instance_norm2d.rb +11 -0
  39. data/lib/torch/nn/instance_norm3d.rb +11 -0
  40. data/lib/torch/nn/layer_norm.rb +35 -0
  41. data/lib/torch/nn/local_response_norm.rb +21 -0
  42. data/lib/torch/nn/log_sigmoid.rb +9 -0
  43. data/lib/torch/nn/lp_pool1d.rb +9 -0
  44. data/lib/torch/nn/lp_pool2d.rb +9 -0
  45. data/lib/torch/nn/lp_poolnd.rb +22 -0
  46. data/lib/torch/nn/lstm.rb +66 -0
  47. data/lib/torch/nn/max_pool1d.rb +9 -0
  48. data/lib/torch/nn/max_pool2d.rb +1 -1
  49. data/lib/torch/nn/max_pool3d.rb +9 -0
  50. data/lib/torch/nn/max_poolnd.rb +6 -6
  51. data/lib/torch/nn/max_unpool1d.rb +16 -0
  52. data/lib/torch/nn/max_unpool2d.rb +16 -0
  53. data/lib/torch/nn/max_unpool3d.rb +16 -0
  54. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  55. data/lib/torch/nn/module.rb +7 -0
  56. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  57. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  58. data/lib/torch/nn/reflection_padnd.rb +13 -0
  59. data/lib/torch/nn/replication_pad1d.rb +10 -0
  60. data/lib/torch/nn/replication_pad2d.rb +10 -0
  61. data/lib/torch/nn/replication_pad3d.rb +10 -0
  62. data/lib/torch/nn/replication_padnd.rb +13 -0
  63. data/lib/torch/nn/rnn_base.rb +48 -4
  64. data/lib/torch/nn/softshrink.rb +18 -0
  65. data/lib/torch/nn/softsign.rb +9 -0
  66. data/lib/torch/nn/tanh.rb +9 -0
  67. data/lib/torch/nn/tanhshrink.rb +9 -0
  68. data/lib/torch/nn/unfold.rb +19 -0
  69. data/lib/torch/nn/utils.rb +25 -0
  70. data/lib/torch/nn/zero_pad2d.rb +9 -0
  71. data/lib/torch/tensor.rb +14 -25
  72. data/lib/torch/version.rb +1 -1
  73. metadata +50 -2
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool3d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _triple(kernel_size)
7
+ @stride = _triple(stride || kernel_size)
8
+ @padding = _triple(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Module
4
+ include Utils
5
+
4
6
  def initialize
5
7
  @training = true
6
8
  @parameters = {}
@@ -15,6 +17,7 @@ module Torch
15
17
  def register_buffer(name, tensor)
16
18
  # TODO add checks
17
19
  @buffers[name] = tensor
20
+ instance_variable_set("@#{name}", tensor)
18
21
  end
19
22
 
20
23
  def register_parameter(name, param)
@@ -216,6 +219,10 @@ module Torch
216
219
  end
217
220
  str % vars
218
221
  end
222
+
223
+ def dict
224
+ instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
225
+ end
219
226
  end
220
227
  end
221
228
  end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad1d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPad2d < ReflectionPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReflectionPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "reflect")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad1d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _pair(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad2d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _quadrupal(padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPad3d < ReplicationPadNd
4
+ def initialize(padding)
5
+ super()
6
+ @padding = _ntuple(6, padding)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module Torch
2
+ module NN
3
+ class ReplicationPadNd < Module
4
+ def forward(input)
5
+ F.pad(input, @padding, mode: "replicate")
6
+ end
7
+
8
+ def extra_inspect
9
+ @padding.inspect
10
+ end
11
+ end
12
+ end
13
+ end
@@ -90,12 +90,13 @@ module Torch
90
90
  end
91
91
 
92
92
  def permute_hidden(hx, permutation)
93
+ if permutation.nil?
94
+ return hx
95
+ end
93
96
  raise NotImplementedYet
94
97
  end
95
98
 
96
99
  def forward(input, hx: nil)
97
- raise NotImplementedYet
98
-
99
100
  is_packed = false # TODO isinstance(input, PackedSequence)
100
101
  if is_packed
101
102
  input, batch_sizes, sorted_indices, unsorted_indices = input
@@ -120,8 +121,8 @@ module Torch
120
121
 
121
122
  check_forward_args(input, hx, batch_sizes)
122
123
  _rnn_impls = {
123
- "RNN_TANH" => Torch.method(:_rnn_tanh),
124
- "RNN_RELU" => Torch.method(:_rnn_relu)
124
+ "RNN_TANH" => Torch.method(:rnn_tanh),
125
+ "RNN_RELU" => Torch.method(:rnn_relu)
125
126
  }
126
127
  _impl = _rnn_impls[@mode]
127
128
  if batch_sizes.nil?
@@ -149,6 +150,49 @@ module Torch
149
150
  end
150
151
  format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
151
152
  end
153
+
154
+ private
155
+
156
+ def _flat_weights
157
+ @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
158
+ end
159
+
160
+ def _get_flat_weights
161
+ _flat_weights
162
+ end
163
+
164
+ def check_input(input, batch_sizes)
165
+ expected_input_dim = !batch_sizes.nil? ? 2 : 3
166
+ if input.dim != expected_input_dim
167
+ raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
168
+ end
169
+ if @input_size != input.size(-1)
170
+ raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
171
+ end
172
+ end
173
+
174
+ def get_expected_hidden_size(input, batch_sizes)
175
+ if !batch_sizes.nil?
176
+ mini_batch = batch_sizes[0]
177
+ mini_batch = mini_batch.to_i
178
+ else
179
+ mini_batch = @batch_first ? input.size(0) : input.size(1)
180
+ end
181
+ num_directions = @bidirectional ? 2 : 1
182
+ [@num_layers * num_directions, mini_batch, @hidden_size]
183
+ end
184
+
185
+ def check_hidden_size(hx, expected_hidden_size)
186
+ if hx.size != expected_hidden_size
187
+ raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
188
+ end
189
+ end
190
+
191
+ def check_forward_args(input, hidden, batch_sizes)
192
+ check_input(input, batch_sizes)
193
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
194
+ check_hidden_size(hidden, expected_hidden_size)
195
+ end
152
196
  end
153
197
  end
154
198
  end
@@ -0,0 +1,18 @@
1
+ module Torch
2
+ module NN
3
+ class Softshrink < Module
4
+ def initialize(lambd: 0.5)
5
+ super()
6
+ @lambd = lambd
7
+ end
8
+
9
+ def forward(input)
10
+ F.softshrink(input, @lambd)
11
+ end
12
+
13
+ def extra_inspect
14
+ @lambd.to_s
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Softsign < Module
4
+ def forward(input)
5
+ F.softsign(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanh < Module
4
+ def forward(input)
5
+ Torch.tanh(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class Tanhshrink < Module
4
+ def forward(input)
5
+ F.tanhshrink(input)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class Unfold < Module
4
+ def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @dilation = dilation
8
+ @padding = padding
9
+ @stride = stride
10
+ end
11
+
12
+ def forward(input)
13
+ F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
14
+ end
15
+
16
+ # TODO add extra_inspect
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ module NN
3
+ module Utils
4
+ def _single(value)
5
+ _ntuple(1, value)
6
+ end
7
+
8
+ def _pair(value)
9
+ _ntuple(2, value)
10
+ end
11
+
12
+ def _triple(value)
13
+ _ntuple(3, value)
14
+ end
15
+
16
+ def _quadrupal(value)
17
+ _ntuple(4, value)
18
+ end
19
+
20
+ def _ntuple(n, value)
21
+ value.is_a?(Array) ? value : [value] * n
22
+ end
23
+ end
24
+ end
25
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class ZeroPad2d < ConstantPad2d
4
+ def initialize(padding)
5
+ super(padding, 0.0)
6
+ end
7
+ end
8
+ end
9
+ end
data/lib/torch/tensor.rb CHANGED
@@ -45,8 +45,9 @@ module Torch
45
45
  dim.times.map { |i| size(i) }
46
46
  end
47
47
 
48
- def view(*size)
49
- _view(size)
48
+ # mirror Python len()
49
+ def length
50
+ size(0)
50
51
  end
51
52
 
52
53
  def item
@@ -86,38 +87,26 @@ module Torch
86
87
  _type(enum)
87
88
  end
88
89
 
89
- # start temp operations
90
+ def reshape(*size)
91
+ # Python doesn't check if size == 1, just ignores later arguments
92
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
93
+ _reshape(size)
94
+ end
90
95
 
96
+ def view(*size)
97
+ size = size.first if size.size == 1 && size.first.is_a?(Array)
98
+ _view(size)
99
+ end
100
+
101
+ # value and other are swapped for some methods
91
102
  def add!(value = 1, other)
92
103
  if other.is_a?(Numeric)
93
104
  _add__scalar(other, value)
94
105
  else
95
- # need to use alpha for sparse tensors instead of multiplying
96
106
  _add__tensor(other, value)
97
107
  end
98
108
  end
99
109
 
100
- def mul!(other)
101
- if other.is_a?(Numeric)
102
- _mul__scalar(other)
103
- else
104
- _mul__tensor(other)
105
- end
106
- end
107
-
108
- # operations
109
- %w(log_softmax mean softmax sum topk).each do |op|
110
- define_method(op) do |*args, **options, &block|
111
- if options.any?
112
- Torch.send(op, self, *args, **options, &block)
113
- else
114
- Torch.send(op, self, *args, &block)
115
- end
116
- end
117
- end
118
-
119
- # end temp operations
120
-
121
110
  def +(other)
122
111
  add(other)
123
112
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.5"
2
+ VERSION = "0.1.6"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.5
4
+ version: 0.1.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-12-07 00:00:00.000000000 Z
11
+ date: 2019-12-10 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -108,6 +108,7 @@ files:
108
108
  - ext/torch/extconf.rb
109
109
  - ext/torch/nn_functions.cpp
110
110
  - ext/torch/nn_functions.hpp
111
+ - ext/torch/templates.cpp
111
112
  - ext/torch/templates.hpp
112
113
  - ext/torch/tensor_functions.cpp
113
114
  - ext/torch/tensor_functions.hpp
@@ -123,12 +124,24 @@ files:
123
124
  - lib/torch/native/native_functions.yaml
124
125
  - lib/torch/native/parser.rb
125
126
  - lib/torch/nn/alpha_dropout.rb
127
+ - lib/torch/nn/avg_pool1d.rb
126
128
  - lib/torch/nn/avg_pool2d.rb
129
+ - lib/torch/nn/avg_pool3d.rb
127
130
  - lib/torch/nn/avg_poolnd.rb
131
+ - lib/torch/nn/batch_norm.rb
132
+ - lib/torch/nn/batch_norm1d.rb
133
+ - lib/torch/nn/batch_norm2d.rb
134
+ - lib/torch/nn/batch_norm3d.rb
128
135
  - lib/torch/nn/bce_loss.rb
129
136
  - lib/torch/nn/bce_with_logits_loss.rb
130
137
  - lib/torch/nn/bilinear.rb
138
+ - lib/torch/nn/constant_pad1d.rb
139
+ - lib/torch/nn/constant_pad2d.rb
140
+ - lib/torch/nn/constant_pad3d.rb
141
+ - lib/torch/nn/constant_padnd.rb
142
+ - lib/torch/nn/conv1d.rb
131
143
  - lib/torch/nn/conv2d.rb
144
+ - lib/torch/nn/conv3d.rb
132
145
  - lib/torch/nn/convnd.rb
133
146
  - lib/torch/nn/cosine_embedding_loss.rb
134
147
  - lib/torch/nn/cosine_similarity.rb
@@ -141,19 +154,40 @@ files:
141
154
  - lib/torch/nn/embedding.rb
142
155
  - lib/torch/nn/embedding_bag.rb
143
156
  - lib/torch/nn/feature_alpha_dropout.rb
157
+ - lib/torch/nn/fold.rb
144
158
  - lib/torch/nn/functional.rb
159
+ - lib/torch/nn/group_norm.rb
160
+ - lib/torch/nn/gru.rb
161
+ - lib/torch/nn/hardshrink.rb
145
162
  - lib/torch/nn/hinge_embedding_loss.rb
146
163
  - lib/torch/nn/identity.rb
147
164
  - lib/torch/nn/init.rb
165
+ - lib/torch/nn/instance_norm.rb
166
+ - lib/torch/nn/instance_norm1d.rb
167
+ - lib/torch/nn/instance_norm2d.rb
168
+ - lib/torch/nn/instance_norm3d.rb
148
169
  - lib/torch/nn/kl_div_loss.rb
149
170
  - lib/torch/nn/l1_loss.rb
171
+ - lib/torch/nn/layer_norm.rb
150
172
  - lib/torch/nn/leaky_relu.rb
151
173
  - lib/torch/nn/linear.rb
174
+ - lib/torch/nn/local_response_norm.rb
175
+ - lib/torch/nn/log_sigmoid.rb
152
176
  - lib/torch/nn/log_softmax.rb
153
177
  - lib/torch/nn/loss.rb
178
+ - lib/torch/nn/lp_pool1d.rb
179
+ - lib/torch/nn/lp_pool2d.rb
180
+ - lib/torch/nn/lp_poolnd.rb
181
+ - lib/torch/nn/lstm.rb
154
182
  - lib/torch/nn/margin_ranking_loss.rb
183
+ - lib/torch/nn/max_pool1d.rb
155
184
  - lib/torch/nn/max_pool2d.rb
185
+ - lib/torch/nn/max_pool3d.rb
156
186
  - lib/torch/nn/max_poolnd.rb
187
+ - lib/torch/nn/max_unpool1d.rb
188
+ - lib/torch/nn/max_unpool2d.rb
189
+ - lib/torch/nn/max_unpool3d.rb
190
+ - lib/torch/nn/max_unpoolnd.rb
157
191
  - lib/torch/nn/module.rb
158
192
  - lib/torch/nn/mse_loss.rb
159
193
  - lib/torch/nn/multi_label_margin_loss.rb
@@ -164,7 +198,14 @@ files:
164
198
  - lib/torch/nn/parameter.rb
165
199
  - lib/torch/nn/poisson_nll_loss.rb
166
200
  - lib/torch/nn/prelu.rb
201
+ - lib/torch/nn/reflection_pad1d.rb
202
+ - lib/torch/nn/reflection_pad2d.rb
203
+ - lib/torch/nn/reflection_padnd.rb
167
204
  - lib/torch/nn/relu.rb
205
+ - lib/torch/nn/replication_pad1d.rb
206
+ - lib/torch/nn/replication_pad2d.rb
207
+ - lib/torch/nn/replication_pad3d.rb
208
+ - lib/torch/nn/replication_padnd.rb
168
209
  - lib/torch/nn/rnn.rb
169
210
  - lib/torch/nn/rnn_base.rb
170
211
  - lib/torch/nn/sequential.rb
@@ -175,8 +216,15 @@ files:
175
216
  - lib/torch/nn/softmax2d.rb
176
217
  - lib/torch/nn/softmin.rb
177
218
  - lib/torch/nn/softplus.rb
219
+ - lib/torch/nn/softshrink.rb
220
+ - lib/torch/nn/softsign.rb
221
+ - lib/torch/nn/tanh.rb
222
+ - lib/torch/nn/tanhshrink.rb
178
223
  - lib/torch/nn/triplet_margin_loss.rb
224
+ - lib/torch/nn/unfold.rb
225
+ - lib/torch/nn/utils.rb
179
226
  - lib/torch/nn/weighted_loss.rb
227
+ - lib/torch/nn/zero_pad2d.rb
180
228
  - lib/torch/optim/adadelta.rb
181
229
  - lib/torch/optim/adagrad.rb
182
230
  - lib/torch/optim/adam.rb