torch-rb 0.1.5 → 0.1.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +0 -170
- data/ext/torch/nn_functions.cpp +44 -24
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +48 -0
- data/ext/torch/tensor_functions.cpp +76 -16
- data/ext/torch/torch_functions.cpp +165 -65
- data/lib/torch.rb +51 -42
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +1 -1
- data/lib/torch/native/function.rb +36 -5
- data/lib/torch/native/generator.rb +26 -7
- data/lib/torch/native/parser.rb +51 -14
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +7 -2
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +9 -17
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +320 -100
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +1 -1
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +6 -6
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +7 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn_base.rb +48 -4
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/tensor.rb +14 -25
- data/lib/torch/version.rb +1 -1
- metadata +50 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 9667f9d3256f5e2d39937f17ae8eb00449dd14f79bb01cd647800bd7ed972fc6
|
4
|
+
data.tar.gz: 54c23612c79355e09c97da5fcf6b97c183da8316d1c2a53d6f8f0463e98342a2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bb2c8e5aae436367aeb871a2d19958e59ed9e9c7601b1b8b4473e33094cadf6d657947582b0ec93a29cb08723f8f7c81178a2d50beb23a125d5a356769d92177
|
7
|
+
data.tar.gz: 62feef39da31a19415e2e6c453aed4972e34db7367161a088944c06a977637a8b25cecc8eb2ad052b3b9deee0707f364e616cc33e7674cf0314899421f18fbee
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -369,7 +369,7 @@ Here are a few full examples:
|
|
369
369
|
|
370
370
|
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
371
371
|
- [Collaborative filtering with MovieLens](examples/movielens)
|
372
|
-
- [
|
372
|
+
- [Sequence models and word embeddings](examples/nlp)
|
373
373
|
|
374
374
|
## LibTorch Installation
|
375
375
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,13 +16,6 @@
|
|
16
16
|
|
17
17
|
using namespace Rice;
|
18
18
|
|
19
|
-
Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
|
20
|
-
Array a;
|
21
|
-
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
22
|
-
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
23
|
-
return Object(a);
|
24
|
-
}
|
25
|
-
|
26
19
|
extern "C"
|
27
20
|
void Init_ext()
|
28
21
|
{
|
@@ -112,94 +105,11 @@ void Init_ext()
|
|
112
105
|
return torch::zeros(size, options);
|
113
106
|
})
|
114
107
|
// begin operations
|
115
|
-
.define_singleton_method(
|
116
|
-
"_mean",
|
117
|
-
*[](Tensor& input) {
|
118
|
-
return torch::mean(input);
|
119
|
-
})
|
120
|
-
.define_singleton_method(
|
121
|
-
"_mean_dim",
|
122
|
-
*[](Tensor& input, int64_t dim, bool keepdim) {
|
123
|
-
return torch::mean(input, dim, keepdim);
|
124
|
-
})
|
125
|
-
.define_singleton_method(
|
126
|
-
"_sum",
|
127
|
-
*[](Tensor& input) {
|
128
|
-
return torch::sum(input);
|
129
|
-
})
|
130
|
-
.define_singleton_method(
|
131
|
-
"_sum_dim",
|
132
|
-
*[](Tensor& input, int64_t dim, bool keepdim) {
|
133
|
-
return torch::sum(input, dim, keepdim);
|
134
|
-
})
|
135
|
-
.define_singleton_method(
|
136
|
-
"_max_out",
|
137
|
-
*[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
|
138
|
-
return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
|
139
|
-
})
|
140
|
-
.define_singleton_method(
|
141
|
-
"_topk",
|
142
|
-
*[](Tensor& input, int64_t k) {
|
143
|
-
return tensor_array(torch::topk(input, k));
|
144
|
-
})
|
145
|
-
.define_singleton_method(
|
146
|
-
"_softmax",
|
147
|
-
*[](const Tensor &input, int64_t dim) {
|
148
|
-
return torch::softmax(input, dim);
|
149
|
-
})
|
150
|
-
.define_singleton_method(
|
151
|
-
"_log_softmax",
|
152
|
-
*[](Tensor& input, int64_t dim) {
|
153
|
-
return torch::log_softmax(input, dim);
|
154
|
-
})
|
155
|
-
.define_singleton_method(
|
156
|
-
"relu",
|
157
|
-
*[](Tensor& input) {
|
158
|
-
return torch::relu(input);
|
159
|
-
})
|
160
|
-
.define_singleton_method(
|
161
|
-
"prelu",
|
162
|
-
*[](torch::Tensor& input, torch::Tensor& weight) {
|
163
|
-
return torch::prelu(input, weight);
|
164
|
-
})
|
165
|
-
.define_singleton_method(
|
166
|
-
"leaky_relu",
|
167
|
-
*[](torch::Tensor& input, Scalar negative_slope) {
|
168
|
-
return torch::leaky_relu(input, negative_slope);
|
169
|
-
})
|
170
|
-
.define_singleton_method(
|
171
|
-
"conv2d",
|
172
|
-
*[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
173
|
-
return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
|
174
|
-
})
|
175
|
-
// linear layers
|
176
|
-
.define_singleton_method(
|
177
|
-
"bilinear",
|
178
|
-
*[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
|
179
|
-
return torch::bilinear(input1, input2, weight, bias);
|
180
|
-
})
|
181
|
-
.define_singleton_method(
|
182
|
-
"linear",
|
183
|
-
*[](Tensor& input, Tensor& weight, Tensor& bias) {
|
184
|
-
return torch::linear(input, weight, bias);
|
185
|
-
})
|
186
|
-
// pooling layers
|
187
|
-
.define_singleton_method(
|
188
|
-
"max_pool2d",
|
189
|
-
*[](Tensor& input, IntArrayRef kernel_size) {
|
190
|
-
return torch::max_pool2d(input, kernel_size);
|
191
|
-
})
|
192
|
-
.define_singleton_method(
|
193
|
-
"avg_pool2d",
|
194
|
-
*[](Tensor& input, IntArrayRef kernel_size) {
|
195
|
-
return torch::avg_pool2d(input, kernel_size);
|
196
|
-
})
|
197
108
|
.define_singleton_method(
|
198
109
|
"_binary_cross_entropy_with_logits",
|
199
110
|
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
200
111
|
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
201
112
|
})
|
202
|
-
.define_singleton_method("numel", &torch::numel)
|
203
113
|
.define_singleton_method(
|
204
114
|
"_from_blob",
|
205
115
|
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
@@ -226,16 +136,11 @@ void Init_ext()
|
|
226
136
|
|
227
137
|
rb_cTensor
|
228
138
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
229
|
-
.define_method("distributed?", &torch::Tensor::is_distributed)
|
230
|
-
.define_method("complex?", &torch::Tensor::is_complex)
|
231
|
-
.define_method("floating_point?", &torch::Tensor::is_floating_point)
|
232
|
-
.define_method("signed?", &torch::Tensor::is_signed)
|
233
139
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
234
140
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
235
141
|
.define_method("dim", &torch::Tensor::dim)
|
236
142
|
.define_method("element_size", &torch::Tensor::element_size)
|
237
143
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
238
|
-
.define_method("view_as", &torch::Tensor::view_as)
|
239
144
|
.define_method(
|
240
145
|
"addcmul!",
|
241
146
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -246,21 +151,6 @@ void Init_ext()
|
|
246
151
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
247
152
|
return self.addcdiv_(tensor1, tensor2, value);
|
248
153
|
})
|
249
|
-
.define_method(
|
250
|
-
"zero!",
|
251
|
-
*[](Tensor& self) {
|
252
|
-
return self.zero_();
|
253
|
-
})
|
254
|
-
.define_method(
|
255
|
-
"detach",
|
256
|
-
*[](Tensor& self) {
|
257
|
-
return self.detach();
|
258
|
-
})
|
259
|
-
.define_method(
|
260
|
-
"detach!",
|
261
|
-
*[](Tensor& self) {
|
262
|
-
return self.detach_();
|
263
|
-
})
|
264
154
|
.define_method(
|
265
155
|
"_requires_grad!",
|
266
156
|
*[](Tensor& self, bool requires_grad) {
|
@@ -300,66 +190,6 @@ void Init_ext()
|
|
300
190
|
s << self.device();
|
301
191
|
return s.str();
|
302
192
|
})
|
303
|
-
.define_method(
|
304
|
-
"resize_as!",
|
305
|
-
*[](Tensor& self, Tensor& other) {
|
306
|
-
return self.resize_as_(other);
|
307
|
-
})
|
308
|
-
.define_method(
|
309
|
-
"fill!",
|
310
|
-
*[](Tensor& self, Scalar value) {
|
311
|
-
return self.fill_(value);
|
312
|
-
})
|
313
|
-
.define_method(
|
314
|
-
"relu!",
|
315
|
-
*[](Tensor& self) {
|
316
|
-
return self.relu_();
|
317
|
-
})
|
318
|
-
.define_method(
|
319
|
-
"normal!",
|
320
|
-
*[](Tensor& self, double mean, double std) {
|
321
|
-
return self.normal_(mean, std);
|
322
|
-
})
|
323
|
-
.define_method(
|
324
|
-
"random!",
|
325
|
-
*[](Tensor& self, int64_t to) {
|
326
|
-
return self.random_(to);
|
327
|
-
})
|
328
|
-
.define_method(
|
329
|
-
"sub!",
|
330
|
-
*[](Tensor& self, Tensor& other) {
|
331
|
-
return self.sub_(other);
|
332
|
-
})
|
333
|
-
.define_method(
|
334
|
-
"div!",
|
335
|
-
*[](Tensor& self, Tensor& other) {
|
336
|
-
return self.div_(other);
|
337
|
-
})
|
338
|
-
.define_method(
|
339
|
-
"sqrt!",
|
340
|
-
*[](Tensor& self) {
|
341
|
-
return self.sqrt_();
|
342
|
-
})
|
343
|
-
.define_method(
|
344
|
-
"unsqueeze!",
|
345
|
-
*[](Tensor& self, int64_t dim) {
|
346
|
-
return self.unsqueeze_(dim);
|
347
|
-
})
|
348
|
-
.define_method(
|
349
|
-
"copy!",
|
350
|
-
*[](Tensor& self, Tensor& src) {
|
351
|
-
return self.copy_(src);
|
352
|
-
})
|
353
|
-
.define_method(
|
354
|
-
"clone",
|
355
|
-
*[](Tensor& self) {
|
356
|
-
return self.clone();
|
357
|
-
})
|
358
|
-
.define_method(
|
359
|
-
"data",
|
360
|
-
*[](Tensor& self) {
|
361
|
-
return self.data();
|
362
|
-
})
|
363
193
|
.define_method(
|
364
194
|
"_flat_data",
|
365
195
|
*[](Tensor& self) {
|
data/ext/torch/nn_functions.cpp
CHANGED
@@ -30,22 +30,42 @@ void add_nn_functions(Module m) {
|
|
30
30
|
.define_singleton_method(
|
31
31
|
"_adaptive_max_pool2d",
|
32
32
|
*[](const Tensor &self, IntArrayRef output_size) {
|
33
|
-
return torch::adaptive_max_pool2d(self, output_size);
|
33
|
+
return wrap(torch::adaptive_max_pool2d(self, output_size));
|
34
34
|
})
|
35
35
|
.define_singleton_method(
|
36
36
|
"_adaptive_max_pool2d_out",
|
37
37
|
*[](const Tensor &self, IntArrayRef output_size, Tensor &out, Tensor &indices) {
|
38
|
-
return torch::adaptive_max_pool2d_out(out, indices, self, output_size);
|
38
|
+
return wrap(torch::adaptive_max_pool2d_out(out, indices, self, output_size));
|
39
39
|
})
|
40
40
|
.define_singleton_method(
|
41
41
|
"_adaptive_max_pool3d",
|
42
42
|
*[](const Tensor &self, IntArrayRef output_size) {
|
43
|
-
return torch::adaptive_max_pool3d(self, output_size);
|
43
|
+
return wrap(torch::adaptive_max_pool3d(self, output_size));
|
44
44
|
})
|
45
45
|
.define_singleton_method(
|
46
46
|
"_adaptive_max_pool3d_out",
|
47
47
|
*[](const Tensor &self, IntArrayRef output_size, Tensor &out, Tensor &indices) {
|
48
|
-
return torch::adaptive_max_pool3d_out(out, indices, self, output_size);
|
48
|
+
return wrap(torch::adaptive_max_pool3d_out(out, indices, self, output_size));
|
49
|
+
})
|
50
|
+
.define_singleton_method(
|
51
|
+
"_avg_pool2d",
|
52
|
+
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad) {
|
53
|
+
return torch::avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad);
|
54
|
+
})
|
55
|
+
.define_singleton_method(
|
56
|
+
"_avg_pool2d_divisor_override",
|
57
|
+
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, int64_t divisor_override) {
|
58
|
+
return torch::avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override);
|
59
|
+
})
|
60
|
+
.define_singleton_method(
|
61
|
+
"_avg_pool3d",
|
62
|
+
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad) {
|
63
|
+
return torch::avg_pool3d(self, kernel_size, stride, padding, ceil_mode, count_include_pad);
|
64
|
+
})
|
65
|
+
.define_singleton_method(
|
66
|
+
"_avg_pool3d_divisor_override",
|
67
|
+
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, bool ceil_mode, bool count_include_pad, int64_t divisor_override) {
|
68
|
+
return torch::avg_pool3d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override);
|
49
69
|
})
|
50
70
|
.define_singleton_method(
|
51
71
|
"_binary_cross_entropy",
|
@@ -85,22 +105,22 @@ void add_nn_functions(Module m) {
|
|
85
105
|
.define_singleton_method(
|
86
106
|
"_fractional_max_pool2d",
|
87
107
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor &random_samples) {
|
88
|
-
return torch::fractional_max_pool2d(self, kernel_size, output_size, random_samples);
|
108
|
+
return wrap(torch::fractional_max_pool2d(self, kernel_size, output_size, random_samples));
|
89
109
|
})
|
90
110
|
.define_singleton_method(
|
91
111
|
"_fractional_max_pool2d_output",
|
92
112
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor &random_samples, Tensor &output, Tensor &indices) {
|
93
|
-
return torch::fractional_max_pool2d_out(output, indices, self, kernel_size, output_size, random_samples);
|
113
|
+
return wrap(torch::fractional_max_pool2d_out(output, indices, self, kernel_size, output_size, random_samples));
|
94
114
|
})
|
95
115
|
.define_singleton_method(
|
96
116
|
"_fractional_max_pool3d",
|
97
117
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor &random_samples) {
|
98
|
-
return torch::fractional_max_pool3d(self, kernel_size, output_size, random_samples);
|
118
|
+
return wrap(torch::fractional_max_pool3d(self, kernel_size, output_size, random_samples));
|
99
119
|
})
|
100
120
|
.define_singleton_method(
|
101
121
|
"_fractional_max_pool3d_output",
|
102
122
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef output_size, const Tensor &random_samples, Tensor &output, Tensor &indices) {
|
103
|
-
return torch::fractional_max_pool3d_out(output, indices, self, kernel_size, output_size, random_samples);
|
123
|
+
return wrap(torch::fractional_max_pool3d_out(output, indices, self, kernel_size, output_size, random_samples));
|
104
124
|
})
|
105
125
|
.define_singleton_method(
|
106
126
|
"_gelu",
|
@@ -180,12 +200,12 @@ void add_nn_functions(Module m) {
|
|
180
200
|
.define_singleton_method(
|
181
201
|
"_log_sigmoid_forward",
|
182
202
|
*[](const Tensor &self) {
|
183
|
-
return torch::log_sigmoid_forward(self);
|
203
|
+
return wrap(torch::log_sigmoid_forward(self));
|
184
204
|
})
|
185
205
|
.define_singleton_method(
|
186
206
|
"_log_sigmoid_forward_output",
|
187
207
|
*[](const Tensor &self, Tensor &output, Tensor &buffer) {
|
188
|
-
return torch::log_sigmoid_forward_out(output, buffer, self);
|
208
|
+
return wrap(torch::log_sigmoid_forward_out(output, buffer, self));
|
189
209
|
})
|
190
210
|
.define_singleton_method(
|
191
211
|
"_log_sigmoid_out",
|
@@ -195,22 +215,22 @@ void add_nn_functions(Module m) {
|
|
195
215
|
.define_singleton_method(
|
196
216
|
"_max_pool2d_with_indices",
|
197
217
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
|
198
|
-
return torch::max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode);
|
218
|
+
return wrap(torch::max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode));
|
199
219
|
})
|
200
220
|
.define_singleton_method(
|
201
221
|
"_max_pool2d_with_indices_out",
|
202
222
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor &out, Tensor &indices) {
|
203
|
-
return torch::max_pool2d_with_indices_out(out, indices, self, kernel_size, stride, padding, dilation, ceil_mode);
|
223
|
+
return wrap(torch::max_pool2d_with_indices_out(out, indices, self, kernel_size, stride, padding, dilation, ceil_mode));
|
204
224
|
})
|
205
225
|
.define_singleton_method(
|
206
226
|
"_max_pool3d_with_indices",
|
207
227
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
|
208
|
-
return torch::max_pool3d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode);
|
228
|
+
return wrap(torch::max_pool3d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode));
|
209
229
|
})
|
210
230
|
.define_singleton_method(
|
211
231
|
"_max_pool3d_with_indices_out",
|
212
232
|
*[](const Tensor &self, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, Tensor &out, Tensor &indices) {
|
213
|
-
return torch::max_pool3d_with_indices_out(out, indices, self, kernel_size, stride, padding, dilation, ceil_mode);
|
233
|
+
return wrap(torch::max_pool3d_with_indices_out(out, indices, self, kernel_size, stride, padding, dilation, ceil_mode));
|
214
234
|
})
|
215
235
|
.define_singleton_method(
|
216
236
|
"_max_unpool2d",
|
@@ -270,12 +290,12 @@ void add_nn_functions(Module m) {
|
|
270
290
|
.define_singleton_method(
|
271
291
|
"_multilabel_margin_loss_forward",
|
272
292
|
*[](const Tensor &self, const Tensor &target, MyReduction reduction) {
|
273
|
-
return torch::multilabel_margin_loss_forward(self, target, reduction);
|
293
|
+
return wrap(torch::multilabel_margin_loss_forward(self, target, reduction));
|
274
294
|
})
|
275
295
|
.define_singleton_method(
|
276
296
|
"_multilabel_margin_loss_forward_output",
|
277
297
|
*[](const Tensor &self, const Tensor &target, MyReduction reduction, Tensor &output, Tensor &is_target) {
|
278
|
-
return torch::multilabel_margin_loss_forward_out(output, is_target, self, target, reduction);
|
298
|
+
return wrap(torch::multilabel_margin_loss_forward_out(output, is_target, self, target, reduction));
|
279
299
|
})
|
280
300
|
.define_singleton_method(
|
281
301
|
"_multilabel_margin_loss_out",
|
@@ -295,12 +315,12 @@ void add_nn_functions(Module m) {
|
|
295
315
|
.define_singleton_method(
|
296
316
|
"_nll_loss2d_forward",
|
297
317
|
*[](const Tensor &self, const Tensor &target, OptionalTensor weight, MyReduction reduction, int64_t ignore_index) {
|
298
|
-
return torch::nll_loss2d_forward(self, target, weight, reduction, ignore_index);
|
318
|
+
return wrap(torch::nll_loss2d_forward(self, target, weight, reduction, ignore_index));
|
299
319
|
})
|
300
320
|
.define_singleton_method(
|
301
321
|
"_nll_loss2d_forward_output",
|
302
322
|
*[](const Tensor &self, const Tensor &target, OptionalTensor weight, MyReduction reduction, int64_t ignore_index, Tensor &output, Tensor &total_weight) {
|
303
|
-
return torch::nll_loss2d_forward_out(output, total_weight, self, target, weight, reduction, ignore_index);
|
323
|
+
return wrap(torch::nll_loss2d_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
|
304
324
|
})
|
305
325
|
.define_singleton_method(
|
306
326
|
"_nll_loss2d_out",
|
@@ -310,12 +330,12 @@ void add_nn_functions(Module m) {
|
|
310
330
|
.define_singleton_method(
|
311
331
|
"_nll_loss_forward",
|
312
332
|
*[](const Tensor &self, const Tensor &target, OptionalTensor weight, MyReduction reduction, int64_t ignore_index) {
|
313
|
-
return torch::nll_loss_forward(self, target, weight, reduction, ignore_index);
|
333
|
+
return wrap(torch::nll_loss_forward(self, target, weight, reduction, ignore_index));
|
314
334
|
})
|
315
335
|
.define_singleton_method(
|
316
336
|
"_nll_loss_forward_output",
|
317
337
|
*[](const Tensor &self, const Tensor &target, OptionalTensor weight, MyReduction reduction, int64_t ignore_index, Tensor &output, Tensor &total_weight) {
|
318
|
-
return torch::nll_loss_forward_out(output, total_weight, self, target, weight, reduction, ignore_index);
|
338
|
+
return wrap(torch::nll_loss_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
|
319
339
|
})
|
320
340
|
.define_singleton_method(
|
321
341
|
"_nll_loss_out",
|
@@ -470,12 +490,12 @@ void add_nn_functions(Module m) {
|
|
470
490
|
.define_singleton_method(
|
471
491
|
"_thnn_conv2d_forward",
|
472
492
|
*[](const Tensor &self, const Tensor &weight, IntArrayRef kernel_size, OptionalTensor bias, IntArrayRef stride, IntArrayRef padding) {
|
473
|
-
return torch::thnn_conv2d_forward(self, weight, kernel_size, bias, stride, padding);
|
493
|
+
return wrap(torch::thnn_conv2d_forward(self, weight, kernel_size, bias, stride, padding));
|
474
494
|
})
|
475
495
|
.define_singleton_method(
|
476
496
|
"_thnn_conv2d_forward_output",
|
477
497
|
*[](const Tensor &self, const Tensor &weight, IntArrayRef kernel_size, OptionalTensor bias, IntArrayRef stride, IntArrayRef padding, Tensor &output, Tensor &finput, Tensor &fgrad_input) {
|
478
|
-
return torch::thnn_conv2d_forward_out(output, finput, fgrad_input, self, weight, kernel_size, bias, stride, padding);
|
498
|
+
return wrap(torch::thnn_conv2d_forward_out(output, finput, fgrad_input, self, weight, kernel_size, bias, stride, padding));
|
479
499
|
})
|
480
500
|
.define_singleton_method(
|
481
501
|
"_thnn_conv2d_out",
|
@@ -490,12 +510,12 @@ void add_nn_functions(Module m) {
|
|
490
510
|
.define_singleton_method(
|
491
511
|
"_thnn_conv3d_forward",
|
492
512
|
*[](const Tensor &self, const Tensor &weight, IntArrayRef kernel_size, OptionalTensor bias, IntArrayRef stride, IntArrayRef padding) {
|
493
|
-
return torch::thnn_conv3d_forward(self, weight, kernel_size, bias, stride, padding);
|
513
|
+
return wrap(torch::thnn_conv3d_forward(self, weight, kernel_size, bias, stride, padding));
|
494
514
|
})
|
495
515
|
.define_singleton_method(
|
496
516
|
"_thnn_conv3d_forward_output",
|
497
517
|
*[](const Tensor &self, const Tensor &weight, IntArrayRef kernel_size, OptionalTensor bias, IntArrayRef stride, IntArrayRef padding, Tensor &output, Tensor &finput, Tensor &fgrad_input) {
|
498
|
-
return torch::thnn_conv3d_forward_out(output, finput, fgrad_input, self, weight, kernel_size, bias, stride, padding);
|
518
|
+
return wrap(torch::thnn_conv3d_forward_out(output, finput, fgrad_input, self, weight, kernel_size, bias, stride, padding));
|
499
519
|
})
|
500
520
|
.define_singleton_method(
|
501
521
|
"_thnn_conv3d_out",
|
@@ -0,0 +1,55 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
#include <rice/Object.hpp>
|
3
|
+
#include "templates.hpp"
|
4
|
+
|
5
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
|
+
Array a;
|
7
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
8
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
9
|
+
return Object(a);
|
10
|
+
}
|
11
|
+
|
12
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
13
|
+
Array a;
|
14
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
15
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
16
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
17
|
+
return Object(a);
|
18
|
+
}
|
19
|
+
|
20
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
21
|
+
Array a;
|
22
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
23
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
24
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
25
|
+
a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
|
26
|
+
return Object(a);
|
27
|
+
}
|
28
|
+
|
29
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
30
|
+
Array a;
|
31
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
32
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
33
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
34
|
+
a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
|
35
|
+
a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
|
36
|
+
return Object(a);
|
37
|
+
}
|
38
|
+
|
39
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
40
|
+
Array a;
|
41
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
42
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
43
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
44
|
+
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
45
|
+
return Object(a);
|
46
|
+
}
|
47
|
+
|
48
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
49
|
+
Array a;
|
50
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
51
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
52
|
+
a.push(to_ruby<double>(std::get<2>(x)));
|
53
|
+
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
|
+
return Object(a);
|
55
|
+
}
|