torch-rb 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +46 -0
- data/README.md +426 -0
- data/ext/torch/ext.cpp +839 -0
- data/ext/torch/extconf.rb +25 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +422 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +85 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +37 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +100 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +85 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +14 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/tensor.rb +196 -0
- data/lib/torch/utils/data/data_loader.rb +27 -0
- data/lib/torch/utils/data/tensor_dataset.rb +22 -0
- data/lib/torch/version.rb +3 -0
- metadata +169 -0
data/ext/torch/ext.cpp
ADDED
@@ -0,0 +1,839 @@
|
|
1
|
+
#include <sstream>
|
2
|
+
|
3
|
+
#include <torch/torch.h>
|
4
|
+
|
5
|
+
#include <rice/Array.hpp>
|
6
|
+
#include <rice/Class.hpp>
|
7
|
+
#include <rice/Constructor.hpp>
|
8
|
+
|
9
|
+
using namespace Rice;
|
10
|
+
|
11
|
+
template<>
|
12
|
+
inline
|
13
|
+
long long from_ruby<long long>(Object x)
|
14
|
+
{
|
15
|
+
return NUM2LL(x);
|
16
|
+
}
|
17
|
+
|
18
|
+
template<>
|
19
|
+
inline
|
20
|
+
Object to_ruby<long long>(long long const & x)
|
21
|
+
{
|
22
|
+
return LL2NUM(x);
|
23
|
+
}
|
24
|
+
|
25
|
+
template<>
|
26
|
+
inline
|
27
|
+
unsigned long long from_ruby<unsigned long long>(Object x)
|
28
|
+
{
|
29
|
+
return NUM2ULL(x);
|
30
|
+
}
|
31
|
+
|
32
|
+
template<>
|
33
|
+
inline
|
34
|
+
Object to_ruby<unsigned long long>(unsigned long long const & x)
|
35
|
+
{
|
36
|
+
return ULL2NUM(x);
|
37
|
+
}
|
38
|
+
|
39
|
+
template<>
|
40
|
+
inline
|
41
|
+
short from_ruby<short>(Object x)
|
42
|
+
{
|
43
|
+
return NUM2SHORT(x);
|
44
|
+
}
|
45
|
+
|
46
|
+
template<>
|
47
|
+
inline
|
48
|
+
Object to_ruby<short>(short const & x)
|
49
|
+
{
|
50
|
+
return INT2NUM(x);
|
51
|
+
}
|
52
|
+
|
53
|
+
template<>
|
54
|
+
inline
|
55
|
+
unsigned short from_ruby<unsigned short>(Object x)
|
56
|
+
{
|
57
|
+
return NUM2USHORT(x);
|
58
|
+
}
|
59
|
+
|
60
|
+
template<>
|
61
|
+
inline
|
62
|
+
Object to_ruby<unsigned short>(unsigned short const & x)
|
63
|
+
{
|
64
|
+
return UINT2NUM(x);
|
65
|
+
}
|
66
|
+
|
67
|
+
// need to wrap torch::IntArrayRef() since
|
68
|
+
// it doesn't own underlying data
|
69
|
+
class IntArrayRef {
|
70
|
+
std::vector<int64_t> vec;
|
71
|
+
public:
|
72
|
+
IntArrayRef(Object o) {
|
73
|
+
Array a = Array(o);
|
74
|
+
for (size_t i = 0; i < a.size(); i++) {
|
75
|
+
vec.push_back(from_ruby<int64_t>(a[i]));
|
76
|
+
}
|
77
|
+
}
|
78
|
+
operator torch::IntArrayRef() {
|
79
|
+
return torch::IntArrayRef(vec);
|
80
|
+
}
|
81
|
+
};
|
82
|
+
|
83
|
+
template<>
|
84
|
+
inline
|
85
|
+
IntArrayRef from_ruby<IntArrayRef>(Object x)
|
86
|
+
{
|
87
|
+
return IntArrayRef(x);
|
88
|
+
}
|
89
|
+
|
90
|
+
// for now
|
91
|
+
class Scalar {
|
92
|
+
torch::Scalar value;
|
93
|
+
public:
|
94
|
+
Scalar(Object o) {
|
95
|
+
// TODO cast based on Ruby type
|
96
|
+
if (o.rb_type() == T_FIXNUM) {
|
97
|
+
value = torch::Scalar(from_ruby<int64_t>(o));
|
98
|
+
} else {
|
99
|
+
value = torch::Scalar(from_ruby<float>(o));
|
100
|
+
}
|
101
|
+
}
|
102
|
+
operator torch::Scalar() {
|
103
|
+
return value;
|
104
|
+
}
|
105
|
+
};
|
106
|
+
|
107
|
+
template<>
|
108
|
+
inline
|
109
|
+
Scalar from_ruby<Scalar>(Object x)
|
110
|
+
{
|
111
|
+
return Scalar(x);
|
112
|
+
}
|
113
|
+
|
114
|
+
class TensorList {
|
115
|
+
std::vector<torch::Tensor> vec;
|
116
|
+
public:
|
117
|
+
TensorList(Object o) {
|
118
|
+
Array a = Array(o);
|
119
|
+
for (size_t i = 0; i < a.size(); i++) {
|
120
|
+
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
121
|
+
}
|
122
|
+
}
|
123
|
+
operator torch::TensorList() {
|
124
|
+
return torch::TensorList(vec);
|
125
|
+
}
|
126
|
+
};
|
127
|
+
|
128
|
+
template<>
|
129
|
+
inline
|
130
|
+
TensorList from_ruby<TensorList>(Object x)
|
131
|
+
{
|
132
|
+
return TensorList(x);
|
133
|
+
}
|
134
|
+
|
135
|
+
extern "C"
|
136
|
+
void Init_ext()
|
137
|
+
{
|
138
|
+
Module rb_mTorch = define_module("Torch")
|
139
|
+
.define_singleton_method(
|
140
|
+
"grad_enabled?",
|
141
|
+
*[]() {
|
142
|
+
return torch::GradMode::is_enabled();
|
143
|
+
})
|
144
|
+
.define_singleton_method(
|
145
|
+
"_set_grad_enabled",
|
146
|
+
*[](bool enabled) {
|
147
|
+
torch::GradMode::set_enabled(enabled);
|
148
|
+
})
|
149
|
+
.define_singleton_method(
|
150
|
+
"floating_point?",
|
151
|
+
*[](torch::Tensor& input) {
|
152
|
+
return torch::is_floating_point(input);
|
153
|
+
})
|
154
|
+
.define_singleton_method(
|
155
|
+
"manual_seed",
|
156
|
+
*[](uint64_t seed) {
|
157
|
+
return torch::manual_seed(seed);
|
158
|
+
})
|
159
|
+
// begin tensor creation
|
160
|
+
.define_singleton_method(
|
161
|
+
"_arange",
|
162
|
+
*[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
|
163
|
+
return torch::arange(start, end, step, options);
|
164
|
+
})
|
165
|
+
.define_singleton_method(
|
166
|
+
"_empty",
|
167
|
+
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
168
|
+
return torch::empty(size, options);
|
169
|
+
})
|
170
|
+
.define_singleton_method(
|
171
|
+
"_eye",
|
172
|
+
*[](int64_t m, int64_t n, const torch::TensorOptions &options) {
|
173
|
+
return torch::eye(m, n, options);
|
174
|
+
})
|
175
|
+
.define_singleton_method(
|
176
|
+
"_full",
|
177
|
+
*[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
|
178
|
+
return torch::full(size, fill_value, options);
|
179
|
+
})
|
180
|
+
.define_singleton_method(
|
181
|
+
"_linspace",
|
182
|
+
*[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
|
183
|
+
return torch::linspace(start, end, steps, options);
|
184
|
+
})
|
185
|
+
.define_singleton_method(
|
186
|
+
"_logspace",
|
187
|
+
*[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
|
188
|
+
return torch::logspace(start, end, steps, base, options);
|
189
|
+
})
|
190
|
+
.define_singleton_method(
|
191
|
+
"_ones",
|
192
|
+
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
193
|
+
return torch::ones(size, options);
|
194
|
+
})
|
195
|
+
.define_singleton_method(
|
196
|
+
"_rand",
|
197
|
+
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
198
|
+
return torch::rand(size, options);
|
199
|
+
})
|
200
|
+
.define_singleton_method(
|
201
|
+
"_randint",
|
202
|
+
*[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
|
203
|
+
return torch::randint(low, high, size, options);
|
204
|
+
})
|
205
|
+
.define_singleton_method(
|
206
|
+
"_randn",
|
207
|
+
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
208
|
+
return torch::randn(size, options);
|
209
|
+
})
|
210
|
+
.define_singleton_method(
|
211
|
+
"_randperm",
|
212
|
+
*[](int64_t n, const torch::TensorOptions &options) {
|
213
|
+
return torch::randperm(n, options);
|
214
|
+
})
|
215
|
+
.define_singleton_method(
|
216
|
+
"_zeros",
|
217
|
+
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
218
|
+
return torch::zeros(size, options);
|
219
|
+
})
|
220
|
+
// begin operations
|
221
|
+
.define_singleton_method(
|
222
|
+
"_mean",
|
223
|
+
*[](torch::Tensor& input) {
|
224
|
+
return torch::mean(input);
|
225
|
+
})
|
226
|
+
.define_singleton_method(
|
227
|
+
"_mean_dim",
|
228
|
+
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
229
|
+
return torch::mean(input, dim, keepdim);
|
230
|
+
})
|
231
|
+
.define_singleton_method(
|
232
|
+
"_sum",
|
233
|
+
*[](torch::Tensor& input) {
|
234
|
+
return torch::sum(input);
|
235
|
+
})
|
236
|
+
.define_singleton_method(
|
237
|
+
"_sum_dim",
|
238
|
+
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
239
|
+
return torch::sum(input, dim, keepdim);
|
240
|
+
})
|
241
|
+
.define_singleton_method(
|
242
|
+
"_argmax",
|
243
|
+
*[](torch::Tensor& input) {
|
244
|
+
return torch::argmax(input);
|
245
|
+
})
|
246
|
+
.define_singleton_method(
|
247
|
+
"_argmax_dim",
|
248
|
+
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
249
|
+
return torch::argmax(input, dim, keepdim);
|
250
|
+
})
|
251
|
+
.define_singleton_method(
|
252
|
+
"_cat",
|
253
|
+
*[](TensorList tensors, int64_t dim) {
|
254
|
+
return torch::cat(tensors, dim);
|
255
|
+
})
|
256
|
+
.define_singleton_method(
|
257
|
+
"_norm",
|
258
|
+
*[](torch::Tensor& input) {
|
259
|
+
return torch::norm(input);
|
260
|
+
})
|
261
|
+
.define_singleton_method(
|
262
|
+
"_min",
|
263
|
+
*[](torch::Tensor& input) {
|
264
|
+
return torch::min(input);
|
265
|
+
})
|
266
|
+
.define_singleton_method(
|
267
|
+
"_max",
|
268
|
+
*[](torch::Tensor& input) {
|
269
|
+
return torch::max(input);
|
270
|
+
})
|
271
|
+
.define_singleton_method(
|
272
|
+
"_max_out",
|
273
|
+
*[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
|
274
|
+
// TODO add return value
|
275
|
+
torch::_max_out(max, max_indices, input, dim, keepdim);
|
276
|
+
})
|
277
|
+
.define_singleton_method(
|
278
|
+
"_sqrt",
|
279
|
+
*[](torch::Tensor& input) {
|
280
|
+
return torch::sqrt(input);
|
281
|
+
})
|
282
|
+
.define_singleton_method(
|
283
|
+
"_exp",
|
284
|
+
*[](torch::Tensor& input) {
|
285
|
+
return torch::exp(input);
|
286
|
+
})
|
287
|
+
.define_singleton_method(
|
288
|
+
"_log",
|
289
|
+
*[](torch::Tensor& input) {
|
290
|
+
return torch::log(input);
|
291
|
+
})
|
292
|
+
.define_singleton_method(
|
293
|
+
"_sign",
|
294
|
+
*[](torch::Tensor& input) {
|
295
|
+
return torch::sign(input);
|
296
|
+
})
|
297
|
+
.define_singleton_method(
|
298
|
+
"_unsqueeze",
|
299
|
+
*[](torch::Tensor& input, int64_t dim) {
|
300
|
+
return torch::unsqueeze(input, dim);
|
301
|
+
})
|
302
|
+
.define_singleton_method(
|
303
|
+
"_dot",
|
304
|
+
*[](torch::Tensor& input, torch::Tensor& tensor) {
|
305
|
+
return torch::dot(input, tensor);
|
306
|
+
})
|
307
|
+
.define_singleton_method(
|
308
|
+
"_matmul",
|
309
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
310
|
+
return torch::matmul(input, other);
|
311
|
+
})
|
312
|
+
.define_singleton_method(
|
313
|
+
"_eq",
|
314
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
315
|
+
return torch::eq(input, other);
|
316
|
+
})
|
317
|
+
.define_singleton_method(
|
318
|
+
"_gt",
|
319
|
+
// TODO support tensors
|
320
|
+
*[](torch::Tensor& input, Scalar other) {
|
321
|
+
return torch::gt(input, other);
|
322
|
+
})
|
323
|
+
.define_singleton_method(
|
324
|
+
"_lt",
|
325
|
+
// TODO support tensors
|
326
|
+
*[](torch::Tensor& input, Scalar other) {
|
327
|
+
return torch::lt(input, other);
|
328
|
+
})
|
329
|
+
.define_singleton_method(
|
330
|
+
"_add",
|
331
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
332
|
+
return torch::add(input, other);
|
333
|
+
})
|
334
|
+
.define_singleton_method(
|
335
|
+
"_add_scalar",
|
336
|
+
*[](torch::Tensor& input, Scalar other) {
|
337
|
+
return torch::add(input, other);
|
338
|
+
})
|
339
|
+
.define_singleton_method(
|
340
|
+
"_add_out",
|
341
|
+
*[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
|
342
|
+
return torch::add_out(out, input, other);
|
343
|
+
})
|
344
|
+
.define_singleton_method(
|
345
|
+
"_sub",
|
346
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
347
|
+
return torch::sub(input, other);
|
348
|
+
})
|
349
|
+
.define_singleton_method(
|
350
|
+
"_sub_scalar",
|
351
|
+
*[](torch::Tensor& input, Scalar other) {
|
352
|
+
return torch::sub(input, other);
|
353
|
+
})
|
354
|
+
.define_singleton_method(
|
355
|
+
"_mul",
|
356
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
357
|
+
return torch::mul(input, other);
|
358
|
+
})
|
359
|
+
.define_singleton_method(
|
360
|
+
"_mul_scalar",
|
361
|
+
*[](torch::Tensor& input, Scalar other) {
|
362
|
+
return torch::mul(input, other);
|
363
|
+
})
|
364
|
+
.define_singleton_method(
|
365
|
+
"_div",
|
366
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
367
|
+
return torch::div(input, other);
|
368
|
+
})
|
369
|
+
.define_singleton_method(
|
370
|
+
"_div_scalar",
|
371
|
+
*[](torch::Tensor& input, Scalar other) {
|
372
|
+
return torch::div(input, other);
|
373
|
+
})
|
374
|
+
.define_singleton_method(
|
375
|
+
"_remainder",
|
376
|
+
*[](torch::Tensor& input, torch::Tensor& other) {
|
377
|
+
return torch::remainder(input, other);
|
378
|
+
})
|
379
|
+
.define_singleton_method(
|
380
|
+
"_remainder_scalar",
|
381
|
+
*[](torch::Tensor& input, Scalar other) {
|
382
|
+
return torch::remainder(input, other);
|
383
|
+
})
|
384
|
+
.define_singleton_method(
|
385
|
+
"_pow",
|
386
|
+
*[](torch::Tensor& input, Scalar exponent) {
|
387
|
+
return torch::pow(input, exponent);
|
388
|
+
})
|
389
|
+
.define_singleton_method(
|
390
|
+
"_abs",
|
391
|
+
*[](torch::Tensor& input) {
|
392
|
+
return torch::abs(input);
|
393
|
+
})
|
394
|
+
.define_singleton_method(
|
395
|
+
"_neg",
|
396
|
+
*[](torch::Tensor& input) {
|
397
|
+
return torch::neg(input);
|
398
|
+
})
|
399
|
+
.define_singleton_method(
|
400
|
+
"_reshape",
|
401
|
+
*[](torch::Tensor& input, IntArrayRef shape) {
|
402
|
+
return torch::reshape(input, shape);
|
403
|
+
})
|
404
|
+
.define_singleton_method(
|
405
|
+
"_flatten",
|
406
|
+
*[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
|
407
|
+
return torch::flatten(input, start_dim, end_dim);
|
408
|
+
})
|
409
|
+
.define_singleton_method(
|
410
|
+
"relu",
|
411
|
+
*[](torch::Tensor& input) {
|
412
|
+
return torch::relu(input);
|
413
|
+
})
|
414
|
+
.define_singleton_method(
|
415
|
+
"conv2d",
|
416
|
+
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
417
|
+
return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
|
418
|
+
})
|
419
|
+
.define_singleton_method(
|
420
|
+
"linear",
|
421
|
+
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
|
422
|
+
return torch::linear(input, weight, bias);
|
423
|
+
})
|
424
|
+
.define_singleton_method(
|
425
|
+
"max_pool2d",
|
426
|
+
*[](torch::Tensor& input, IntArrayRef kernel_size) {
|
427
|
+
return torch::max_pool2d(input, kernel_size);
|
428
|
+
})
|
429
|
+
.define_singleton_method(
|
430
|
+
"avg_pool2d",
|
431
|
+
*[](torch::Tensor& input, IntArrayRef kernel_size) {
|
432
|
+
return torch::avg_pool2d(input, kernel_size);
|
433
|
+
})
|
434
|
+
.define_singleton_method(
|
435
|
+
"_dropout",
|
436
|
+
*[](torch::Tensor& input, float p, bool train) {
|
437
|
+
return torch::dropout(input, p, train);
|
438
|
+
})
|
439
|
+
.define_singleton_method(
|
440
|
+
"_dropout!",
|
441
|
+
*[](torch::Tensor& input, float p, bool train) {
|
442
|
+
return torch::dropout_(input, p, train);
|
443
|
+
})
|
444
|
+
.define_singleton_method(
|
445
|
+
"_feature_dropout",
|
446
|
+
*[](torch::Tensor& input, float p, bool train) {
|
447
|
+
return torch::feature_dropout(input, p, train);
|
448
|
+
})
|
449
|
+
.define_singleton_method(
|
450
|
+
"_feature_dropout!",
|
451
|
+
*[](torch::Tensor& input, float p, bool train) {
|
452
|
+
return torch::feature_dropout_(input, p, train);
|
453
|
+
})
|
454
|
+
.define_singleton_method(
|
455
|
+
"_alpha_dropout",
|
456
|
+
*[](torch::Tensor& input, float p, bool train) {
|
457
|
+
return torch::alpha_dropout(input, p, train);
|
458
|
+
})
|
459
|
+
.define_singleton_method(
|
460
|
+
"_alpha_dropout!",
|
461
|
+
*[](torch::Tensor& input, float p, bool train) {
|
462
|
+
return torch::alpha_dropout_(input, p, train);
|
463
|
+
})
|
464
|
+
.define_singleton_method(
|
465
|
+
"_feature_alpha_dropout",
|
466
|
+
*[](torch::Tensor& input, float p, bool train) {
|
467
|
+
return torch::feature_alpha_dropout(input, p, train);
|
468
|
+
})
|
469
|
+
.define_singleton_method(
|
470
|
+
"_feature_alpha_dropout!",
|
471
|
+
*[](torch::Tensor& input, float p, bool train) {
|
472
|
+
return torch::feature_alpha_dropout_(input, p, train);
|
473
|
+
})
|
474
|
+
.define_singleton_method(
|
475
|
+
"_embedding",
|
476
|
+
// weight and indices are swapped from Python interface
|
477
|
+
*[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
|
478
|
+
return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
|
479
|
+
})
|
480
|
+
.define_singleton_method(
|
481
|
+
"mse_loss",
|
482
|
+
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
|
483
|
+
auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
|
484
|
+
return torch::mse_loss(input, target, red);
|
485
|
+
})
|
486
|
+
.define_singleton_method(
|
487
|
+
"nll_loss",
|
488
|
+
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
|
489
|
+
auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
|
490
|
+
return torch::nll_loss(input, target, {}, red);
|
491
|
+
})
|
492
|
+
.define_singleton_method("numel", &torch::numel)
|
493
|
+
.define_singleton_method(
|
494
|
+
"_from_blob",
|
495
|
+
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
496
|
+
void *data = const_cast<char *>(s.c_str());
|
497
|
+
return torch::from_blob(data, size, options);
|
498
|
+
})
|
499
|
+
.define_singleton_method(
|
500
|
+
"_tensor",
|
501
|
+
*[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
|
502
|
+
Array a = Array(o);
|
503
|
+
std::vector<float> vec;
|
504
|
+
for (size_t i = 0; i < a.size(); i++) {
|
505
|
+
vec.push_back(from_ruby<float>(a[i]));
|
506
|
+
}
|
507
|
+
return torch::tensor(vec, options).reshape(size);
|
508
|
+
});
|
509
|
+
|
510
|
+
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
|
511
|
+
.define_method("cuda?", &torch::Tensor::is_cuda)
|
512
|
+
.define_method("distributed?", &torch::Tensor::is_distributed)
|
513
|
+
.define_method("complex?", &torch::Tensor::is_complex)
|
514
|
+
.define_method("floating_point?", &torch::Tensor::is_floating_point)
|
515
|
+
.define_method("signed?", &torch::Tensor::is_signed)
|
516
|
+
.define_method("sparse?", &torch::Tensor::is_sparse)
|
517
|
+
.define_method("quantized?", &torch::Tensor::is_quantized)
|
518
|
+
.define_method("dim", &torch::Tensor::dim)
|
519
|
+
.define_method("element_size", &torch::Tensor::element_size)
|
520
|
+
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
521
|
+
.define_method("view_as", &torch::Tensor::view_as)
|
522
|
+
.define_method(
|
523
|
+
"addcmul!",
|
524
|
+
*[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
|
525
|
+
return self.addcmul_(tensor1, tensor2, value);
|
526
|
+
})
|
527
|
+
.define_method(
|
528
|
+
"addcdiv!",
|
529
|
+
*[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
|
530
|
+
return self.addcdiv_(tensor1, tensor2, value);
|
531
|
+
})
|
532
|
+
.define_method(
|
533
|
+
"zero!",
|
534
|
+
*[](torch::Tensor& self) {
|
535
|
+
return self.zero_();
|
536
|
+
})
|
537
|
+
.define_method(
|
538
|
+
"detach!",
|
539
|
+
*[](torch::Tensor& self) {
|
540
|
+
return self.detach_();
|
541
|
+
})
|
542
|
+
.define_method(
|
543
|
+
"_select",
|
544
|
+
*[](torch::Tensor& self, int64_t dim, int64_t index) {
|
545
|
+
return self.select(dim, index);
|
546
|
+
})
|
547
|
+
.define_method(
|
548
|
+
"_slice",
|
549
|
+
*[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
550
|
+
return self.slice(dim, start, end, step);
|
551
|
+
})
|
552
|
+
.define_method(
|
553
|
+
"_requires_grad!",
|
554
|
+
*[](torch::Tensor& self, bool requires_grad) {
|
555
|
+
return self.set_requires_grad(requires_grad);
|
556
|
+
})
|
557
|
+
.define_method(
|
558
|
+
"_backward",
|
559
|
+
*[](torch::Tensor& self) {
|
560
|
+
return self.backward();
|
561
|
+
})
|
562
|
+
.define_method(
|
563
|
+
"_backward_gradient",
|
564
|
+
*[](torch::Tensor& self, const torch::Tensor& gradient) {
|
565
|
+
return self.backward(gradient);
|
566
|
+
})
|
567
|
+
.define_method(
|
568
|
+
"grad",
|
569
|
+
*[](torch::Tensor& self) {
|
570
|
+
return self.grad();
|
571
|
+
})
|
572
|
+
.define_method(
|
573
|
+
"_dtype",
|
574
|
+
*[](torch::Tensor& self) {
|
575
|
+
return (int) at::typeMetaToScalarType(self.dtype());
|
576
|
+
})
|
577
|
+
.define_method(
|
578
|
+
"_type",
|
579
|
+
*[](torch::Tensor& self, int dtype) {
|
580
|
+
return self.toType((torch::ScalarType) dtype);
|
581
|
+
})
|
582
|
+
.define_method(
|
583
|
+
"_layout",
|
584
|
+
*[](torch::Tensor& self) {
|
585
|
+
std::stringstream s;
|
586
|
+
s << self.layout();
|
587
|
+
return s.str();
|
588
|
+
})
|
589
|
+
.define_method(
|
590
|
+
"device",
|
591
|
+
*[](torch::Tensor& self) {
|
592
|
+
std::stringstream s;
|
593
|
+
s << self.device();
|
594
|
+
return s.str();
|
595
|
+
})
|
596
|
+
.define_method(
|
597
|
+
"_view",
|
598
|
+
*[](torch::Tensor& self, IntArrayRef size) {
|
599
|
+
return self.view(size);
|
600
|
+
})
|
601
|
+
.define_method(
|
602
|
+
"resize_as!",
|
603
|
+
*[](torch::Tensor& self, torch::Tensor& other) {
|
604
|
+
return self.resize_as_(other);
|
605
|
+
})
|
606
|
+
.define_method(
|
607
|
+
"fill!",
|
608
|
+
*[](torch::Tensor& self, Scalar value) {
|
609
|
+
return self.fill_(value);
|
610
|
+
})
|
611
|
+
.define_method(
|
612
|
+
"_add!",
|
613
|
+
*[](torch::Tensor& self, torch::Tensor& other) {
|
614
|
+
return self.add_(other);
|
615
|
+
})
|
616
|
+
.define_method(
|
617
|
+
"_add_alpha!",
|
618
|
+
*[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
|
619
|
+
return self.add_(other, alpha);
|
620
|
+
})
|
621
|
+
.define_method(
|
622
|
+
"_add_scalar!",
|
623
|
+
*[](torch::Tensor& self, Scalar other) {
|
624
|
+
return self.add_(other);
|
625
|
+
})
|
626
|
+
.define_method(
|
627
|
+
"normal!",
|
628
|
+
*[](torch::Tensor& self, double mean, double std) {
|
629
|
+
return self.normal_(mean, std);
|
630
|
+
})
|
631
|
+
.define_method(
|
632
|
+
"sub!",
|
633
|
+
*[](torch::Tensor& self, torch::Tensor& other) {
|
634
|
+
return self.sub_(other);
|
635
|
+
})
|
636
|
+
.define_method(
|
637
|
+
"_mul!",
|
638
|
+
*[](torch::Tensor& self, torch::Tensor& other) {
|
639
|
+
return self.mul_(other);
|
640
|
+
})
|
641
|
+
.define_method(
|
642
|
+
"_mul_scalar!",
|
643
|
+
*[](torch::Tensor& self, Scalar other) {
|
644
|
+
return self.mul_(other);
|
645
|
+
})
|
646
|
+
.define_method(
|
647
|
+
"div!",
|
648
|
+
*[](torch::Tensor& self, torch::Tensor& other) {
|
649
|
+
return self.div_(other);
|
650
|
+
})
|
651
|
+
.define_method(
|
652
|
+
"sqrt!",
|
653
|
+
*[](torch::Tensor& self) {
|
654
|
+
return self.sqrt_();
|
655
|
+
})
|
656
|
+
.define_method(
|
657
|
+
"unsqueeze!",
|
658
|
+
*[](torch::Tensor& self, int64_t dim) {
|
659
|
+
return self.unsqueeze_(dim);
|
660
|
+
})
|
661
|
+
.define_method(
|
662
|
+
"copy!",
|
663
|
+
*[](torch::Tensor& self, torch::Tensor& src) {
|
664
|
+
return self.copy_(src);
|
665
|
+
})
|
666
|
+
.define_method(
|
667
|
+
"clone",
|
668
|
+
*[](torch::Tensor& self) {
|
669
|
+
return self.clone();
|
670
|
+
})
|
671
|
+
.define_method(
|
672
|
+
"log_softmax",
|
673
|
+
*[](torch::Tensor& self, int64_t dim) {
|
674
|
+
return self.log_softmax(dim);
|
675
|
+
})
|
676
|
+
.define_method(
|
677
|
+
"data",
|
678
|
+
*[](torch::Tensor& self) {
|
679
|
+
return self.data();
|
680
|
+
})
|
681
|
+
.define_method(
|
682
|
+
"_data",
|
683
|
+
*[](torch::Tensor& self) {
|
684
|
+
Array a;
|
685
|
+
auto dtype = self.dtype();
|
686
|
+
|
687
|
+
// TODO DRY if someone knows C++
|
688
|
+
if (dtype == torch::kByte) {
|
689
|
+
uint8_t* data = self.data_ptr<uint8_t>();
|
690
|
+
for (int i = 0; i < self.numel(); i++) {
|
691
|
+
a.push(data[i]);
|
692
|
+
}
|
693
|
+
} else if (dtype == torch::kChar) {
|
694
|
+
int8_t* data = self.data_ptr<int8_t>();
|
695
|
+
for (int i = 0; i < self.numel(); i++) {
|
696
|
+
a.push(to_ruby<int>(data[i]));
|
697
|
+
}
|
698
|
+
} else if (dtype == torch::kShort) {
|
699
|
+
int16_t* data = self.data_ptr<int16_t>();
|
700
|
+
for (int i = 0; i < self.numel(); i++) {
|
701
|
+
a.push(data[i]);
|
702
|
+
}
|
703
|
+
} else if (dtype == torch::kInt) {
|
704
|
+
int32_t* data = self.data_ptr<int32_t>();
|
705
|
+
for (int i = 0; i < self.numel(); i++) {
|
706
|
+
a.push(data[i]);
|
707
|
+
}
|
708
|
+
} else if (dtype == torch::kLong) {
|
709
|
+
int64_t* data = self.data_ptr<int64_t>();
|
710
|
+
for (int i = 0; i < self.numel(); i++) {
|
711
|
+
a.push(data[i]);
|
712
|
+
}
|
713
|
+
} else if (dtype == torch::kFloat) {
|
714
|
+
float* data = self.data_ptr<float>();
|
715
|
+
for (int i = 0; i < self.numel(); i++) {
|
716
|
+
a.push(data[i]);
|
717
|
+
}
|
718
|
+
} else if (dtype == torch::kDouble) {
|
719
|
+
double* data = self.data_ptr<double>();
|
720
|
+
for (int i = 0; i < self.numel(); i++) {
|
721
|
+
a.push(data[i]);
|
722
|
+
}
|
723
|
+
} else if (dtype == torch::kBool) {
|
724
|
+
bool* data = self.data_ptr<bool>();
|
725
|
+
for (int i = 0; i < self.numel(); i++) {
|
726
|
+
a.push(data[i] ? True : False);
|
727
|
+
}
|
728
|
+
} else {
|
729
|
+
throw std::runtime_error("Unsupported type");
|
730
|
+
}
|
731
|
+
return a;
|
732
|
+
})
|
733
|
+
.define_method(
|
734
|
+
"_size",
|
735
|
+
*[](torch::Tensor& self, int i) {
|
736
|
+
return self.size(i);
|
737
|
+
})
|
738
|
+
.define_method(
|
739
|
+
"_to",
|
740
|
+
*[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
741
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
742
|
+
})
|
743
|
+
.define_singleton_method(
|
744
|
+
"_make_subclass",
|
745
|
+
*[](torch::Tensor& rd, bool requires_grad) {
|
746
|
+
auto data = torch::autograd::as_variable_ref(rd).detach();
|
747
|
+
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
748
|
+
auto var = data.set_requires_grad(requires_grad);
|
749
|
+
return torch::autograd::Variable(std::move(var));
|
750
|
+
});
|
751
|
+
|
752
|
+
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
753
|
+
.define_constructor(Constructor<torch::TensorOptions>())
|
754
|
+
.define_method(
|
755
|
+
"dtype",
|
756
|
+
*[](torch::TensorOptions& self, int dtype) {
|
757
|
+
return self.dtype((torch::ScalarType) dtype);
|
758
|
+
})
|
759
|
+
.define_method(
|
760
|
+
"layout",
|
761
|
+
*[](torch::TensorOptions& self, std::string layout) {
|
762
|
+
torch::Layout l;
|
763
|
+
if (layout == "strided") {
|
764
|
+
l = torch::kStrided;
|
765
|
+
} else if (layout == "sparse") {
|
766
|
+
l = torch::kSparse;
|
767
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
768
|
+
} else {
|
769
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
770
|
+
}
|
771
|
+
return self.layout(l);
|
772
|
+
})
|
773
|
+
.define_method(
|
774
|
+
"device",
|
775
|
+
*[](torch::TensorOptions& self, std::string device) {
|
776
|
+
torch::DeviceType d;
|
777
|
+
if (device == "cpu") {
|
778
|
+
d = torch::kCPU;
|
779
|
+
} else if (device == "cuda") {
|
780
|
+
d = torch::kCUDA;
|
781
|
+
} else {
|
782
|
+
throw std::runtime_error("Unsupported device: " + device);
|
783
|
+
}
|
784
|
+
return self.device(d);
|
785
|
+
})
|
786
|
+
.define_method(
|
787
|
+
"requires_grad",
|
788
|
+
*[](torch::TensorOptions& self, bool requires_grad) {
|
789
|
+
return self.requires_grad(requires_grad);
|
790
|
+
});
|
791
|
+
|
792
|
+
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
793
|
+
|
794
|
+
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
795
|
+
.define_singleton_method(
|
796
|
+
"kaiming_uniform!",
|
797
|
+
*[](torch::Tensor& input, double a) {
|
798
|
+
return torch::nn::init::kaiming_uniform_(input, a);
|
799
|
+
})
|
800
|
+
.define_singleton_method(
|
801
|
+
"normal!",
|
802
|
+
*[](torch::Tensor& input) {
|
803
|
+
return torch::nn::init::normal_(input);
|
804
|
+
})
|
805
|
+
.define_singleton_method(
|
806
|
+
"uniform!",
|
807
|
+
*[](torch::Tensor& input, double to, double from) {
|
808
|
+
return torch::nn::init::uniform_(input, to, from);
|
809
|
+
});
|
810
|
+
|
811
|
+
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
|
812
|
+
// TODO return grad or nil to remove need for 2nd function
|
813
|
+
.define_method(
|
814
|
+
"_grad",
|
815
|
+
*[](torch::autograd::Variable& self) {
|
816
|
+
return self.grad();
|
817
|
+
})
|
818
|
+
.define_method(
|
819
|
+
"_grad_defined",
|
820
|
+
*[](torch::autograd::Variable& self) {
|
821
|
+
return self.grad().defined();
|
822
|
+
});
|
823
|
+
|
824
|
+
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
825
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
826
|
+
.define_method("index", &torch::Device::index)
|
827
|
+
.define_method("index?", &torch::Device::has_index)
|
828
|
+
.define_method(
|
829
|
+
"type",
|
830
|
+
*[](torch::Device& self) {
|
831
|
+
std::stringstream s;
|
832
|
+
s << self.type();
|
833
|
+
return s.str();
|
834
|
+
});
|
835
|
+
|
836
|
+
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
837
|
+
.define_singleton_method("available?", &torch::cuda::is_available)
|
838
|
+
.define_singleton_method("device_count", &torch::cuda::device_count);
|
839
|
+
}
|