torch-rb 0.1.3 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
@@ -0,0 +1,10 @@
1
+ module Torch
2
+ module NN
3
+ class Loss < Module
4
+ def initialize(reduction)
5
+ super()
6
+ @reduction = reduction
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool1d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool1d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class LPPool2d < LPPoolNd
4
+ def forward(input)
5
+ F.lp_pool2d(input, @norm_type.to_f, @kernel_size, @stride, @ceil_mode)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,22 @@
1
+ module Torch
2
+ module NN
3
+ class LPPoolNd < Module
4
+ def initialize(norm_type, kernel_size, stride: nil, ceil_mode: false)
5
+ super()
6
+ @norm_type = norm_type
7
+ @kernel_size = kernel_size
8
+ @stride = stride
9
+ @ceil_mode = ceil_mode
10
+ end
11
+
12
+ def extra_inspect
13
+ format("norm_type: %{norm_type}, kernel_size: %{kernel_size}, stride: %{stride}, ceil_mode: %{ceil_mode}",
14
+ norm_type: @norm_type,
15
+ kernel_size: @kernel_size,
16
+ stride: @stride,
17
+ ceil_mode: @ceil_mode
18
+ )
19
+ end
20
+ end
21
+ end
22
+ end
@@ -0,0 +1,66 @@
1
+ module Torch
2
+ module NN
3
+ class LSTM < RNNBase
4
+ def initialize(*args, **options)
5
+ super("LSTM", *args, **options)
6
+ end
7
+
8
+ def check_forward_args(input, hidden, batch_sizes)
9
+ check_input(input, batch_sizes)
10
+ expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
11
+
12
+ # TODO pass message
13
+ check_hidden_size(hidden[0], expected_hidden_size)
14
+ check_hidden_size(hidden[1], expected_hidden_size)
15
+ end
16
+
17
+ def permute_hidden(hx, permutation)
18
+ if permutation.nil?
19
+ return hx
20
+ end
21
+ raise NotImplementedYet
22
+ end
23
+
24
+ def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
25
+ if hx.nil?
26
+ num_directions = @bidirectional ? 2 : 1
27
+ zeros = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
28
+ hx = [zeros, zeros]
29
+ else
30
+ # Each batch of the hidden state should match the input sequence that
31
+ # the user believes he/she is passing in.
32
+ hx = permute_hidden(hx, sorted_indices)
33
+ end
34
+
35
+ check_forward_args(input, hx, batch_sizes)
36
+ if batch_sizes.nil?
37
+ result = Torch.lstm(input, hx, _get_flat_weights, @bias, @num_layers,
38
+ @dropout, @training, @bidirectional, @batch_first)
39
+ else
40
+ result = Torch.lstm(input, batch_sizes, hx, _get_flat_weights, @bias,
41
+ @num_layers, @dropout, @training, @bidirectional)
42
+ end
43
+ output = result[0]
44
+ hidden = result[1..-1]
45
+
46
+ [output, hidden]
47
+ end
48
+
49
+ def forward_tensor(input, hx: nil)
50
+ batch_sizes = nil
51
+ max_batch_size = @batch_first ? input.size(0) : input.size(1)
52
+ sorted_indices = nil
53
+ unsorted_indices = nil
54
+
55
+ output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
56
+
57
+ [output, permute_hidden(hidden, unsorted_indices)]
58
+ end
59
+
60
+ def forward(input, hx: nil)
61
+ # TODO PackedSequence
62
+ forward_tensor(input, hx: hx)
63
+ end
64
+ end
65
+ end
66
+ end
@@ -0,0 +1,14 @@
1
+ module Torch
2
+ module NN
3
+ class MarginRankingLoss < Loss
4
+ def initialize(margin: 1.0, reduction: "mean")
5
+ super(reduction)
6
+ @margin = margin
7
+ end
8
+
9
+ def forward(input1, input2, target)
10
+ F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool1d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool1d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool2d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool2d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPool3d < MaxPoolNd
4
+ def forward(input)
5
+ F.max_pool3d(input, @kernel_size, @stride, @padding, @dilation, @ceil_mode, @return_indices)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,19 @@
1
+ module Torch
2
+ module NN
3
+ class MaxPoolNd < Module
4
+ def initialize(kernel_size, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
5
+ super()
6
+ @kernel_size = kernel_size
7
+ @stride = stride || kernel_size
8
+ @padding = padding
9
+ @dilation = dilation
10
+ @return_indices = return_indices
11
+ @ceil_mode = ceil_mode
12
+ end
13
+
14
+ def extra_inspect
15
+ format("kernel_size: %s", @kernel_size)
16
+ end
17
+ end
18
+ end
19
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool1d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _single(kernel_size)
7
+ @stride = _single(stride || kernel_size)
8
+ @padding = _single(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool1d(input, indices, @kernel_size, stride: @stride, padding: @padding, output_size: output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool2d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _pair(kernel_size)
7
+ @stride = _pair(stride || kernel_size)
8
+ @padding = _pair(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool2d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,16 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpool3d < MaxUnpoolNd
4
+ def initialize(kernel_size, stride: nil, padding: 0)
5
+ super()
6
+ @kernel_size = _triple(kernel_size)
7
+ @stride = _triple(stride || kernel_size)
8
+ @padding = _triple(padding)
9
+ end
10
+
11
+ def forward(input, indices, output_size: nil)
12
+ F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,9 @@
1
+ module Torch
2
+ module NN
3
+ class MaxUnpoolNd < Module
4
+ def extra_inspect
5
+ format("kernel_size: %s, stride: %s, padding: %s", @kernel_size, @stride, @padding)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,56 +1,170 @@
1
1
  module Torch
2
2
  module NN
3
3
  class Module
4
+ include Utils
5
+
4
6
  def initialize
5
7
  @training = true
8
+ @parameters = {}
9
+ @buffers = {}
10
+ @modules = {}
6
11
  end
7
12
 
8
- def inspect
9
- str = String.new
10
- str << "#{self.class.name}(\n"
11
- modules.each do |name, mod|
12
- str << " (#{name}): #{mod.inspect}\n"
13
- end
14
- str << ")"
13
+ def forward
14
+ raise NotImplementedError
15
15
  end
16
16
 
17
- def train(mode = true)
18
- @training = mode
17
+ def register_buffer(name, tensor)
18
+ # TODO add checks
19
+ @buffers[name] = tensor
20
+ instance_variable_set("@#{name}", tensor)
21
+ end
19
22
 
20
- modules.each do |_, mod|
21
- mod.train(mode)
23
+ def register_parameter(name, param)
24
+ # TODO add checks
25
+ @parameters[name] = param
26
+ end
27
+
28
+ def add_module(name, mod)
29
+ # TODO add checks
30
+ @modules[name] = mod
31
+ end
32
+
33
+ def _apply(fn)
34
+ children.each do |mod|
35
+ mod._apply(fn)
22
36
  end
37
+ # TODO apply to more objects
38
+ self
23
39
  end
24
40
 
25
- def eval
26
- train(false)
41
+ def apply(fn)
42
+ children.each do |mod|
43
+ mod.apply(fn)
44
+ end
45
+ fn.call(self)
46
+ self
47
+ end
48
+
49
+ def cuda(device: nil)
50
+ _apply ->(t) { t.cuda(device) }
51
+ end
52
+
53
+ def cpu
54
+ _apply ->(t) { t.cpu }
55
+ end
56
+
57
+ def type(dst_type)
58
+ _apply ->(t) { t.type(dst_type) }
27
59
  end
28
60
 
29
- def call(*input)
30
- forward(*input)
61
+ def float
62
+ _apply ->(t) { t.floating_point? ? t.float : t }
63
+ end
64
+
65
+ def double
66
+ _apply ->(t) { t.floating_point? ? t.double : t }
67
+ end
68
+
69
+ def half
70
+ _apply ->(t) { t.floating_point? ? t.half : t }
31
71
  end
32
72
 
33
73
  # modifies in-place
34
74
  def to(device)
35
- instance_variables.each do |name|
36
- param = instance_variable_get(name)
37
- if param.is_a?(Parameter)
38
- instance_variable_set(name, Parameter.new(param.to(device)))
39
- end
75
+ convert = lambda do |t|
76
+ t.to(device)
40
77
  end
41
- modules.each do |_, mod|
42
- mod.to(device)
78
+
79
+ _apply(convert)
80
+ end
81
+
82
+ def call(*input, **kwargs)
83
+ forward(*input, **kwargs)
84
+ end
85
+
86
+ def state_dict(destination: nil)
87
+ destination ||= {}
88
+ named_parameters.each do |k, v|
89
+ destination[k] = v
43
90
  end
44
- self
91
+ destination
92
+ end
93
+
94
+ def load_state_dict(state_dict)
95
+ raise NotImplementedYet
45
96
  end
46
97
 
47
98
  def parameters
48
- params = []
99
+ named_parameters.values
100
+ end
101
+
102
+ def named_parameters(prefix: "", recurse: true)
103
+ params = {}
104
+ if recurse
105
+ named_children.each do |name, mod|
106
+ params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
107
+ end
108
+ end
49
109
  instance_variables.each do |name|
50
110
  param = instance_variable_get(name)
51
- params << param if param.is_a?(Parameter)
111
+ params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
112
+ end
113
+ @parameters.each do |name, param|
114
+ params[[prefix, name].join] = param
52
115
  end
53
- params + modules.flat_map { |_, mod| mod.parameters }
116
+ params
117
+ end
118
+
119
+ def buffers
120
+ named_buffers.values
121
+ end
122
+
123
+ def named_buffers
124
+ @buffers || {}
125
+ end
126
+
127
+ def children
128
+ named_children.values
129
+ end
130
+
131
+ def named_children
132
+ modules = {}
133
+ instance_variables.each do |name|
134
+ mod = instance_variable_get(name)
135
+ modules[name[1..-1]] = mod if mod.is_a?(Module)
136
+ end
137
+ @modules.each do |name, mod|
138
+ modules[name] = mod
139
+ end
140
+ modules
141
+ end
142
+
143
+ def modules
144
+ named_modules.values
145
+ end
146
+
147
+ def named_modules
148
+ {"" => self}.merge(named_children)
149
+ end
150
+
151
+ def train(mode = true)
152
+ @training = mode
153
+ children.each do |mod|
154
+ mod.train(mode)
155
+ end
156
+ self
157
+ end
158
+
159
+ def eval
160
+ train(false)
161
+ end
162
+
163
+ def requires_grad!(requires_grad: true)
164
+ parameters.each do |p|
165
+ p.requires_grad!(requires_grad)
166
+ end
167
+ self
54
168
  end
55
169
 
56
170
  def zero_grad
@@ -62,23 +176,60 @@ module Torch
62
176
  end
63
177
  end
64
178
 
179
+ def share_memory
180
+ _apply ->(t) { t.share_memory! }
181
+ end
182
+
183
+ def inspect
184
+ name = self.class.name.split("::").last
185
+ if children.empty?
186
+ "#{name}(#{extra_inspect})"
187
+ else
188
+ str = String.new
189
+ str << "#{name}(\n"
190
+ children.each do |name, mod|
191
+ str << " (#{name}): #{mod.inspect}\n"
192
+ end
193
+ str << ")"
194
+ end
195
+ end
196
+
65
197
  def method_missing(method, *args, &block)
66
- modules[method.to_s] || super
198
+ name = method.to_s
199
+ if named_parameters.key?(name)
200
+ named_parameters[name]
201
+ elsif named_buffers.key?(name)
202
+ named_buffers[name]
203
+ elsif named_modules.key?(name)
204
+ named_modules[name]
205
+ else
206
+ super
207
+ end
67
208
  end
68
209
 
69
210
  def respond_to?(method, include_private = false)
70
- modules.key?(method.to_s) || super
211
+ name = method.to_s
212
+ named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
71
213
  end
72
214
 
73
215
  private
74
216
 
75
- def modules
76
- modules = {}
77
- instance_variables.each do |name|
78
- mod = instance_variable_get(name)
79
- modules[name[1..-1]] = mod if mod.is_a?(Module)
80
- end
81
- modules
217
+ def extra_inspect
218
+ nil
219
+ end
220
+
221
+ def format(str, *vars, **options)
222
+ vars =
223
+ if vars.any?
224
+ vars.map(&:inspect)
225
+ else
226
+ options.map { |k, v| [k, v.inspect] }.to_h
227
+ end
228
+ str % vars
229
+ end
230
+
231
+ def dict
232
+ instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
82
233
  end
83
234
  end
84
235
  end