torch-rb 0.3.1 → 0.3.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +6 -2
- data/ext/torch/ext.cpp +43 -12
- data/ext/torch/templates.cpp +8 -0
- data/ext/torch/templates.hpp +59 -54
- data/lib/torch.rb +48 -19
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/native/function.rb +9 -1
- data/lib/torch/native/generator.rb +19 -25
- data/lib/torch/native/parser.rb +13 -1
- data/lib/torch/nn/functional.rb +11 -3
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +6 -1
- data/lib/torch/tensor.rb +32 -44
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +24 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
|
4
|
+
data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
|
7
|
+
data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.6 (2020-09-17)
|
2
|
+
|
3
|
+
- Added `inplace` option for leaky ReLU
|
4
|
+
- Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
|
5
|
+
- Fixed error with buffers on GPU
|
6
|
+
|
7
|
+
## 0.3.5 (2020-09-04)
|
8
|
+
|
9
|
+
- Fixed error with data loader (due to `dtype` of `randperm`)
|
10
|
+
|
11
|
+
## 0.3.4 (2020-08-26)
|
12
|
+
|
13
|
+
- Added `Torch.clamp` method
|
14
|
+
|
15
|
+
## 0.3.3 (2020-08-25)
|
16
|
+
|
17
|
+
- Added spectral ops
|
18
|
+
- Fixed tensor indexing
|
19
|
+
|
20
|
+
## 0.3.2 (2020-08-24)
|
21
|
+
|
22
|
+
- Added `enable_grad` method
|
23
|
+
- Added `random_split` method
|
24
|
+
- Added `collate_fn` option to `DataLoader`
|
25
|
+
- Added `grad=` method to `Tensor`
|
26
|
+
- Fixed error with `grad` method when empty
|
27
|
+
- Fixed `EmbeddingBag`
|
28
|
+
|
1
29
|
## 0.3.1 (2020-08-17)
|
2
30
|
|
3
31
|
- Added `create_graph` and `retain_graph` options to `backward` method
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -411,7 +415,7 @@ Here’s the list of compatible versions.
|
|
411
415
|
|
412
416
|
Torch.rb | LibTorch
|
413
417
|
--- | ---
|
414
|
-
0.3.0-0.3.
|
418
|
+
0.3.0-0.3.4 | 1.6.0
|
415
419
|
0.2.0-0.2.7 | 1.5.0-1.5.1
|
416
420
|
0.1.8 | 1.4.0
|
417
421
|
0.1.0-0.1.7 | 1.3.1
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -58,6 +68,13 @@ void Init_ext()
|
|
58
68
|
return generator.seed();
|
59
69
|
});
|
60
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
61
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
62
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -284,11 +301,6 @@ void Init_ext()
|
|
284
301
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
285
302
|
return torch::pickle_load(v);
|
286
303
|
})
|
287
|
-
.define_singleton_method(
|
288
|
-
"_binary_cross_entropy_with_logits",
|
289
|
-
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
290
|
-
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
291
|
-
})
|
292
304
|
.define_singleton_method(
|
293
305
|
"_from_blob",
|
294
306
|
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
@@ -330,6 +342,18 @@ void Init_ext()
|
|
330
342
|
.define_method("numel", &torch::Tensor::numel)
|
331
343
|
.define_method("element_size", &torch::Tensor::element_size)
|
332
344
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
345
|
+
.define_method(
|
346
|
+
"_index",
|
347
|
+
*[](Tensor& self, Array indices) {
|
348
|
+
auto vec = index_vector(indices);
|
349
|
+
return self.index(vec);
|
350
|
+
})
|
351
|
+
.define_method(
|
352
|
+
"_index_put_custom",
|
353
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
354
|
+
auto vec = index_vector(indices);
|
355
|
+
return self.index_put_(vec, value);
|
356
|
+
})
|
333
357
|
.define_method(
|
334
358
|
"contiguous?",
|
335
359
|
*[](Tensor& self) {
|
@@ -350,15 +374,16 @@ void Init_ext()
|
|
350
374
|
*[](Tensor& self, bool requires_grad) {
|
351
375
|
return self.set_requires_grad(requires_grad);
|
352
376
|
})
|
353
|
-
.define_method(
|
354
|
-
"_backward",
|
355
|
-
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
356
|
-
return self.backward(gradient, create_graph, retain_graph);
|
357
|
-
})
|
358
377
|
.define_method(
|
359
378
|
"grad",
|
360
379
|
*[](Tensor& self) {
|
361
|
-
|
380
|
+
auto grad = self.grad();
|
381
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
382
|
+
})
|
383
|
+
.define_method(
|
384
|
+
"grad=",
|
385
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
386
|
+
self.grad() = grad;
|
362
387
|
})
|
363
388
|
.define_method(
|
364
389
|
"_dtype",
|
@@ -502,6 +527,7 @@ void Init_ext()
|
|
502
527
|
});
|
503
528
|
|
504
529
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
530
|
+
.add_handler<torch::Error>(handle_error)
|
505
531
|
.define_singleton_method(
|
506
532
|
"_calculate_gain",
|
507
533
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -580,11 +606,16 @@ void Init_ext()
|
|
580
606
|
*[](Parameter& self) {
|
581
607
|
auto grad = self.grad();
|
582
608
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
609
|
+
})
|
610
|
+
.define_method(
|
611
|
+
"grad=",
|
612
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
613
|
+
self.grad() = grad;
|
583
614
|
});
|
584
615
|
|
585
616
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
586
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
587
617
|
.add_handler<torch::Error>(handle_error)
|
618
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
588
619
|
.define_method("index", &torch::Device::index)
|
589
620
|
.define_method("index?", &torch::Device::has_index)
|
590
621
|
.define_method(
|
data/ext/torch/templates.cpp
CHANGED
@@ -53,3 +53,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
53
53
|
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
54
|
return Object(a);
|
55
55
|
}
|
56
|
+
|
57
|
+
Object wrap(std::vector<torch::Tensor> x) {
|
58
|
+
Array a;
|
59
|
+
for (auto& t : x) {
|
60
|
+
a.push(to_ruby<torch::Tensor>(t));
|
61
|
+
}
|
62
|
+
return Object(a);
|
63
|
+
}
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,11 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::Scalar;
|
14
|
+
using torch::ScalarType;
|
15
|
+
using torch::Tensor;
|
16
|
+
|
12
17
|
// need to wrap torch::IntArrayRef() since
|
13
18
|
// it doesn't own underlying data
|
14
19
|
class IntArrayRef {
|
@@ -32,30 +37,6 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
|
|
32
37
|
return IntArrayRef(x);
|
33
38
|
}
|
34
39
|
|
35
|
-
// for now
|
36
|
-
class Scalar {
|
37
|
-
torch::Scalar value;
|
38
|
-
public:
|
39
|
-
Scalar(Object o) {
|
40
|
-
// TODO cast based on Ruby type
|
41
|
-
if (o.rb_type() == T_FIXNUM) {
|
42
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
43
|
-
} else {
|
44
|
-
value = torch::Scalar(from_ruby<float>(o));
|
45
|
-
}
|
46
|
-
}
|
47
|
-
operator torch::Scalar() {
|
48
|
-
return value;
|
49
|
-
}
|
50
|
-
};
|
51
|
-
|
52
|
-
template<>
|
53
|
-
inline
|
54
|
-
Scalar from_ruby<Scalar>(Object x)
|
55
|
-
{
|
56
|
-
return Scalar(x);
|
57
|
-
}
|
58
|
-
|
59
40
|
class TensorList {
|
60
41
|
std::vector<torch::Tensor> vec;
|
61
42
|
public:
|
@@ -174,8 +155,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
155
|
return MyReduction(x);
|
175
156
|
}
|
176
157
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
158
|
class OptionalTensor {
|
180
159
|
Object value;
|
181
160
|
public:
|
@@ -190,6 +169,17 @@ class OptionalTensor {
|
|
190
169
|
}
|
191
170
|
};
|
192
171
|
|
172
|
+
template<>
|
173
|
+
inline
|
174
|
+
Scalar from_ruby<Scalar>(Object x)
|
175
|
+
{
|
176
|
+
if (x.rb_type() == T_FIXNUM) {
|
177
|
+
return torch::Scalar(from_ruby<int64_t>(x));
|
178
|
+
} else {
|
179
|
+
return torch::Scalar(from_ruby<double>(x));
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
193
183
|
template<>
|
194
184
|
inline
|
195
185
|
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
@@ -197,46 +187,60 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
187
|
return OptionalTensor(x);
|
198
188
|
}
|
199
189
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
}
|
190
|
+
template<>
|
191
|
+
inline
|
192
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
193
|
+
{
|
194
|
+
if (x.is_nil()) {
|
195
|
+
return torch::nullopt;
|
196
|
+
} else {
|
197
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
198
|
+
}
|
199
|
+
}
|
210
200
|
|
211
201
|
template<>
|
212
202
|
inline
|
213
|
-
|
203
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
214
204
|
{
|
215
|
-
|
205
|
+
if (x.is_nil()) {
|
206
|
+
return torch::nullopt;
|
207
|
+
} else {
|
208
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
209
|
+
}
|
216
210
|
}
|
217
211
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
212
|
+
template<>
|
213
|
+
inline
|
214
|
+
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
|
215
|
+
{
|
216
|
+
if (x.is_nil()) {
|
217
|
+
return torch::nullopt;
|
218
|
+
} else {
|
219
|
+
return torch::optional<double>{from_ruby<double>(x)};
|
220
|
+
}
|
221
|
+
}
|
231
222
|
|
232
223
|
template<>
|
233
224
|
inline
|
234
|
-
|
225
|
+
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
235
226
|
{
|
236
|
-
|
227
|
+
if (x.is_nil()) {
|
228
|
+
return torch::nullopt;
|
229
|
+
} else {
|
230
|
+
return torch::optional<bool>{from_ruby<bool>(x)};
|
231
|
+
}
|
237
232
|
}
|
238
233
|
|
239
|
-
|
234
|
+
template<>
|
235
|
+
inline
|
236
|
+
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
237
|
+
{
|
238
|
+
if (x.is_nil()) {
|
239
|
+
return torch::nullopt;
|
240
|
+
} else {
|
241
|
+
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
242
|
+
}
|
243
|
+
}
|
240
244
|
|
241
245
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
246
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
@@ -244,3 +248,4 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
244
248
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
245
249
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
|
246
250
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
|
251
|
+
Object wrap(std::vector<torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -179,8 +179,10 @@ require "torch/nn/functional"
|
|
179
179
|
require "torch/nn/init"
|
180
180
|
|
181
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
182
183
|
require "torch/utils/data/data_loader"
|
183
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
184
186
|
require "torch/utils/data/tensor_dataset"
|
185
187
|
|
186
188
|
# hub
|
@@ -237,25 +239,22 @@ module Torch
|
|
237
239
|
cls
|
238
240
|
end
|
239
241
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
257
|
-
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
258
|
-
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
242
|
+
DTYPE_TO_CLASS = {
|
243
|
+
float32: "FloatTensor",
|
244
|
+
float64: "DoubleTensor",
|
245
|
+
float16: "HalfTensor",
|
246
|
+
uint8: "ByteTensor",
|
247
|
+
int8: "CharTensor",
|
248
|
+
int16: "ShortTensor",
|
249
|
+
int32: "IntTensor",
|
250
|
+
int64: "LongTensor",
|
251
|
+
bool: "BoolTensor"
|
252
|
+
}
|
253
|
+
|
254
|
+
DTYPE_TO_CLASS.each do |dtype, class_name|
|
255
|
+
const_set(class_name, _make_tensor_class(dtype))
|
256
|
+
CUDA.const_set(class_name, _make_tensor_class(dtype, true))
|
257
|
+
end
|
259
258
|
|
260
259
|
class << self
|
261
260
|
# Torch.float, Torch.long, etc
|
@@ -316,6 +315,16 @@ module Torch
|
|
316
315
|
end
|
317
316
|
end
|
318
317
|
|
318
|
+
def enable_grad
|
319
|
+
previous_value = grad_enabled?
|
320
|
+
begin
|
321
|
+
_set_grad_enabled(true)
|
322
|
+
yield
|
323
|
+
ensure
|
324
|
+
_set_grad_enabled(previous_value)
|
325
|
+
end
|
326
|
+
end
|
327
|
+
|
319
328
|
def device(str)
|
320
329
|
Device.new(str)
|
321
330
|
end
|
@@ -376,6 +385,10 @@ module Torch
|
|
376
385
|
end
|
377
386
|
|
378
387
|
def randperm(n, **options)
|
388
|
+
# dtype hack in Python
|
389
|
+
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
390
|
+
options[:dtype] ||= :int64
|
391
|
+
|
379
392
|
_randperm(n, tensor_options(**options))
|
380
393
|
end
|
381
394
|
|
@@ -448,6 +461,22 @@ module Torch
|
|
448
461
|
zeros(input.size, **like_options(input, options))
|
449
462
|
end
|
450
463
|
|
464
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
465
|
+
if center
|
466
|
+
signal_dim = input.dim
|
467
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
468
|
+
pad = n_fft.div(2).to_i
|
469
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
470
|
+
input = input.view(input.shape[-signal_dim..-1])
|
471
|
+
end
|
472
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
473
|
+
end
|
474
|
+
|
475
|
+
def clamp(tensor, min, max)
|
476
|
+
tensor = _clamp_min(tensor, min)
|
477
|
+
_clamp_max(tensor, max)
|
478
|
+
end
|
479
|
+
|
451
480
|
private
|
452
481
|
|
453
482
|
def to_ivalue(obj)
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
@@ -1,10 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
+
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
+
@function["func"].sub!(tensor_options_str, ")")
|
8
12
|
end
|
9
13
|
|
10
14
|
def func
|
@@ -82,6 +86,10 @@ module Torch
|
|
82
86
|
@ret_size ||= func.split("->").last.split(", ").size
|
83
87
|
end
|
84
88
|
|
89
|
+
def ret_array?
|
90
|
+
@ret_array ||= func.split("->").last.include?('[]')
|
91
|
+
end
|
92
|
+
|
85
93
|
def out?
|
86
94
|
out_size > 0 && base_name[-1] != "_"
|
87
95
|
end
|
@@ -18,12 +18,12 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["
|
21
|
+
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
25
25
|
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.
|
26
|
+
f.ruby_name.include?("_backward") ||
|
27
27
|
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
28
|
end
|
29
29
|
|
@@ -31,32 +31,15 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
34
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
35
|
+
# call to 'range' is ambiguous
|
36
|
+
f.cpp_name == "_range" ||
|
36
37
|
# native_functions.yaml is missing size argument for normal
|
37
38
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
39
|
(f.base_name == "normal" && !f.out?)
|
39
40
|
end
|
40
41
|
end
|
41
42
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
43
|
# todo_functions.each do |f|
|
61
44
|
# puts f.func
|
62
45
|
# puts
|
@@ -97,7 +80,8 @@ void add_%{type}_functions(Module m) {
|
|
97
80
|
|
98
81
|
cpp_defs = []
|
99
82
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
83
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
84
|
+
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
101
85
|
|
102
86
|
cpp_args = []
|
103
87
|
fargs.each do |a|
|
@@ -109,7 +93,7 @@ void add_%{type}_functions(Module m) {
|
|
109
93
|
# TODO better signature
|
110
94
|
"OptionalTensor"
|
111
95
|
when "ScalarType?"
|
112
|
-
"
|
96
|
+
"torch::optional<ScalarType>"
|
113
97
|
when "Tensor[]"
|
114
98
|
"TensorList"
|
115
99
|
when "Tensor?[]"
|
@@ -117,6 +101,14 @@ void add_%{type}_functions(Module m) {
|
|
117
101
|
"TensorList"
|
118
102
|
when "int"
|
119
103
|
"int64_t"
|
104
|
+
when "int?"
|
105
|
+
"torch::optional<int64_t>"
|
106
|
+
when "float?"
|
107
|
+
"torch::optional<double>"
|
108
|
+
when "bool?"
|
109
|
+
"torch::optional<bool>"
|
110
|
+
when "Scalar?"
|
111
|
+
"torch::optional<torch::Scalar>"
|
120
112
|
when "float"
|
121
113
|
"double"
|
122
114
|
when /\Aint\[/
|
@@ -125,6 +117,8 @@ void add_%{type}_functions(Module m) {
|
|
125
117
|
"Tensor &"
|
126
118
|
when "str"
|
127
119
|
"std::string"
|
120
|
+
when "TensorOptions"
|
121
|
+
"const torch::TensorOptions &"
|
128
122
|
else
|
129
123
|
a[:type]
|
130
124
|
end
|
@@ -141,8 +135,8 @@ void add_%{type}_functions(Module m) {
|
|
141
135
|
prefix = def_method == :define_method ? "self." : "torch::"
|
142
136
|
|
143
137
|
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
144
|
-
|
145
|
-
if func.ret_size > 1
|
138
|
+
|
139
|
+
if func.ret_size > 1 || func.ret_array?
|
146
140
|
body = "wrap(#{body})"
|
147
141
|
end
|
148
142
|
|
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,12 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
88
|
+
when "float?"
|
89
|
+
v.is_a?(Numeric) || v.nil?
|
90
|
+
when "bool?"
|
91
|
+
v == true || v == false || v.nil?
|
86
92
|
when "float"
|
87
93
|
v.is_a?(Numeric)
|
88
94
|
when /int\[.*\]/
|
@@ -95,6 +101,10 @@ module Torch
|
|
95
101
|
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
96
102
|
when "Scalar"
|
97
103
|
v.is_a?(Numeric)
|
104
|
+
when "Scalar?"
|
105
|
+
v.is_a?(Numeric) || v.nil?
|
106
|
+
when "ScalarType"
|
107
|
+
false # not supported yet
|
98
108
|
when "ScalarType?"
|
99
109
|
v.nil?
|
100
110
|
when "bool"
|
@@ -126,9 +136,11 @@ module Torch
|
|
126
136
|
end
|
127
137
|
|
128
138
|
func = candidates.first
|
139
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
140
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
141
|
{
|
130
142
|
name: func.cpp_name,
|
131
|
-
args:
|
143
|
+
args: args
|
132
144
|
}
|
133
145
|
end
|
134
146
|
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -178,8 +178,12 @@ module Torch
|
|
178
178
|
Torch.hardshrink(input, lambd)
|
179
179
|
end
|
180
180
|
|
181
|
-
def leaky_relu(input, negative_slope = 0.01)
|
182
|
-
|
181
|
+
def leaky_relu(input, negative_slope = 0.01, inplace: false)
|
182
|
+
if inplace
|
183
|
+
NN.leaky_relu!(input, negative_slope)
|
184
|
+
else
|
185
|
+
NN.leaky_relu(input, negative_slope)
|
186
|
+
end
|
183
187
|
end
|
184
188
|
|
185
189
|
def log_sigmoid(input)
|
@@ -373,7 +377,8 @@ module Torch
|
|
373
377
|
end
|
374
378
|
|
375
379
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
380
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
381
|
+
ret
|
377
382
|
end
|
378
383
|
|
379
384
|
# distance functions
|
@@ -426,6 +431,9 @@ module Torch
|
|
426
431
|
end
|
427
432
|
|
428
433
|
def mse_loss(input, target, reduction: "mean")
|
434
|
+
if target.size != input.size
|
435
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
436
|
+
end
|
429
437
|
NN.mse_loss(input, target, reduction)
|
430
438
|
end
|
431
439
|
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class LeakyReLU < Module
|
4
|
-
def initialize(negative_slope: 1e-2
|
4
|
+
def initialize(negative_slope: 1e-2, inplace: false)
|
5
5
|
super()
|
6
6
|
@negative_slope = negative_slope
|
7
|
-
|
7
|
+
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
10
|
def forward(input)
|
11
|
-
F.leaky_relu(input, @negative_slope
|
11
|
+
F.leaky_relu(input, @negative_slope, inplace: @inplace)
|
12
12
|
end
|
13
13
|
|
14
14
|
def extra_inspect
|
data/lib/torch/nn/module.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -103,11 +103,6 @@ module Torch
|
|
103
103
|
Torch.empty(0, dtype: dtype)
|
104
104
|
end
|
105
105
|
|
106
|
-
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
-
retain_graph = create_graph if retain_graph.nil?
|
108
|
-
_backward(gradient, retain_graph, create_graph)
|
109
|
-
end
|
110
|
-
|
111
106
|
# TODO read directly from memory
|
112
107
|
def numo
|
113
108
|
cls = Torch._dtype_to_numo[dtype]
|
@@ -188,49 +183,15 @@ module Torch
|
|
188
183
|
# based on python_variable_indexing.cpp and
|
189
184
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
185
|
def [](*indexes)
|
191
|
-
|
192
|
-
dim = 0
|
193
|
-
indexes.each do |index|
|
194
|
-
if index.is_a?(Numeric)
|
195
|
-
result = result._select_int(dim, index)
|
196
|
-
elsif index.is_a?(Range)
|
197
|
-
finish = index.end
|
198
|
-
finish += 1 unless index.exclude_end?
|
199
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
200
|
-
dim += 1
|
201
|
-
elsif index.is_a?(Tensor)
|
202
|
-
result = result.index([index])
|
203
|
-
elsif index.nil?
|
204
|
-
result = result.unsqueeze(dim)
|
205
|
-
dim += 1
|
206
|
-
elsif index == true
|
207
|
-
result = result.unsqueeze(dim)
|
208
|
-
# TODO handle false
|
209
|
-
else
|
210
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
211
|
-
end
|
212
|
-
end
|
213
|
-
result
|
186
|
+
_index(tensor_indexes(indexes))
|
214
187
|
end
|
215
188
|
|
216
189
|
# based on python_variable_indexing.cpp and
|
217
190
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
218
|
-
def []=(
|
191
|
+
def []=(*indexes, value)
|
219
192
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
220
|
-
|
221
193
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
222
|
-
|
223
|
-
if index.is_a?(Numeric)
|
224
|
-
index_put!([Torch.tensor(index)], value)
|
225
|
-
elsif index.is_a?(Range)
|
226
|
-
finish = index.end
|
227
|
-
finish += 1 unless index.exclude_end?
|
228
|
-
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
229
|
-
elsif index.is_a?(Tensor)
|
230
|
-
index_put!([index], value)
|
231
|
-
else
|
232
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
233
|
-
end
|
194
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
234
195
|
end
|
235
196
|
|
236
197
|
# native functions that need manually defined
|
@@ -244,13 +205,13 @@ module Torch
|
|
244
205
|
end
|
245
206
|
end
|
246
207
|
|
247
|
-
#
|
208
|
+
# parser can't handle overlap, so need to handle manually
|
248
209
|
def random!(*args)
|
249
210
|
case args.size
|
250
211
|
when 1
|
251
212
|
_random__to(*args)
|
252
213
|
when 2
|
253
|
-
|
214
|
+
_random__from(*args)
|
254
215
|
else
|
255
216
|
_random_(*args)
|
256
217
|
end
|
@@ -260,5 +221,32 @@ module Torch
|
|
260
221
|
_clamp_min_(min)
|
261
222
|
_clamp_max_(max)
|
262
223
|
end
|
224
|
+
|
225
|
+
private
|
226
|
+
|
227
|
+
def tensor_indexes(indexes)
|
228
|
+
indexes.map do |index|
|
229
|
+
case index
|
230
|
+
when Integer
|
231
|
+
TensorIndex.integer(index)
|
232
|
+
when Range
|
233
|
+
finish = index.end || -1
|
234
|
+
if finish == -1 && !index.exclude_end?
|
235
|
+
finish = nil
|
236
|
+
else
|
237
|
+
finish += 1 unless index.exclude_end?
|
238
|
+
end
|
239
|
+
TensorIndex.slice(index.begin, finish)
|
240
|
+
when Tensor
|
241
|
+
TensorIndex.tensor(index)
|
242
|
+
when nil
|
243
|
+
TensorIndex.none
|
244
|
+
when true, false
|
245
|
+
TensorIndex.boolean(index)
|
246
|
+
else
|
247
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
248
|
+
end
|
249
|
+
end
|
250
|
+
end
|
263
251
|
end
|
264
252
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,18 +37,20 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
33
45
|
def size
|
34
46
|
(@dataset.size / @batch_size.to_f).ceil
|
35
47
|
end
|
48
|
+
alias_method :length, :size
|
49
|
+
alias_method :count, :size
|
36
50
|
|
37
51
|
private
|
38
52
|
|
39
|
-
def
|
53
|
+
def default_convert(batch)
|
40
54
|
elem = batch[0]
|
41
55
|
case elem
|
42
56
|
when Tensor
|
@@ -44,11 +58,15 @@ module Torch
|
|
44
58
|
when Integer
|
45
59
|
Torch.tensor(batch)
|
46
60
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
61
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
62
|
else
|
49
|
-
raise
|
63
|
+
raise NotImplementedYet
|
50
64
|
end
|
51
65
|
end
|
66
|
+
|
67
|
+
def auto_collation?
|
68
|
+
!@batch_sampler.nil?
|
69
|
+
end
|
52
70
|
end
|
53
71
|
end
|
54
72
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-09-18 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|