torch-rb 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +1 -0
- data/ext/torch/ext.cpp +375 -124
- data/lib/torch.rb +101 -20
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +2 -2
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/functional.rb +101 -13
- data/lib/torch/nn/identity.rb +13 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +120 -31
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +0 -4
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +28 -10
- data/lib/torch/version.rb +1 -1
- metadata +29 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 4faccffe2d2fd29519ad9dcce0560978a07c734831b5f64bb4624a0037f2b08c
|
4
|
+
data.tar.gz: 4a8f873a9bb99c2311c856c59e5c43a5dfadd3f4f2460da1370ca1db888b79ad
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 199b3b47325b72b38786f50c39f0cfb9b11709f02edf4b77e1c4e9198baf5fa2b3924d639d6c0f7e1715193528d4cfab4f53fbba7e2b16e06ac462d37862cf3a
|
7
|
+
data.tar.gz: 4514f7aab60d9beabee47db175c9df6e3e1e93080b47eaeacff0f9dd4e8e737b476c8400f84bda4ccdf61aca4f7c81e010bcfbc2fa67f8c629c33cd6dcdcb54c
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## 0.1.4 (2019-12-01)
|
2
|
+
|
3
|
+
- Added distance functions
|
4
|
+
- Added more activations
|
5
|
+
- Added more linear layers
|
6
|
+
- Added more loss functions
|
7
|
+
- Added more init methods
|
8
|
+
- Added support for tensor assignment
|
9
|
+
|
1
10
|
## 0.1.3 (2019-11-30)
|
2
11
|
|
3
12
|
- Changed to BSD 3-Clause license to match PyTorch
|
data/README.md
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -132,6 +132,112 @@ TensorList from_ruby<TensorList>(Object x)
|
|
132
132
|
return TensorList(x);
|
133
133
|
}
|
134
134
|
|
135
|
+
class FanModeType {
|
136
|
+
std::string s;
|
137
|
+
public:
|
138
|
+
FanModeType(Object o) {
|
139
|
+
s = String(o).str();
|
140
|
+
}
|
141
|
+
// TODO switch NonlinearityType after LibTorch 1.4 release
|
142
|
+
operator torch::nn::init::FanMode() {
|
143
|
+
if (s == "fan_in") {
|
144
|
+
return torch::nn::init::FanMode::FanIn;
|
145
|
+
} else if (s == "fan_out") {
|
146
|
+
return torch::nn::init::FanMode::FanOut;
|
147
|
+
} else {
|
148
|
+
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
149
|
+
}
|
150
|
+
}
|
151
|
+
};
|
152
|
+
|
153
|
+
template<>
|
154
|
+
inline
|
155
|
+
FanModeType from_ruby<FanModeType>(Object x)
|
156
|
+
{
|
157
|
+
return FanModeType(x);
|
158
|
+
}
|
159
|
+
|
160
|
+
class NonlinearityType {
|
161
|
+
std::string s;
|
162
|
+
public:
|
163
|
+
NonlinearityType(Object o) {
|
164
|
+
s = String(o).str();
|
165
|
+
}
|
166
|
+
// TODO switch NonlinearityType after LibTorch 1.4 release
|
167
|
+
operator torch::nn::init::Nonlinearity() {
|
168
|
+
if (s == "linear") {
|
169
|
+
return torch::nn::init::Nonlinearity::Linear;
|
170
|
+
} else if (s == "conv1d") {
|
171
|
+
return torch::nn::init::Nonlinearity::Conv1D;
|
172
|
+
} else if (s == "conv2d") {
|
173
|
+
return torch::nn::init::Nonlinearity::Conv2D;
|
174
|
+
} else if (s == "conv3d") {
|
175
|
+
return torch::nn::init::Nonlinearity::Conv3D;
|
176
|
+
} else if (s == "conv_transpose1d") {
|
177
|
+
return torch::nn::init::Nonlinearity::ConvTranspose1D;
|
178
|
+
} else if (s == "conv_transpose2d") {
|
179
|
+
return torch::nn::init::Nonlinearity::ConvTranspose2D;
|
180
|
+
} else if (s == "conv_transpose3d") {
|
181
|
+
return torch::nn::init::Nonlinearity::ConvTranspose3D;
|
182
|
+
} else if (s == "sigmoid") {
|
183
|
+
return torch::nn::init::Nonlinearity::Sigmoid;
|
184
|
+
} else if (s == "tanh") {
|
185
|
+
return torch::nn::init::Nonlinearity::Tanh;
|
186
|
+
} else if (s == "relu") {
|
187
|
+
return torch::nn::init::Nonlinearity::ReLU;
|
188
|
+
} else if (s == "leaky_relu") {
|
189
|
+
return torch::nn::init::Nonlinearity::LeakyReLU;
|
190
|
+
} else {
|
191
|
+
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
192
|
+
}
|
193
|
+
}
|
194
|
+
};
|
195
|
+
|
196
|
+
template<>
|
197
|
+
inline
|
198
|
+
NonlinearityType from_ruby<NonlinearityType>(Object x)
|
199
|
+
{
|
200
|
+
return NonlinearityType(x);
|
201
|
+
}
|
202
|
+
|
203
|
+
class MyReduction {
|
204
|
+
Object value;
|
205
|
+
public:
|
206
|
+
MyReduction(Object o) {
|
207
|
+
value = o;
|
208
|
+
}
|
209
|
+
operator int64_t() {
|
210
|
+
if (value.is_nil()) {
|
211
|
+
return Reduction::None;
|
212
|
+
}
|
213
|
+
|
214
|
+
std::string s = String(value).str();
|
215
|
+
if (s == "mean") {
|
216
|
+
return Reduction::Mean;
|
217
|
+
} else if (s == "sum") {
|
218
|
+
return Reduction::Sum;
|
219
|
+
} else {
|
220
|
+
throw std::runtime_error("Unsupported reduction: " + s);
|
221
|
+
}
|
222
|
+
}
|
223
|
+
};
|
224
|
+
|
225
|
+
template<>
|
226
|
+
inline
|
227
|
+
MyReduction from_ruby<MyReduction>(Object x)
|
228
|
+
{
|
229
|
+
return MyReduction(x);
|
230
|
+
}
|
231
|
+
|
232
|
+
typedef torch::Tensor Tensor;
|
233
|
+
|
234
|
+
Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
|
235
|
+
Array a;
|
236
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
237
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
238
|
+
return Object(a);
|
239
|
+
}
|
240
|
+
|
135
241
|
extern "C"
|
136
242
|
void Init_ext()
|
137
243
|
{
|
@@ -148,7 +254,7 @@ void Init_ext()
|
|
148
254
|
})
|
149
255
|
.define_singleton_method(
|
150
256
|
"floating_point?",
|
151
|
-
*[](
|
257
|
+
*[](Tensor& input) {
|
152
258
|
return torch::is_floating_point(input);
|
153
259
|
})
|
154
260
|
.define_singleton_method(
|
@@ -220,32 +326,32 @@ void Init_ext()
|
|
220
326
|
// begin operations
|
221
327
|
.define_singleton_method(
|
222
328
|
"_mean",
|
223
|
-
*[](
|
329
|
+
*[](Tensor& input) {
|
224
330
|
return torch::mean(input);
|
225
331
|
})
|
226
332
|
.define_singleton_method(
|
227
333
|
"_mean_dim",
|
228
|
-
*[](
|
334
|
+
*[](Tensor& input, int64_t dim, bool keepdim) {
|
229
335
|
return torch::mean(input, dim, keepdim);
|
230
336
|
})
|
231
337
|
.define_singleton_method(
|
232
338
|
"_sum",
|
233
|
-
*[](
|
339
|
+
*[](Tensor& input) {
|
234
340
|
return torch::sum(input);
|
235
341
|
})
|
236
342
|
.define_singleton_method(
|
237
343
|
"_sum_dim",
|
238
|
-
*[](
|
344
|
+
*[](Tensor& input, int64_t dim, bool keepdim) {
|
239
345
|
return torch::sum(input, dim, keepdim);
|
240
346
|
})
|
241
347
|
.define_singleton_method(
|
242
348
|
"_argmax",
|
243
|
-
*[](
|
349
|
+
*[](Tensor& input) {
|
244
350
|
return torch::argmax(input);
|
245
351
|
})
|
246
352
|
.define_singleton_method(
|
247
353
|
"_argmax_dim",
|
248
|
-
*[](
|
354
|
+
*[](Tensor& input, int64_t dim, bool keepdim) {
|
249
355
|
return torch::argmax(input, dim, keepdim);
|
250
356
|
})
|
251
357
|
.define_singleton_method(
|
@@ -255,239 +361,322 @@ void Init_ext()
|
|
255
361
|
})
|
256
362
|
.define_singleton_method(
|
257
363
|
"_norm",
|
258
|
-
*[](
|
364
|
+
*[](Tensor& input) {
|
259
365
|
return torch::norm(input);
|
260
366
|
})
|
261
367
|
.define_singleton_method(
|
262
368
|
"_min",
|
263
|
-
*[](
|
369
|
+
*[](Tensor& input) {
|
264
370
|
return torch::min(input);
|
265
371
|
})
|
266
372
|
.define_singleton_method(
|
267
373
|
"_max",
|
268
|
-
*[](
|
374
|
+
*[](Tensor& input) {
|
269
375
|
return torch::max(input);
|
270
376
|
})
|
271
377
|
.define_singleton_method(
|
272
378
|
"_max_out",
|
273
|
-
*[](
|
274
|
-
|
275
|
-
torch::_max_out(max, max_indices, input, dim, keepdim);
|
379
|
+
*[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
|
380
|
+
return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
|
276
381
|
})
|
277
382
|
.define_singleton_method(
|
278
383
|
"_sqrt",
|
279
|
-
*[](
|
384
|
+
*[](Tensor& input) {
|
280
385
|
return torch::sqrt(input);
|
281
386
|
})
|
282
387
|
.define_singleton_method(
|
283
388
|
"_exp",
|
284
|
-
*[](
|
389
|
+
*[](Tensor& input) {
|
285
390
|
return torch::exp(input);
|
286
391
|
})
|
287
392
|
.define_singleton_method(
|
288
393
|
"_log",
|
289
|
-
*[](
|
394
|
+
*[](Tensor& input) {
|
290
395
|
return torch::log(input);
|
291
396
|
})
|
292
397
|
.define_singleton_method(
|
293
398
|
"_sign",
|
294
|
-
*[](
|
399
|
+
*[](Tensor& input) {
|
295
400
|
return torch::sign(input);
|
296
401
|
})
|
297
402
|
.define_singleton_method(
|
298
403
|
"_unsqueeze",
|
299
|
-
*[](
|
404
|
+
*[](Tensor& input, int64_t dim) {
|
300
405
|
return torch::unsqueeze(input, dim);
|
301
406
|
})
|
302
407
|
.define_singleton_method(
|
303
408
|
"_dot",
|
304
|
-
*[](
|
409
|
+
*[](Tensor& input, Tensor& tensor) {
|
305
410
|
return torch::dot(input, tensor);
|
306
411
|
})
|
307
412
|
.define_singleton_method(
|
308
413
|
"_matmul",
|
309
|
-
*[](
|
414
|
+
*[](Tensor& input, Tensor& other) {
|
310
415
|
return torch::matmul(input, other);
|
311
416
|
})
|
312
417
|
.define_singleton_method(
|
313
418
|
"_eq",
|
314
|
-
*[](
|
419
|
+
*[](Tensor& input, Tensor& other) {
|
315
420
|
return torch::eq(input, other);
|
316
421
|
})
|
317
422
|
.define_singleton_method(
|
318
423
|
"_gt",
|
319
424
|
// TODO support tensors
|
320
|
-
*[](
|
425
|
+
*[](Tensor& input, Scalar other) {
|
321
426
|
return torch::gt(input, other);
|
322
427
|
})
|
323
428
|
.define_singleton_method(
|
324
429
|
"_lt",
|
325
430
|
// TODO support tensors
|
326
|
-
*[](
|
431
|
+
*[](Tensor& input, Scalar other) {
|
327
432
|
return torch::lt(input, other);
|
328
433
|
})
|
329
434
|
.define_singleton_method(
|
330
435
|
"_add",
|
331
|
-
*[](
|
436
|
+
*[](Tensor& input, Tensor& other) {
|
332
437
|
return torch::add(input, other);
|
333
438
|
})
|
334
439
|
.define_singleton_method(
|
335
440
|
"_add_scalar",
|
336
|
-
*[](
|
441
|
+
*[](Tensor& input, Scalar other) {
|
337
442
|
return torch::add(input, other);
|
338
443
|
})
|
339
444
|
.define_singleton_method(
|
340
445
|
"_add_out",
|
341
|
-
*[](
|
446
|
+
*[](Tensor& out, Tensor& input, Tensor& other) {
|
342
447
|
return torch::add_out(out, input, other);
|
343
448
|
})
|
344
449
|
.define_singleton_method(
|
345
450
|
"_sub",
|
346
|
-
*[](
|
451
|
+
*[](Tensor& input, Tensor& other) {
|
347
452
|
return torch::sub(input, other);
|
348
453
|
})
|
349
454
|
.define_singleton_method(
|
350
455
|
"_sub_scalar",
|
351
|
-
*[](
|
456
|
+
*[](Tensor& input, Scalar other) {
|
352
457
|
return torch::sub(input, other);
|
353
458
|
})
|
354
459
|
.define_singleton_method(
|
355
460
|
"_mul",
|
356
|
-
*[](
|
461
|
+
*[](Tensor& input, Tensor& other) {
|
357
462
|
return torch::mul(input, other);
|
358
463
|
})
|
359
464
|
.define_singleton_method(
|
360
465
|
"_mul_scalar",
|
361
|
-
*[](
|
466
|
+
*[](Tensor& input, Scalar other) {
|
362
467
|
return torch::mul(input, other);
|
363
468
|
})
|
364
469
|
.define_singleton_method(
|
365
470
|
"_div",
|
366
|
-
*[](
|
471
|
+
*[](Tensor& input, Tensor& other) {
|
367
472
|
return torch::div(input, other);
|
368
473
|
})
|
369
474
|
.define_singleton_method(
|
370
475
|
"_div_scalar",
|
371
|
-
*[](
|
476
|
+
*[](Tensor& input, Scalar other) {
|
372
477
|
return torch::div(input, other);
|
373
478
|
})
|
374
479
|
.define_singleton_method(
|
375
480
|
"_remainder",
|
376
|
-
*[](
|
481
|
+
*[](Tensor& input, Tensor& other) {
|
377
482
|
return torch::remainder(input, other);
|
378
483
|
})
|
379
484
|
.define_singleton_method(
|
380
485
|
"_remainder_scalar",
|
381
|
-
*[](
|
486
|
+
*[](Tensor& input, Scalar other) {
|
382
487
|
return torch::remainder(input, other);
|
383
488
|
})
|
384
489
|
.define_singleton_method(
|
385
490
|
"_pow",
|
386
|
-
*[](
|
491
|
+
*[](Tensor& input, Scalar exponent) {
|
387
492
|
return torch::pow(input, exponent);
|
388
493
|
})
|
494
|
+
.define_singleton_method(
|
495
|
+
"_topk",
|
496
|
+
*[](Tensor& input, int64_t k) {
|
497
|
+
return tensor_array(torch::topk(input, k));
|
498
|
+
})
|
499
|
+
.define_singleton_method(
|
500
|
+
"_sigmoid",
|
501
|
+
*[](Tensor& input) {
|
502
|
+
return torch::sigmoid(input);
|
503
|
+
})
|
504
|
+
.define_singleton_method(
|
505
|
+
"_softplus",
|
506
|
+
*[](const Tensor &input, Scalar beta, Scalar threshold) {
|
507
|
+
return torch::softplus(input, beta, threshold);
|
508
|
+
})
|
509
|
+
.define_singleton_method(
|
510
|
+
"_softmax",
|
511
|
+
*[](const Tensor &input, int64_t dim) {
|
512
|
+
return torch::softmax(input, dim);
|
513
|
+
})
|
514
|
+
.define_singleton_method(
|
515
|
+
"_log_softmax",
|
516
|
+
*[](Tensor& input, int64_t dim) {
|
517
|
+
return torch::log_softmax(input, dim);
|
518
|
+
})
|
389
519
|
.define_singleton_method(
|
390
520
|
"_abs",
|
391
|
-
*[](
|
521
|
+
*[](Tensor& input) {
|
392
522
|
return torch::abs(input);
|
393
523
|
})
|
394
524
|
.define_singleton_method(
|
395
525
|
"_neg",
|
396
|
-
*[](
|
526
|
+
*[](Tensor& input) {
|
397
527
|
return torch::neg(input);
|
398
528
|
})
|
399
529
|
.define_singleton_method(
|
400
530
|
"_reshape",
|
401
|
-
*[](
|
531
|
+
*[](Tensor& input, IntArrayRef shape) {
|
402
532
|
return torch::reshape(input, shape);
|
403
533
|
})
|
404
534
|
.define_singleton_method(
|
405
535
|
"_flatten",
|
406
|
-
*[](
|
536
|
+
*[](Tensor& input, int64_t start_dim, int64_t end_dim) {
|
407
537
|
return torch::flatten(input, start_dim, end_dim);
|
408
538
|
})
|
409
539
|
.define_singleton_method(
|
410
540
|
"relu",
|
411
|
-
*[](
|
541
|
+
*[](Tensor& input) {
|
412
542
|
return torch::relu(input);
|
413
543
|
})
|
544
|
+
.define_singleton_method(
|
545
|
+
"prelu",
|
546
|
+
*[](torch::Tensor& input, torch::Tensor& weight) {
|
547
|
+
return torch::prelu(input, weight);
|
548
|
+
})
|
549
|
+
.define_singleton_method(
|
550
|
+
"leaky_relu",
|
551
|
+
*[](torch::Tensor& input, Scalar negative_slope) {
|
552
|
+
return torch::leaky_relu(input, negative_slope);
|
553
|
+
})
|
414
554
|
.define_singleton_method(
|
415
555
|
"conv2d",
|
416
|
-
*[](
|
556
|
+
*[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
417
557
|
return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
|
418
558
|
})
|
559
|
+
// linear layers
|
560
|
+
.define_singleton_method(
|
561
|
+
"bilinear",
|
562
|
+
*[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
|
563
|
+
return torch::bilinear(input1, input2, weight, bias);
|
564
|
+
})
|
419
565
|
.define_singleton_method(
|
420
566
|
"linear",
|
421
|
-
*[](
|
567
|
+
*[](Tensor& input, Tensor& weight, Tensor& bias) {
|
422
568
|
return torch::linear(input, weight, bias);
|
423
569
|
})
|
570
|
+
// pooling layers
|
424
571
|
.define_singleton_method(
|
425
572
|
"max_pool2d",
|
426
|
-
*[](
|
573
|
+
*[](Tensor& input, IntArrayRef kernel_size) {
|
427
574
|
return torch::max_pool2d(input, kernel_size);
|
428
575
|
})
|
429
576
|
.define_singleton_method(
|
430
577
|
"avg_pool2d",
|
431
|
-
*[](
|
578
|
+
*[](Tensor& input, IntArrayRef kernel_size) {
|
432
579
|
return torch::avg_pool2d(input, kernel_size);
|
433
580
|
})
|
434
581
|
.define_singleton_method(
|
435
582
|
"_dropout",
|
436
|
-
*[](
|
583
|
+
*[](Tensor& input, float p, bool train) {
|
437
584
|
return torch::dropout(input, p, train);
|
438
585
|
})
|
439
586
|
.define_singleton_method(
|
440
587
|
"_dropout!",
|
441
|
-
*[](
|
588
|
+
*[](Tensor& input, float p, bool train) {
|
442
589
|
return torch::dropout_(input, p, train);
|
443
590
|
})
|
444
591
|
.define_singleton_method(
|
445
592
|
"_feature_dropout",
|
446
|
-
*[](
|
593
|
+
*[](Tensor& input, float p, bool train) {
|
447
594
|
return torch::feature_dropout(input, p, train);
|
448
595
|
})
|
449
596
|
.define_singleton_method(
|
450
597
|
"_feature_dropout!",
|
451
|
-
*[](
|
598
|
+
*[](Tensor& input, float p, bool train) {
|
452
599
|
return torch::feature_dropout_(input, p, train);
|
453
600
|
})
|
454
601
|
.define_singleton_method(
|
455
602
|
"_alpha_dropout",
|
456
|
-
*[](
|
603
|
+
*[](Tensor& input, float p, bool train) {
|
457
604
|
return torch::alpha_dropout(input, p, train);
|
458
605
|
})
|
459
606
|
.define_singleton_method(
|
460
607
|
"_alpha_dropout!",
|
461
|
-
*[](
|
608
|
+
*[](Tensor& input, float p, bool train) {
|
462
609
|
return torch::alpha_dropout_(input, p, train);
|
463
610
|
})
|
464
611
|
.define_singleton_method(
|
465
612
|
"_feature_alpha_dropout",
|
466
|
-
*[](
|
613
|
+
*[](Tensor& input, float p, bool train) {
|
467
614
|
return torch::feature_alpha_dropout(input, p, train);
|
468
615
|
})
|
469
616
|
.define_singleton_method(
|
470
617
|
"_feature_alpha_dropout!",
|
471
|
-
*[](
|
618
|
+
*[](Tensor& input, float p, bool train) {
|
472
619
|
return torch::feature_alpha_dropout_(input, p, train);
|
473
620
|
})
|
621
|
+
// sparse layers
|
474
622
|
.define_singleton_method(
|
475
623
|
"_embedding",
|
476
624
|
// weight and indices are swapped from Python interface
|
477
|
-
*[](const
|
625
|
+
*[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
|
478
626
|
return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
|
479
627
|
})
|
628
|
+
.define_singleton_method(
|
629
|
+
"_embedding_bag",
|
630
|
+
// weight and indices are swapped from Python interface
|
631
|
+
*[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) {
|
632
|
+
return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights);
|
633
|
+
})
|
634
|
+
// distance functions
|
635
|
+
.define_singleton_method(
|
636
|
+
"_cosine_similarity",
|
637
|
+
*[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) {
|
638
|
+
return torch::cosine_similarity(x1, x2, dim, eps);
|
639
|
+
})
|
640
|
+
.define_singleton_method(
|
641
|
+
"_pairwise_distance",
|
642
|
+
*[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) {
|
643
|
+
return torch::pairwise_distance(x1, x2, p, eps, keepdim);
|
644
|
+
})
|
645
|
+
// loss functions
|
646
|
+
.define_singleton_method(
|
647
|
+
"binary_cross_entropy",
|
648
|
+
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
649
|
+
return torch::binary_cross_entropy(input, target, {}, reduction);
|
650
|
+
})
|
651
|
+
.define_singleton_method(
|
652
|
+
"ctc_loss",
|
653
|
+
*[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) {
|
654
|
+
return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
|
655
|
+
})
|
656
|
+
.define_singleton_method(
|
657
|
+
"kl_div",
|
658
|
+
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
659
|
+
return torch::kl_div(input, target, reduction);
|
660
|
+
})
|
661
|
+
.define_singleton_method(
|
662
|
+
"l1_loss",
|
663
|
+
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
664
|
+
return torch::l1_loss(input, target, reduction);
|
665
|
+
})
|
480
666
|
.define_singleton_method(
|
481
667
|
"mse_loss",
|
482
|
-
*[](
|
483
|
-
|
484
|
-
return torch::mse_loss(input, target, red);
|
668
|
+
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
669
|
+
return torch::mse_loss(input, target, reduction);
|
485
670
|
})
|
486
671
|
.define_singleton_method(
|
487
672
|
"nll_loss",
|
488
|
-
*[](
|
489
|
-
|
490
|
-
|
673
|
+
*[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) {
|
674
|
+
return torch::nll_loss(input, target, {}, reduction, ignore_index);
|
675
|
+
})
|
676
|
+
.define_singleton_method(
|
677
|
+
"poisson_nll_loss",
|
678
|
+
*[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) {
|
679
|
+
return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction);
|
491
680
|
})
|
492
681
|
.define_singleton_method("numel", &torch::numel)
|
493
682
|
.define_singleton_method(
|
@@ -500,11 +689,18 @@ void Init_ext()
|
|
500
689
|
"_tensor",
|
501
690
|
*[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
|
502
691
|
Array a = Array(o);
|
503
|
-
|
504
|
-
|
505
|
-
|
692
|
+
auto dtype = options.dtype();
|
693
|
+
torch::Tensor t;
|
694
|
+
if (dtype == torch::kBool) {
|
695
|
+
throw std::runtime_error("Cannot create bool from tensor method yet");
|
696
|
+
} else {
|
697
|
+
std::vector<float> vec;
|
698
|
+
for (size_t i = 0; i < a.size(); i++) {
|
699
|
+
vec.push_back(from_ruby<float>(a[i]));
|
700
|
+
}
|
701
|
+
t = torch::tensor(vec, options);
|
506
702
|
}
|
507
|
-
return
|
703
|
+
return t.reshape(size);
|
508
704
|
});
|
509
705
|
|
510
706
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
|
@@ -521,166 +717,171 @@ void Init_ext()
|
|
521
717
|
.define_method("view_as", &torch::Tensor::view_as)
|
522
718
|
.define_method(
|
523
719
|
"addcmul!",
|
524
|
-
*[](
|
720
|
+
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
525
721
|
return self.addcmul_(tensor1, tensor2, value);
|
526
722
|
})
|
527
723
|
.define_method(
|
528
724
|
"addcdiv!",
|
529
|
-
*[](
|
725
|
+
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
530
726
|
return self.addcdiv_(tensor1, tensor2, value);
|
531
727
|
})
|
532
728
|
.define_method(
|
533
729
|
"zero!",
|
534
|
-
*[](
|
730
|
+
*[](Tensor& self) {
|
535
731
|
return self.zero_();
|
536
732
|
})
|
733
|
+
.define_method(
|
734
|
+
"detach",
|
735
|
+
*[](Tensor& self) {
|
736
|
+
return self.detach();
|
737
|
+
})
|
537
738
|
.define_method(
|
538
739
|
"detach!",
|
539
|
-
*[](
|
740
|
+
*[](Tensor& self) {
|
540
741
|
return self.detach_();
|
541
742
|
})
|
542
743
|
.define_method(
|
543
744
|
"_select",
|
544
|
-
*[](
|
745
|
+
*[](Tensor& self, int64_t dim, int64_t index) {
|
545
746
|
return self.select(dim, index);
|
546
747
|
})
|
547
748
|
.define_method(
|
548
749
|
"_slice",
|
549
|
-
*[](
|
750
|
+
*[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
550
751
|
return self.slice(dim, start, end, step);
|
551
752
|
})
|
552
753
|
.define_method(
|
553
754
|
"_requires_grad!",
|
554
|
-
*[](
|
755
|
+
*[](Tensor& self, bool requires_grad) {
|
555
756
|
return self.set_requires_grad(requires_grad);
|
556
757
|
})
|
557
758
|
.define_method(
|
558
759
|
"_backward",
|
559
|
-
*[](
|
560
|
-
return self.backward();
|
561
|
-
})
|
562
|
-
.define_method(
|
563
|
-
"_backward_gradient",
|
564
|
-
*[](torch::Tensor& self, const torch::Tensor& gradient) {
|
565
|
-
return self.backward(gradient);
|
760
|
+
*[](Tensor& self, Object gradient) {
|
761
|
+
return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
|
566
762
|
})
|
567
763
|
.define_method(
|
568
764
|
"grad",
|
569
|
-
*[](
|
765
|
+
*[](Tensor& self) {
|
570
766
|
return self.grad();
|
571
767
|
})
|
572
768
|
.define_method(
|
573
769
|
"_dtype",
|
574
|
-
*[](
|
770
|
+
*[](Tensor& self) {
|
575
771
|
return (int) at::typeMetaToScalarType(self.dtype());
|
576
772
|
})
|
577
773
|
.define_method(
|
578
774
|
"_type",
|
579
|
-
*[](
|
775
|
+
*[](Tensor& self, int dtype) {
|
580
776
|
return self.toType((torch::ScalarType) dtype);
|
581
777
|
})
|
582
778
|
.define_method(
|
583
779
|
"_layout",
|
584
|
-
*[](
|
780
|
+
*[](Tensor& self) {
|
585
781
|
std::stringstream s;
|
586
782
|
s << self.layout();
|
587
783
|
return s.str();
|
588
784
|
})
|
589
785
|
.define_method(
|
590
786
|
"device",
|
591
|
-
*[](
|
787
|
+
*[](Tensor& self) {
|
592
788
|
std::stringstream s;
|
593
789
|
s << self.device();
|
594
790
|
return s.str();
|
595
791
|
})
|
596
792
|
.define_method(
|
597
793
|
"_view",
|
598
|
-
*[](
|
794
|
+
*[](Tensor& self, IntArrayRef size) {
|
599
795
|
return self.view(size);
|
600
796
|
})
|
601
797
|
.define_method(
|
602
798
|
"resize_as!",
|
603
|
-
*[](
|
799
|
+
*[](Tensor& self, Tensor& other) {
|
604
800
|
return self.resize_as_(other);
|
605
801
|
})
|
606
802
|
.define_method(
|
607
803
|
"fill!",
|
608
|
-
*[](
|
804
|
+
*[](Tensor& self, Scalar value) {
|
609
805
|
return self.fill_(value);
|
610
806
|
})
|
807
|
+
.define_method(
|
808
|
+
"relu!",
|
809
|
+
*[](Tensor& self) {
|
810
|
+
return self.relu_();
|
811
|
+
})
|
611
812
|
.define_method(
|
612
813
|
"_add!",
|
613
|
-
*[](
|
814
|
+
*[](Tensor& self, Tensor& other) {
|
614
815
|
return self.add_(other);
|
615
816
|
})
|
616
817
|
.define_method(
|
617
818
|
"_add_alpha!",
|
618
|
-
*[](
|
819
|
+
*[](Tensor& self, Tensor& other, Scalar alpha) {
|
619
820
|
return self.add_(other, alpha);
|
620
821
|
})
|
621
822
|
.define_method(
|
622
823
|
"_add_scalar!",
|
623
|
-
*[](
|
824
|
+
*[](Tensor& self, Scalar other) {
|
624
825
|
return self.add_(other);
|
625
826
|
})
|
626
827
|
.define_method(
|
627
828
|
"normal!",
|
628
|
-
*[](
|
829
|
+
*[](Tensor& self, double mean, double std) {
|
629
830
|
return self.normal_(mean, std);
|
630
831
|
})
|
832
|
+
.define_method(
|
833
|
+
"random!",
|
834
|
+
*[](Tensor& self, int64_t to) {
|
835
|
+
return self.random_(to);
|
836
|
+
})
|
631
837
|
.define_method(
|
632
838
|
"sub!",
|
633
|
-
*[](
|
839
|
+
*[](Tensor& self, Tensor& other) {
|
634
840
|
return self.sub_(other);
|
635
841
|
})
|
636
842
|
.define_method(
|
637
843
|
"_mul!",
|
638
|
-
*[](
|
844
|
+
*[](Tensor& self, Tensor& other) {
|
639
845
|
return self.mul_(other);
|
640
846
|
})
|
641
847
|
.define_method(
|
642
848
|
"_mul_scalar!",
|
643
|
-
*[](
|
849
|
+
*[](Tensor& self, Scalar other) {
|
644
850
|
return self.mul_(other);
|
645
851
|
})
|
646
852
|
.define_method(
|
647
853
|
"div!",
|
648
|
-
*[](
|
854
|
+
*[](Tensor& self, Tensor& other) {
|
649
855
|
return self.div_(other);
|
650
856
|
})
|
651
857
|
.define_method(
|
652
858
|
"sqrt!",
|
653
|
-
*[](
|
859
|
+
*[](Tensor& self) {
|
654
860
|
return self.sqrt_();
|
655
861
|
})
|
656
862
|
.define_method(
|
657
863
|
"unsqueeze!",
|
658
|
-
*[](
|
864
|
+
*[](Tensor& self, int64_t dim) {
|
659
865
|
return self.unsqueeze_(dim);
|
660
866
|
})
|
661
867
|
.define_method(
|
662
868
|
"copy!",
|
663
|
-
*[](
|
869
|
+
*[](Tensor& self, Tensor& src) {
|
664
870
|
return self.copy_(src);
|
665
871
|
})
|
666
872
|
.define_method(
|
667
873
|
"clone",
|
668
|
-
*[](
|
874
|
+
*[](Tensor& self) {
|
669
875
|
return self.clone();
|
670
876
|
})
|
671
|
-
.define_method(
|
672
|
-
"log_softmax",
|
673
|
-
*[](torch::Tensor& self, int64_t dim) {
|
674
|
-
return self.log_softmax(dim);
|
675
|
-
})
|
676
877
|
.define_method(
|
677
878
|
"data",
|
678
|
-
*[](
|
879
|
+
*[](Tensor& self) {
|
679
880
|
return self.data();
|
680
881
|
})
|
681
882
|
.define_method(
|
682
883
|
"_data",
|
683
|
-
*[](
|
884
|
+
*[](Tensor& self) {
|
684
885
|
Array a;
|
685
886
|
auto dtype = self.dtype();
|
686
887
|
|
@@ -732,17 +933,17 @@ void Init_ext()
|
|
732
933
|
})
|
733
934
|
.define_method(
|
734
935
|
"_size",
|
735
|
-
*[](
|
936
|
+
*[](Tensor& self, int i) {
|
736
937
|
return self.size(i);
|
737
938
|
})
|
738
939
|
.define_method(
|
739
940
|
"_to",
|
740
|
-
*[](
|
941
|
+
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
741
942
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
742
943
|
})
|
743
944
|
.define_singleton_method(
|
744
945
|
"_make_subclass",
|
745
|
-
*[](
|
946
|
+
*[](Tensor& rd, bool requires_grad) {
|
746
947
|
auto data = torch::autograd::as_variable_ref(rd).detach();
|
747
948
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
748
949
|
auto var = data.set_requires_grad(requires_grad);
|
@@ -793,32 +994,82 @@ void Init_ext()
|
|
793
994
|
|
794
995
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
795
996
|
.define_singleton_method(
|
796
|
-
"
|
797
|
-
*[](
|
798
|
-
return torch::nn::init::
|
997
|
+
"_calculate_gain",
|
998
|
+
*[](NonlinearityType nonlinearity, double param) {
|
999
|
+
return torch::nn::init::calculate_gain(nonlinearity, param);
|
799
1000
|
})
|
800
1001
|
.define_singleton_method(
|
801
|
-
"
|
802
|
-
*[](
|
803
|
-
return torch::nn::init::
|
1002
|
+
"_uniform!",
|
1003
|
+
*[](Tensor tensor, double low, double high) {
|
1004
|
+
return torch::nn::init::uniform_(tensor, low, high);
|
1005
|
+
})
|
1006
|
+
.define_singleton_method(
|
1007
|
+
"_normal!",
|
1008
|
+
*[](Tensor tensor, double mean, double std) {
|
1009
|
+
return torch::nn::init::normal_(tensor, mean, std);
|
1010
|
+
})
|
1011
|
+
.define_singleton_method(
|
1012
|
+
"_constant!",
|
1013
|
+
*[](Tensor tensor, Scalar value) {
|
1014
|
+
return torch::nn::init::constant_(tensor, value);
|
1015
|
+
})
|
1016
|
+
.define_singleton_method(
|
1017
|
+
"_ones!",
|
1018
|
+
*[](Tensor tensor) {
|
1019
|
+
return torch::nn::init::ones_(tensor);
|
1020
|
+
})
|
1021
|
+
.define_singleton_method(
|
1022
|
+
"_zeros!",
|
1023
|
+
*[](Tensor tensor) {
|
1024
|
+
return torch::nn::init::zeros_(tensor);
|
1025
|
+
})
|
1026
|
+
.define_singleton_method(
|
1027
|
+
"_eye!",
|
1028
|
+
*[](Tensor tensor) {
|
1029
|
+
return torch::nn::init::eye_(tensor);
|
1030
|
+
})
|
1031
|
+
.define_singleton_method(
|
1032
|
+
"_dirac!",
|
1033
|
+
*[](Tensor tensor) {
|
1034
|
+
return torch::nn::init::dirac_(tensor);
|
1035
|
+
})
|
1036
|
+
.define_singleton_method(
|
1037
|
+
"_xavier_uniform!",
|
1038
|
+
*[](Tensor tensor, double gain) {
|
1039
|
+
return torch::nn::init::xavier_uniform_(tensor, gain);
|
1040
|
+
})
|
1041
|
+
.define_singleton_method(
|
1042
|
+
"_xavier_normal!",
|
1043
|
+
*[](Tensor tensor, double gain) {
|
1044
|
+
return torch::nn::init::xavier_normal_(tensor, gain);
|
1045
|
+
})
|
1046
|
+
.define_singleton_method(
|
1047
|
+
"_kaiming_uniform!",
|
1048
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
1049
|
+
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
1050
|
+
})
|
1051
|
+
.define_singleton_method(
|
1052
|
+
"_kaiming_normal!",
|
1053
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
1054
|
+
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
804
1055
|
})
|
805
1056
|
.define_singleton_method(
|
806
|
-
"
|
807
|
-
*[](
|
808
|
-
return torch::nn::init::
|
1057
|
+
"_orthogonal!",
|
1058
|
+
*[](Tensor tensor, double gain) {
|
1059
|
+
return torch::nn::init::orthogonal_(tensor, gain);
|
1060
|
+
})
|
1061
|
+
.define_singleton_method(
|
1062
|
+
"_sparse!",
|
1063
|
+
*[](Tensor tensor, double sparsity, double std) {
|
1064
|
+
return torch::nn::init::sparse_(tensor, sparsity, std);
|
809
1065
|
});
|
810
1066
|
|
811
1067
|
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
|
812
|
-
// TODO return grad or nil to remove need for 2nd function
|
813
1068
|
.define_method(
|
814
|
-
"
|
815
|
-
*[](torch::autograd::Variable& self) {
|
816
|
-
return self.grad();
|
817
|
-
})
|
818
|
-
.define_method(
|
819
|
-
"_grad_defined",
|
1069
|
+
"grad",
|
820
1070
|
*[](torch::autograd::Variable& self) {
|
821
|
-
|
1071
|
+
auto grad = self.grad();
|
1072
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
822
1073
|
});
|
823
1074
|
|
824
1075
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|