torch-rb 0.1.3 → 0.1.8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Loss < Module
4
+ def initialize(reduction)
5
+ super()
6
+ @reduction = reduction
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool1d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool1d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool2d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool2d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class LPPoolNd < Module
4
+ def initialize(norm_type, kernel_size, stride: nil, ceil_mode: false)
5
+ super()
6
+ @norm_type = norm_type
7
+ @kernel_size = kernel_size
8
+ @stride = stride
9
+ @ceil_mode = ceil_mode
10
+ end
11
+
12
+ def extra_inspect
13
+ format("norm_type: %{norm_type}, kernel_size: %{kernel_size}, stride: %{stride}, ceil_mode: %{ceil_mode}",
14
+ norm_type: @norm_type,
15
+ kernel_size: @kernel_size,
16
+ stride: @stride,
17
+ ceil_mode: @ceil_mode
18
+ )
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,66 @@
1
+ module Torch
2
+ module NN
3
+ class LSTM < RNNBase
4
+ def initialize(*args, **options)
5
+ super("LSTM", *args, **options)
6
+ end
7
+
8
+ def check_forward_args(input, hidden, batch_sizes)
9
+ check_input(input, batch_sizes)
10
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
11
+
12
+ # TODO pass message
13
+ check_hidden_size(hidden[0], expected_hidden_size)
14
+ check_hidden_size(hidden[1], expected_hidden_size)
15
+ end
16
+
17
+ def permute_hidden(hx, permutation)
18
+ if permutation.nil?
19
+ return hx
20
+ end
21
+ raise NotImplementedYet
22
+ end
23
+
24
+ def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
25
+ if hx.nil?
26
+ num_directions = @bidirectional ? 2 : 1
27
+ zeros = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
28
+ hx = [zeros, zeros]
29
+ else
30
+ # Each batch of the hidden state should match the input sequence that
31
+ # the user believes he/she is passing in.
32
+ hx = permute_hidden(hx, sorted_indices)
33
+ end
34
+
35
+ check_forward_args(input, hx, batch_sizes)
36
+ if batch_sizes.nil?
37
+ result = Torch.lstm(input, hx, _get_flat_weights, @bias, @num_layers,
38
+ @dropout, @training, @bidirectional, @batch_first)
39
+ else
40
+ result = Torch.lstm(input, batch_sizes, hx, _get_flat_weights, @bias,
41
+ @num_layers, @dropout, @training, @bidirectional)
42
+ end
43
+ output = result[0]
44
+ hidden = result[1..-1]
45
+
46
+ [output, hidden]
47
+ end
48
+
49
+ def forward_tensor(input, hx: nil)
50
+ batch_sizes = nil
51
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
52
+ sorted_indices = nil
53
+ unsorted_indices = nil
54
+
55
+ output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
56
+
57
+ [output, permute_hidden(hidden, unsorted_indices)]
58
+ end
59
+
60
+ def forward(input, hx: nil)
61
+ # TODO PackedSequence
62
+ forward_tensor(input, hx: hx)
63
+ end
64
+ end
65
+ end
66
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class MarginRankingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool1d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool1d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool2d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool2d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool3d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool3d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPoolNd < Module
4
+ def initialize(kernel_size, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @dilation = dilation
10
+ @return_indices = return_indices
11
+ @ceil_mode = ceil_mode
12
+ end
13
+
14
+ def extra_inspect
15
+ format("kernel_size: %s", @kernel_size)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool1d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool1d(input, indices, @kernel_size, stride: @stride, padding: @padding, output_size: output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool2d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _pair(kernel_size)
7
+ @stride = _pair(stride || kernel_size)
8
+ @padding = _pair(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool2d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool3d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _triple(kernel_size)
7
+ @stride = _triple(stride || kernel_size)
8
+ @padding = _triple(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,56 +1,170 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Module
4
+ include Utils
5
+
4
6
  def initialize
5
7
  @training = true
8
+ @parameters = {}
9
+ @buffers = {}
10
+ @modules = {}
6
11
  end
7
12
 
8
- def inspect
9
- str = String.new
10
- str << "#{self.class.name}(\n"
11
- modules.each do |name, mod|
12
- str << " (#{name}): #{mod.inspect}\n"
13
- end
14
- str << ")"
13
+ def forward
14
+ raise NotImplementedError
15
15
  end
16
16
 
17
- def train(mode = true)
18
- @training = mode
17
+ def register_buffer(name, tensor)
18
+ # TODO add checks
19
+ @buffers[name] = tensor
20
+ instance_variable_set("@#{name}", tensor)
21
+ end
19
22
 
20
- modules.each do |_, mod|
21
- mod.train(mode)
23
+ def register_parameter(name, param)
24
+ # TODO add checks
25
+ @parameters[name] = param
26
+ end
27
+
28
+ def add_module(name, mod)
29
+ # TODO add checks
30
+ @modules[name] = mod
31
+ end
32
+
33
+ def _apply(fn)
34
+ children.each do |mod|
35
+ mod._apply(fn)
22
36
  end
37
+ # TODO apply to more objects
38
+ self
23
39
  end
24
40
 
25
- def eval
26
- train(false)
41
+ def apply(fn)
42
+ children.each do |mod|
43
+ mod.apply(fn)
44
+ end
45
+ fn.call(self)
46
+ self
47
+ end
48
+
49
+ def cuda(device: nil)
50
+ _apply ->(t) { t.cuda(device) }
51
+ end
52
+
53
+ def cpu
54
+ _apply ->(t) { t.cpu }
55
+ end
56
+
57
+ def type(dst_type)
58
+ _apply ->(t) { t.type(dst_type) }
27
59
  end
28
60
 
29
- def call(*input)
30
- forward(*input)
61
+ def float
62
+ _apply ->(t) { t.floating_point? ? t.float : t }
63
+ end
64
+
65
+ def double
66
+ _apply ->(t) { t.floating_point? ? t.double : t }
67
+ end
68
+
69
+ def half
70
+ _apply ->(t) { t.floating_point? ? t.half : t }
31
71
  end
32
72
 
33
73
  # modifies in-place
34
74
  def to(device)
35
- instance_variables.each do |name|
36
- param = instance_variable_get(name)
37
- if param.is_a?(Parameter)
38
- instance_variable_set(name, Parameter.new(param.to(device)))
39
- end
75
+ convert = lambda do |t|
76
+ t.to(device)
40
77
  end
41
- modules.each do |_, mod|
42
- mod.to(device)
78
+
79
+ _apply(convert)
80
+ end
81
+
82
+ def call(*input, **kwargs)
83
+ forward(*input, **kwargs)
84
+ end
85
+
86
+ def state_dict(destination: nil)
87
+ destination ||= {}
88
+ named_parameters.each do |k, v|
89
+ destination[k] = v
43
90
  end
44
- self
91
+ destination
92
+ end
93
+
94
+ def load_state_dict(state_dict)
95
+ raise NotImplementedYet
45
96
  end
46
97
 
47
98
  def parameters
48
- params = []
99
+ named_parameters.values
100
+ end
101
+
102
+ def named_parameters(prefix: "", recurse: true)
103
+ params = {}
104
+ if recurse
105
+ named_children.each do |name, mod|
106
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
107
+ end
108
+ end
49
109
  instance_variables.each do |name|
50
110
  param = instance_variable_get(name)
51
- params << param if param.is_a?(Parameter)
111
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
112
+ end
113
+ @parameters.each do |name, param|
114
+ params[[prefix, name].join] = param
52
115
  end
53
- params + modules.flat_map { |_, mod| mod.parameters }
116
+ params
117
+ end
118
+
119
+ def buffers
120
+ named_buffers.values
121
+ end
122
+
123
+ def named_buffers
124
+ @buffers || {}
125
+ end
126
+
127
+ def children
128
+ named_children.values
129
+ end
130
+
131
+ def named_children
132
+ modules = {}
133
+ instance_variables.each do |name|
134
+ mod = instance_variable_get(name)
135
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
136
+ end
137
+ @modules.each do |name, mod|
138
+ modules[name] = mod
139
+ end
140
+ modules
141
+ end
142
+
143
+ def modules
144
+ named_modules.values
145
+ end
146
+
147
+ def named_modules
148
+ {"" => self}.merge(named_children)
149
+ end
150
+
151
+ def train(mode = true)
152
+ @training = mode
153
+ children.each do |mod|
154
+ mod.train(mode)
155
+ end
156
+ self
157
+ end
158
+
159
+ def eval
160
+ train(false)
161
+ end
162
+
163
+ def requires_grad!(requires_grad: true)
164
+ parameters.each do |p|
165
+ p.requires_grad!(requires_grad)
166
+ end
167
+ self
54
168
  end
55
169
 
56
170
  def zero_grad
@@ -62,23 +176,60 @@ module Torch
62
176
  end
63
177
  end
64
178
 
179
+ def share_memory
180
+ _apply ->(t) { t.share_memory! }
181
+ end
182
+
183
+ def inspect
184
+ name = self.class.name.split("::").last
185
+ if children.empty?
186
+ "#{name}(#{extra_inspect})"
187
+ else
188
+ str = String.new
189
+ str << "#{name}(\n"
190
+ children.each do |name, mod|
191
+ str << " (#{name}): #{mod.inspect}\n"
192
+ end
193
+ str << ")"
194
+ end
195
+ end
196
+
65
197
  def method_missing(method, *args, &block)
66
- modules[method.to_s] || super
198
+ name = method.to_s
199
+ if named_parameters.key?(name)
200
+ named_parameters[name]
201
+ elsif named_buffers.key?(name)
202
+ named_buffers[name]
203
+ elsif named_modules.key?(name)
204
+ named_modules[name]
205
+ else
206
+ super
207
+ end
67
208
  end
68
209
 
69
210
  def respond_to?(method, include_private = false)
70
- modules.key?(method.to_s) || super
211
+ name = method.to_s
212
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
71
213
  end
72
214
 
73
215
  private
74
216
 
75
- def modules
76
- modules = {}
77
- instance_variables.each do |name|
78
- mod = instance_variable_get(name)
79
- modules[name[1..-1]] = mod if mod.is_a?(Module)
80
- end
81
- modules
217
+ def extra_inspect
218
+ nil
219
+ end
220
+
221
+ def format(str, *vars, **options)
222
+ vars =
223
+ if vars.any?
224
+ vars.map(&:inspect)
225
+ else
226
+ options.map { |k, v| [k, v.inspect] }.to_h
227
+ end
228
+ str % vars
229
+ end
230
+
231
+ def dict
232
+ instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
82
233
  end
83
234
  end
84
235
  end