torch-rb 0.1.1 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +73 -9
  5. data/ext/torch/ext.cpp +148 -315
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +298 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +236 -112
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +52 -25
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -39
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +419 -16
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +191 -19
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +4 -0
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +62 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +60 -0
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +90 -30
  139. data/lib/torch/utils/data/data_loader.rb +15 -0
  140. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  141. data/lib/torch/version.rb +1 -1
  142. metadata +122 -3
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Loss < Module
4
+ def initialize(reduction)
5
+ super()
6
+ @reduction = reduction
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool1d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool1d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool2d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool2d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class LPPoolNd < Module
4
+ def initialize(norm_type, kernel_size, stride: nil, ceil_mode: false)
5
+ super()
6
+ @norm_type = norm_type
7
+ @kernel_size = kernel_size
8
+ @stride = stride
9
+ @ceil_mode = ceil_mode
10
+ end
11
+
12
+ def extra_inspect
13
+ format("norm_type: %{norm_type}, kernel_size: %{kernel_size}, stride: %{stride}, ceil_mode: %{ceil_mode}",
14
+ norm_type: @norm_type,
15
+ kernel_size: @kernel_size,
16
+ stride: @stride,
17
+ ceil_mode: @ceil_mode
18
+ )
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,66 @@
1
+ module Torch
2
+ module NN
3
+ class LSTM < RNNBase
4
+ def initialize(*args, **options)
5
+ super("LSTM", *args, **options)
6
+ end
7
+
8
+ def check_forward_args(input, hidden, batch_sizes)
9
+ check_input(input, batch_sizes)
10
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
11
+
12
+ # TODO pass message
13
+ check_hidden_size(hidden[0], expected_hidden_size)
14
+ check_hidden_size(hidden[1], expected_hidden_size)
15
+ end
16
+
17
+ def permute_hidden(hx, permutation)
18
+ if permutation.nil?
19
+ return hx
20
+ end
21
+ raise NotImplementedYet
22
+ end
23
+
24
+ def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
25
+ if hx.nil?
26
+ num_directions = @bidirectional ? 2 : 1
27
+ zeros = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
28
+ hx = [zeros, zeros]
29
+ else
30
+ # Each batch of the hidden state should match the input sequence that
31
+ # the user believes he/she is passing in.
32
+ hx = permute_hidden(hx, sorted_indices)
33
+ end
34
+
35
+ check_forward_args(input, hx, batch_sizes)
36
+ if batch_sizes.nil?
37
+ result = Torch.lstm(input, hx, _get_flat_weights, @bias, @num_layers,
38
+ @dropout, @training, @bidirectional, @batch_first)
39
+ else
40
+ result = Torch.lstm(input, batch_sizes, hx, _get_flat_weights, @bias,
41
+ @num_layers, @dropout, @training, @bidirectional)
42
+ end
43
+ output = result[0]
44
+ hidden = result[1..-1]
45
+
46
+ [output, hidden]
47
+ end
48
+
49
+ def forward_tensor(input, hx: nil)
50
+ batch_sizes = nil
51
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
52
+ sorted_indices = nil
53
+ unsorted_indices = nil
54
+
55
+ output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
56
+
57
+ [output, permute_hidden(hidden, unsorted_indices)]
58
+ end
59
+
60
+ def forward(input, hx: nil)
61
+ # TODO PackedSequence
62
+ forward_tensor(input, hx: hx)
63
+ end
64
+ end
65
+ end
66
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class MarginRankingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool1d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool1d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool2d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool2d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool3d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool3d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPoolNd < Module
4
+ def initialize(kernel_size, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @dilation = dilation
10
+ @return_indices = return_indices
11
+ @ceil_mode = ceil_mode
12
+ end
13
+
14
+ def extra_inspect
15
+ format("kernel_size: %s", @kernel_size)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool1d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool1d(input, indices, @kernel_size, stride: @stride, padding: @padding, output_size: output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool2d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _pair(kernel_size)
7
+ @stride = _pair(stride || kernel_size)
8
+ @padding = _pair(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool2d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool3d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _triple(kernel_size)
7
+ @stride = _triple(stride || kernel_size)
8
+ @padding = _triple(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,55 +1,227 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Module
4
- def inspect
5
- str = String.new
6
- str << "#{self.class.name}(\n"
7
- modules.each do |name, mod|
8
- str << " (#{name}): #{mod.inspect}\n"
4
+ include Utils
5
+
6
+ def initialize
7
+ @training = true
8
+ @parameters = {}
9
+ @buffers = {}
10
+ @modules = {}
11
+ end
12
+
13
+ def forward
14
+ raise NotImplementedError
15
+ end
16
+
17
+ def register_buffer(name, tensor)
18
+ # TODO add checks
19
+ @buffers[name] = tensor
20
+ instance_variable_set("@#{name}", tensor)
21
+ end
22
+
23
+ def register_parameter(name, param)
24
+ # TODO add checks
25
+ @parameters[name] = param
26
+ end
27
+
28
+ def add_module(name, mod)
29
+ # TODO add checks
30
+ @modules[name] = mod
31
+ end
32
+
33
+ def _apply(fn)
34
+ children.each do |mod|
35
+ mod._apply(fn)
36
+ end
37
+ # TODO apply to more objects
38
+ self
39
+ end
40
+
41
+ def apply(fn)
42
+ children.each do |mod|
43
+ mod.apply(fn)
9
44
  end
10
- str << ")"
45
+ fn.call(self)
46
+ self
47
+ end
48
+
49
+ def cuda(device: nil)
50
+ _apply ->(t) { t.cuda(device) }
51
+ end
52
+
53
+ def cpu
54
+ _apply ->(t) { t.cpu }
55
+ end
56
+
57
+ def type(dst_type)
58
+ _apply ->(t) { t.type(dst_type) }
59
+ end
60
+
61
+ def float
62
+ _apply ->(t) { t.floating_point? ? t.float : t }
63
+ end
64
+
65
+ def double
66
+ _apply ->(t) { t.floating_point? ? t.double : t }
67
+ end
68
+
69
+ def half
70
+ _apply ->(t) { t.floating_point? ? t.half : t }
71
+ end
72
+
73
+ # modifies in-place
74
+ def to(device)
75
+ convert = lambda do |t|
76
+ t.to(device)
77
+ end
78
+
79
+ _apply(convert)
11
80
  end
12
81
 
13
82
  def call(*input)
14
83
  forward(*input)
15
84
  end
16
85
 
86
+ def state_dict
87
+ raise NotImplementedYet
88
+ end
89
+
17
90
  def parameters
18
- params = []
91
+ named_parameters.values
92
+ end
93
+
94
+ def named_parameters(prefix: "", recurse: true)
95
+ params = {}
96
+ if recurse
97
+ named_children.each do |name, mod|
98
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
99
+ end
100
+ end
19
101
  instance_variables.each do |name|
20
102
  param = instance_variable_get(name)
21
- params << param if param.is_a?(Parameter)
103
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
104
+ end
105
+ @parameters.each do |name, param|
106
+ params[[prefix, name].join] = param
107
+ end
108
+ params
109
+ end
110
+
111
+ def buffers
112
+ named_buffers.values
113
+ end
114
+
115
+ def named_buffers
116
+ @buffers || {}
117
+ end
118
+
119
+ def children
120
+ named_children.values
121
+ end
122
+
123
+ def named_children
124
+ modules = {}
125
+ instance_variables.each do |name|
126
+ mod = instance_variable_get(name)
127
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
128
+ end
129
+ @modules.each do |name, mod|
130
+ modules[name] = mod
131
+ end
132
+ modules
133
+ end
134
+
135
+ def modules
136
+ named_modules.values
137
+ end
138
+
139
+ def named_modules
140
+ {"" => self}.merge(named_children)
141
+ end
142
+
143
+ def train(mode = true)
144
+ @training = mode
145
+ children.each do |mod|
146
+ mod.train(mode)
22
147
  end
23
- params + modules.flat_map { |_, mod| mod.parameters }
148
+ self
149
+ end
150
+
151
+ def eval
152
+ train(false)
153
+ end
154
+
155
+ def requires_grad!(requires_grad: true)
156
+ parameters.each do |p|
157
+ p.requires_grad!(requires_grad)
158
+ end
159
+ self
24
160
  end
25
161
 
26
162
  def zero_grad
27
163
  parameters.each do |param|
28
164
  if param.grad
29
- raise Error, "Not supported yet"
30
165
  param.grad.detach!
31
166
  param.grad.zero!
32
167
  end
33
168
  end
34
169
  end
35
170
 
171
+ def share_memory
172
+ _apply ->(t) { t.share_memory! }
173
+ end
174
+
175
+ def inspect
176
+ name = self.class.name.split("::").last
177
+ if children.empty?
178
+ "#{name}(#{extra_inspect})"
179
+ else
180
+ str = String.new
181
+ str << "#{name}(\n"
182
+ children.each do |name, mod|
183
+ str << " (#{name}): #{mod.inspect}\n"
184
+ end
185
+ str << ")"
186
+ end
187
+ end
188
+
36
189
  def method_missing(method, *args, &block)
37
- modules[method.to_s] || super
190
+ name = method.to_s
191
+ if named_parameters.key?(name)
192
+ named_parameters[name]
193
+ elsif named_buffers.key?(name)
194
+ named_buffers[name]
195
+ elsif named_modules.key?(name)
196
+ named_modules[name]
197
+ else
198
+ super
199
+ end
38
200
  end
39
201
 
40
202
  def respond_to?(method, include_private = false)
41
- modules.key?(method.to_s) || super
203
+ name = method.to_s
204
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
42
205
  end
43
206
 
44
207
  private
45
208
 
46
- def modules
47
- modules = {}
48
- instance_variables.each do |name|
49
- mod = instance_variable_get(name)
50
- modules[name[1..-1]] = mod if mod.is_a?(Module)
51
- end
52
- modules
209
+ def extra_inspect
210
+ nil
211
+ end
212
+
213
+ def format(str, *vars, **options)
214
+ vars =
215
+ if vars.any?
216
+ vars.map(&:inspect)
217
+ else
218
+ options.map { |k, v| [k, v.inspect] }.to_h
219
+ end
220
+ str % vars
221
+ end
222
+
223
+ def dict
224
+ instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
53
225
  end
54
226
  end
55
227
  end