torch-rb 0.3.1 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +6 -2
- data/ext/torch/ext.cpp +43 -12
- data/ext/torch/templates.cpp +8 -0
- data/ext/torch/templates.hpp +59 -54
- data/lib/torch.rb +48 -19
- data/lib/torch/hub.rb +11 -10
- data/lib/torch/native/function.rb +9 -1
- data/lib/torch/native/generator.rb +19 -25
- data/lib/torch/native/parser.rb +13 -1
- data/lib/torch/nn/functional.rb +11 -3
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +6 -1
- data/lib/torch/tensor.rb +32 -44
- data/lib/torch/utils/data.rb +23 -0
- data/lib/torch/utils/data/data_loader.rb +24 -6
- data/lib/torch/utils/data/subset.rb +25 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +4 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
|
4
|
+
data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
|
7
|
+
data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.6 (2020-09-17)
|
2
|
+
|
3
|
+
- Added `inplace` option for leaky ReLU
|
4
|
+
- Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
|
5
|
+
- Fixed error with buffers on GPU
|
6
|
+
|
7
|
+
## 0.3.5 (2020-09-04)
|
8
|
+
|
9
|
+
- Fixed error with data loader (due to `dtype` of `randperm`)
|
10
|
+
|
11
|
+
## 0.3.4 (2020-08-26)
|
12
|
+
|
13
|
+
- Added `Torch.clamp` method
|
14
|
+
|
15
|
+
## 0.3.3 (2020-08-25)
|
16
|
+
|
17
|
+
- Added spectral ops
|
18
|
+
- Fixed tensor indexing
|
19
|
+
|
20
|
+
## 0.3.2 (2020-08-24)
|
21
|
+
|
22
|
+
- Added `enable_grad` method
|
23
|
+
- Added `random_split` method
|
24
|
+
- Added `collate_fn` option to `DataLoader`
|
25
|
+
- Added `grad=` method to `Tensor`
|
26
|
+
- Fixed error with `grad` method when empty
|
27
|
+
- Fixed `EmbeddingBag`
|
28
|
+
|
1
29
|
## 0.3.1 (2020-08-17)
|
2
30
|
|
3
31
|
- Added `create_graph` and `retain_graph` options to `backward` method
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -411,7 +415,7 @@ Here’s the list of compatible versions.
|
|
411
415
|
|
412
416
|
Torch.rb | LibTorch
|
413
417
|
--- | ---
|
414
|
-
0.3.0-0.3.
|
418
|
+
0.3.0-0.3.4 | 1.6.0
|
415
419
|
0.2.0-0.2.7 | 1.5.0-1.5.1
|
416
420
|
0.1.8 | 1.4.0
|
417
421
|
0.1.0-0.1.7 | 1.3.1
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -58,6 +68,13 @@ void Init_ext()
|
|
58
68
|
return generator.seed();
|
59
69
|
});
|
60
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
61
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
62
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -284,11 +301,6 @@ void Init_ext()
|
|
284
301
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
285
302
|
return torch::pickle_load(v);
|
286
303
|
})
|
287
|
-
.define_singleton_method(
|
288
|
-
"_binary_cross_entropy_with_logits",
|
289
|
-
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
290
|
-
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
291
|
-
})
|
292
304
|
.define_singleton_method(
|
293
305
|
"_from_blob",
|
294
306
|
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
@@ -330,6 +342,18 @@ void Init_ext()
|
|
330
342
|
.define_method("numel", &torch::Tensor::numel)
|
331
343
|
.define_method("element_size", &torch::Tensor::element_size)
|
332
344
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
345
|
+
.define_method(
|
346
|
+
"_index",
|
347
|
+
*[](Tensor& self, Array indices) {
|
348
|
+
auto vec = index_vector(indices);
|
349
|
+
return self.index(vec);
|
350
|
+
})
|
351
|
+
.define_method(
|
352
|
+
"_index_put_custom",
|
353
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
354
|
+
auto vec = index_vector(indices);
|
355
|
+
return self.index_put_(vec, value);
|
356
|
+
})
|
333
357
|
.define_method(
|
334
358
|
"contiguous?",
|
335
359
|
*[](Tensor& self) {
|
@@ -350,15 +374,16 @@ void Init_ext()
|
|
350
374
|
*[](Tensor& self, bool requires_grad) {
|
351
375
|
return self.set_requires_grad(requires_grad);
|
352
376
|
})
|
353
|
-
.define_method(
|
354
|
-
"_backward",
|
355
|
-
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
356
|
-
return self.backward(gradient, create_graph, retain_graph);
|
357
|
-
})
|
358
377
|
.define_method(
|
359
378
|
"grad",
|
360
379
|
*[](Tensor& self) {
|
361
|
-
|
380
|
+
auto grad = self.grad();
|
381
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
382
|
+
})
|
383
|
+
.define_method(
|
384
|
+
"grad=",
|
385
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
386
|
+
self.grad() = grad;
|
362
387
|
})
|
363
388
|
.define_method(
|
364
389
|
"_dtype",
|
@@ -502,6 +527,7 @@ void Init_ext()
|
|
502
527
|
});
|
503
528
|
|
504
529
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
530
|
+
.add_handler<torch::Error>(handle_error)
|
505
531
|
.define_singleton_method(
|
506
532
|
"_calculate_gain",
|
507
533
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -580,11 +606,16 @@ void Init_ext()
|
|
580
606
|
*[](Parameter& self) {
|
581
607
|
auto grad = self.grad();
|
582
608
|
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
609
|
+
})
|
610
|
+
.define_method(
|
611
|
+
"grad=",
|
612
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
613
|
+
self.grad() = grad;
|
583
614
|
});
|
584
615
|
|
585
616
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
586
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
587
617
|
.add_handler<torch::Error>(handle_error)
|
618
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
588
619
|
.define_method("index", &torch::Device::index)
|
589
620
|
.define_method("index?", &torch::Device::has_index)
|
590
621
|
.define_method(
|
data/ext/torch/templates.cpp
CHANGED
@@ -53,3 +53,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
53
53
|
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
54
|
return Object(a);
|
55
55
|
}
|
56
|
+
|
57
|
+
Object wrap(std::vector<torch::Tensor> x) {
|
58
|
+
Array a;
|
59
|
+
for (auto& t : x) {
|
60
|
+
a.push(to_ruby<torch::Tensor>(t));
|
61
|
+
}
|
62
|
+
return Object(a);
|
63
|
+
}
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,6 +9,11 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
+
using torch::Device;
|
13
|
+
using torch::Scalar;
|
14
|
+
using torch::ScalarType;
|
15
|
+
using torch::Tensor;
|
16
|
+
|
12
17
|
// need to wrap torch::IntArrayRef() since
|
13
18
|
// it doesn't own underlying data
|
14
19
|
class IntArrayRef {
|
@@ -32,30 +37,6 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
|
|
32
37
|
return IntArrayRef(x);
|
33
38
|
}
|
34
39
|
|
35
|
-
// for now
|
36
|
-
class Scalar {
|
37
|
-
torch::Scalar value;
|
38
|
-
public:
|
39
|
-
Scalar(Object o) {
|
40
|
-
// TODO cast based on Ruby type
|
41
|
-
if (o.rb_type() == T_FIXNUM) {
|
42
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
43
|
-
} else {
|
44
|
-
value = torch::Scalar(from_ruby<float>(o));
|
45
|
-
}
|
46
|
-
}
|
47
|
-
operator torch::Scalar() {
|
48
|
-
return value;
|
49
|
-
}
|
50
|
-
};
|
51
|
-
|
52
|
-
template<>
|
53
|
-
inline
|
54
|
-
Scalar from_ruby<Scalar>(Object x)
|
55
|
-
{
|
56
|
-
return Scalar(x);
|
57
|
-
}
|
58
|
-
|
59
40
|
class TensorList {
|
60
41
|
std::vector<torch::Tensor> vec;
|
61
42
|
public:
|
@@ -174,8 +155,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
155
|
return MyReduction(x);
|
175
156
|
}
|
176
157
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
158
|
class OptionalTensor {
|
180
159
|
Object value;
|
181
160
|
public:
|
@@ -190,6 +169,17 @@ class OptionalTensor {
|
|
190
169
|
}
|
191
170
|
};
|
192
171
|
|
172
|
+
template<>
|
173
|
+
inline
|
174
|
+
Scalar from_ruby<Scalar>(Object x)
|
175
|
+
{
|
176
|
+
if (x.rb_type() == T_FIXNUM) {
|
177
|
+
return torch::Scalar(from_ruby<int64_t>(x));
|
178
|
+
} else {
|
179
|
+
return torch::Scalar(from_ruby<double>(x));
|
180
|
+
}
|
181
|
+
}
|
182
|
+
|
193
183
|
template<>
|
194
184
|
inline
|
195
185
|
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
@@ -197,46 +187,60 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
187
|
return OptionalTensor(x);
|
198
188
|
}
|
199
189
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
}
|
190
|
+
template<>
|
191
|
+
inline
|
192
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
193
|
+
{
|
194
|
+
if (x.is_nil()) {
|
195
|
+
return torch::nullopt;
|
196
|
+
} else {
|
197
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
198
|
+
}
|
199
|
+
}
|
210
200
|
|
211
201
|
template<>
|
212
202
|
inline
|
213
|
-
|
203
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
214
204
|
{
|
215
|
-
|
205
|
+
if (x.is_nil()) {
|
206
|
+
return torch::nullopt;
|
207
|
+
} else {
|
208
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
209
|
+
}
|
216
210
|
}
|
217
211
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
212
|
+
template<>
|
213
|
+
inline
|
214
|
+
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
|
215
|
+
{
|
216
|
+
if (x.is_nil()) {
|
217
|
+
return torch::nullopt;
|
218
|
+
} else {
|
219
|
+
return torch::optional<double>{from_ruby<double>(x)};
|
220
|
+
}
|
221
|
+
}
|
231
222
|
|
232
223
|
template<>
|
233
224
|
inline
|
234
|
-
|
225
|
+
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
235
226
|
{
|
236
|
-
|
227
|
+
if (x.is_nil()) {
|
228
|
+
return torch::nullopt;
|
229
|
+
} else {
|
230
|
+
return torch::optional<bool>{from_ruby<bool>(x)};
|
231
|
+
}
|
237
232
|
}
|
238
233
|
|
239
|
-
|
234
|
+
template<>
|
235
|
+
inline
|
236
|
+
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
237
|
+
{
|
238
|
+
if (x.is_nil()) {
|
239
|
+
return torch::nullopt;
|
240
|
+
} else {
|
241
|
+
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
242
|
+
}
|
243
|
+
}
|
240
244
|
|
241
245
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
246
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
@@ -244,3 +248,4 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
244
248
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
245
249
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
|
246
250
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
|
251
|
+
Object wrap(std::vector<torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -179,8 +179,10 @@ require "torch/nn/functional"
|
|
179
179
|
require "torch/nn/init"
|
180
180
|
|
181
181
|
# utils
|
182
|
+
require "torch/utils/data"
|
182
183
|
require "torch/utils/data/data_loader"
|
183
184
|
require "torch/utils/data/dataset"
|
185
|
+
require "torch/utils/data/subset"
|
184
186
|
require "torch/utils/data/tensor_dataset"
|
185
187
|
|
186
188
|
# hub
|
@@ -237,25 +239,22 @@ module Torch
|
|
237
239
|
cls
|
238
240
|
end
|
239
241
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
257
|
-
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
258
|
-
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
242
|
+
DTYPE_TO_CLASS = {
|
243
|
+
float32: "FloatTensor",
|
244
|
+
float64: "DoubleTensor",
|
245
|
+
float16: "HalfTensor",
|
246
|
+
uint8: "ByteTensor",
|
247
|
+
int8: "CharTensor",
|
248
|
+
int16: "ShortTensor",
|
249
|
+
int32: "IntTensor",
|
250
|
+
int64: "LongTensor",
|
251
|
+
bool: "BoolTensor"
|
252
|
+
}
|
253
|
+
|
254
|
+
DTYPE_TO_CLASS.each do |dtype, class_name|
|
255
|
+
const_set(class_name, _make_tensor_class(dtype))
|
256
|
+
CUDA.const_set(class_name, _make_tensor_class(dtype, true))
|
257
|
+
end
|
259
258
|
|
260
259
|
class << self
|
261
260
|
# Torch.float, Torch.long, etc
|
@@ -316,6 +315,16 @@ module Torch
|
|
316
315
|
end
|
317
316
|
end
|
318
317
|
|
318
|
+
def enable_grad
|
319
|
+
previous_value = grad_enabled?
|
320
|
+
begin
|
321
|
+
_set_grad_enabled(true)
|
322
|
+
yield
|
323
|
+
ensure
|
324
|
+
_set_grad_enabled(previous_value)
|
325
|
+
end
|
326
|
+
end
|
327
|
+
|
319
328
|
def device(str)
|
320
329
|
Device.new(str)
|
321
330
|
end
|
@@ -376,6 +385,10 @@ module Torch
|
|
376
385
|
end
|
377
386
|
|
378
387
|
def randperm(n, **options)
|
388
|
+
# dtype hack in Python
|
389
|
+
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
390
|
+
options[:dtype] ||= :int64
|
391
|
+
|
379
392
|
_randperm(n, tensor_options(**options))
|
380
393
|
end
|
381
394
|
|
@@ -448,6 +461,22 @@ module Torch
|
|
448
461
|
zeros(input.size, **like_options(input, options))
|
449
462
|
end
|
450
463
|
|
464
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
465
|
+
if center
|
466
|
+
signal_dim = input.dim
|
467
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
468
|
+
pad = n_fft.div(2).to_i
|
469
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
470
|
+
input = input.view(input.shape[-signal_dim..-1])
|
471
|
+
end
|
472
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
473
|
+
end
|
474
|
+
|
475
|
+
def clamp(tensor, min, max)
|
476
|
+
tensor = _clamp_min(tensor, min)
|
477
|
+
_clamp_max(tensor, max)
|
478
|
+
end
|
479
|
+
|
451
480
|
private
|
452
481
|
|
453
482
|
def to_ivalue(obj)
|
data/lib/torch/hub.rb
CHANGED
@@ -7,25 +7,26 @@ module Torch
|
|
7
7
|
|
8
8
|
def download_url_to_file(url, dst)
|
9
9
|
uri = URI(url)
|
10
|
-
tmp =
|
10
|
+
tmp = nil
|
11
11
|
location = nil
|
12
12
|
|
13
|
+
puts "Downloading #{url}..."
|
13
14
|
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
15
|
request = Net::HTTP::Get.new(uri)
|
15
16
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
17
|
+
http.request(request) do |response|
|
18
|
+
case response
|
19
|
+
when Net::HTTPRedirection
|
20
|
+
location = response["location"]
|
21
|
+
when Net::HTTPSuccess
|
22
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
23
|
+
File.open(tmp, "wb") do |f|
|
23
24
|
response.read_body do |chunk|
|
24
25
|
f.write(chunk)
|
25
26
|
end
|
26
|
-
else
|
27
|
-
raise Error, "Bad response"
|
28
27
|
end
|
28
|
+
else
|
29
|
+
raise Error, "Bad response"
|
29
30
|
end
|
30
31
|
end
|
31
32
|
end
|
@@ -1,10 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
+
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
+
@function["func"].sub!(tensor_options_str, ")")
|
8
12
|
end
|
9
13
|
|
10
14
|
def func
|
@@ -82,6 +86,10 @@ module Torch
|
|
82
86
|
@ret_size ||= func.split("->").last.split(", ").size
|
83
87
|
end
|
84
88
|
|
89
|
+
def ret_array?
|
90
|
+
@ret_array ||= func.split("->").last.include?('[]')
|
91
|
+
end
|
92
|
+
|
85
93
|
def out?
|
86
94
|
out_size > 0 && base_name[-1] != "_"
|
87
95
|
end
|
@@ -18,12 +18,12 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["
|
21
|
+
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
25
25
|
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.
|
26
|
+
f.ruby_name.include?("_backward") ||
|
27
27
|
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
28
|
end
|
29
29
|
|
@@ -31,32 +31,15 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
34
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
35
|
+
# call to 'range' is ambiguous
|
36
|
+
f.cpp_name == "_range" ||
|
36
37
|
# native_functions.yaml is missing size argument for normal
|
37
38
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
39
|
(f.base_name == "normal" && !f.out?)
|
39
40
|
end
|
40
41
|
end
|
41
42
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
43
|
# todo_functions.each do |f|
|
61
44
|
# puts f.func
|
62
45
|
# puts
|
@@ -97,7 +80,8 @@ void add_%{type}_functions(Module m) {
|
|
97
80
|
|
98
81
|
cpp_defs = []
|
99
82
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
83
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
84
|
+
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
101
85
|
|
102
86
|
cpp_args = []
|
103
87
|
fargs.each do |a|
|
@@ -109,7 +93,7 @@ void add_%{type}_functions(Module m) {
|
|
109
93
|
# TODO better signature
|
110
94
|
"OptionalTensor"
|
111
95
|
when "ScalarType?"
|
112
|
-
"
|
96
|
+
"torch::optional<ScalarType>"
|
113
97
|
when "Tensor[]"
|
114
98
|
"TensorList"
|
115
99
|
when "Tensor?[]"
|
@@ -117,6 +101,14 @@ void add_%{type}_functions(Module m) {
|
|
117
101
|
"TensorList"
|
118
102
|
when "int"
|
119
103
|
"int64_t"
|
104
|
+
when "int?"
|
105
|
+
"torch::optional<int64_t>"
|
106
|
+
when "float?"
|
107
|
+
"torch::optional<double>"
|
108
|
+
when "bool?"
|
109
|
+
"torch::optional<bool>"
|
110
|
+
when "Scalar?"
|
111
|
+
"torch::optional<torch::Scalar>"
|
120
112
|
when "float"
|
121
113
|
"double"
|
122
114
|
when /\Aint\[/
|
@@ -125,6 +117,8 @@ void add_%{type}_functions(Module m) {
|
|
125
117
|
"Tensor &"
|
126
118
|
when "str"
|
127
119
|
"std::string"
|
120
|
+
when "TensorOptions"
|
121
|
+
"const torch::TensorOptions &"
|
128
122
|
else
|
129
123
|
a[:type]
|
130
124
|
end
|
@@ -141,8 +135,8 @@ void add_%{type}_functions(Module m) {
|
|
141
135
|
prefix = def_method == :define_method ? "self." : "torch::"
|
142
136
|
|
143
137
|
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
144
|
-
|
145
|
-
if func.ret_size > 1
|
138
|
+
|
139
|
+
if func.ret_size > 1 || func.ret_array?
|
146
140
|
body = "wrap(#{body})"
|
147
141
|
end
|
148
142
|
|
data/lib/torch/native/parser.rb
CHANGED
@@ -83,6 +83,12 @@ module Torch
|
|
83
83
|
else
|
84
84
|
v.is_a?(Integer)
|
85
85
|
end
|
86
|
+
when "int?"
|
87
|
+
v.is_a?(Integer) || v.nil?
|
88
|
+
when "float?"
|
89
|
+
v.is_a?(Numeric) || v.nil?
|
90
|
+
when "bool?"
|
91
|
+
v == true || v == false || v.nil?
|
86
92
|
when "float"
|
87
93
|
v.is_a?(Numeric)
|
88
94
|
when /int\[.*\]/
|
@@ -95,6 +101,10 @@ module Torch
|
|
95
101
|
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
96
102
|
when "Scalar"
|
97
103
|
v.is_a?(Numeric)
|
104
|
+
when "Scalar?"
|
105
|
+
v.is_a?(Numeric) || v.nil?
|
106
|
+
when "ScalarType"
|
107
|
+
false # not supported yet
|
98
108
|
when "ScalarType?"
|
99
109
|
v.nil?
|
100
110
|
when "bool"
|
@@ -126,9 +136,11 @@ module Torch
|
|
126
136
|
end
|
127
137
|
|
128
138
|
func = candidates.first
|
139
|
+
args = func.args.map { |a| final_values[a[:name]] }
|
140
|
+
args << TensorOptions.new.dtype(6) if func.tensor_options
|
129
141
|
{
|
130
142
|
name: func.cpp_name,
|
131
|
-
args:
|
143
|
+
args: args
|
132
144
|
}
|
133
145
|
end
|
134
146
|
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -178,8 +178,12 @@ module Torch
|
|
178
178
|
Torch.hardshrink(input, lambd)
|
179
179
|
end
|
180
180
|
|
181
|
-
def leaky_relu(input, negative_slope = 0.01)
|
182
|
-
|
181
|
+
def leaky_relu(input, negative_slope = 0.01, inplace: false)
|
182
|
+
if inplace
|
183
|
+
NN.leaky_relu!(input, negative_slope)
|
184
|
+
else
|
185
|
+
NN.leaky_relu(input, negative_slope)
|
186
|
+
end
|
183
187
|
end
|
184
188
|
|
185
189
|
def log_sigmoid(input)
|
@@ -373,7 +377,8 @@ module Torch
|
|
373
377
|
end
|
374
378
|
|
375
379
|
# weight and input swapped
|
376
|
-
Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
380
|
+
ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
|
381
|
+
ret
|
377
382
|
end
|
378
383
|
|
379
384
|
# distance functions
|
@@ -426,6 +431,9 @@ module Torch
|
|
426
431
|
end
|
427
432
|
|
428
433
|
def mse_loss(input, target, reduction: "mean")
|
434
|
+
if target.size != input.size
|
435
|
+
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
436
|
+
end
|
429
437
|
NN.mse_loss(input, target, reduction)
|
430
438
|
end
|
431
439
|
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class LeakyReLU < Module
|
4
|
-
def initialize(negative_slope: 1e-2
|
4
|
+
def initialize(negative_slope: 1e-2, inplace: false)
|
5
5
|
super()
|
6
6
|
@negative_slope = negative_slope
|
7
|
-
|
7
|
+
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
10
|
def forward(input)
|
11
|
-
F.leaky_relu(input, @negative_slope
|
11
|
+
F.leaky_relu(input, @negative_slope, inplace: @inplace)
|
12
12
|
end
|
13
13
|
|
14
14
|
def extra_inspect
|
data/lib/torch/nn/module.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -103,11 +103,6 @@ module Torch
|
|
103
103
|
Torch.empty(0, dtype: dtype)
|
104
104
|
end
|
105
105
|
|
106
|
-
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
-
retain_graph = create_graph if retain_graph.nil?
|
108
|
-
_backward(gradient, retain_graph, create_graph)
|
109
|
-
end
|
110
|
-
|
111
106
|
# TODO read directly from memory
|
112
107
|
def numo
|
113
108
|
cls = Torch._dtype_to_numo[dtype]
|
@@ -188,49 +183,15 @@ module Torch
|
|
188
183
|
# based on python_variable_indexing.cpp and
|
189
184
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
185
|
def [](*indexes)
|
191
|
-
|
192
|
-
dim = 0
|
193
|
-
indexes.each do |index|
|
194
|
-
if index.is_a?(Numeric)
|
195
|
-
result = result._select_int(dim, index)
|
196
|
-
elsif index.is_a?(Range)
|
197
|
-
finish = index.end
|
198
|
-
finish += 1 unless index.exclude_end?
|
199
|
-
result = result._slice_tensor(dim, index.begin, finish, 1)
|
200
|
-
dim += 1
|
201
|
-
elsif index.is_a?(Tensor)
|
202
|
-
result = result.index([index])
|
203
|
-
elsif index.nil?
|
204
|
-
result = result.unsqueeze(dim)
|
205
|
-
dim += 1
|
206
|
-
elsif index == true
|
207
|
-
result = result.unsqueeze(dim)
|
208
|
-
# TODO handle false
|
209
|
-
else
|
210
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
211
|
-
end
|
212
|
-
end
|
213
|
-
result
|
186
|
+
_index(tensor_indexes(indexes))
|
214
187
|
end
|
215
188
|
|
216
189
|
# based on python_variable_indexing.cpp and
|
217
190
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
218
|
-
def []=(
|
191
|
+
def []=(*indexes, value)
|
219
192
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
220
|
-
|
221
193
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
222
|
-
|
223
|
-
if index.is_a?(Numeric)
|
224
|
-
index_put!([Torch.tensor(index)], value)
|
225
|
-
elsif index.is_a?(Range)
|
226
|
-
finish = index.end
|
227
|
-
finish += 1 unless index.exclude_end?
|
228
|
-
_slice_tensor(0, index.begin, finish, 1).copy!(value)
|
229
|
-
elsif index.is_a?(Tensor)
|
230
|
-
index_put!([index], value)
|
231
|
-
else
|
232
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
233
|
-
end
|
194
|
+
_index_put_custom(tensor_indexes(indexes), value)
|
234
195
|
end
|
235
196
|
|
236
197
|
# native functions that need manually defined
|
@@ -244,13 +205,13 @@ module Torch
|
|
244
205
|
end
|
245
206
|
end
|
246
207
|
|
247
|
-
#
|
208
|
+
# parser can't handle overlap, so need to handle manually
|
248
209
|
def random!(*args)
|
249
210
|
case args.size
|
250
211
|
when 1
|
251
212
|
_random__to(*args)
|
252
213
|
when 2
|
253
|
-
|
214
|
+
_random__from(*args)
|
254
215
|
else
|
255
216
|
_random_(*args)
|
256
217
|
end
|
@@ -260,5 +221,32 @@ module Torch
|
|
260
221
|
_clamp_min_(min)
|
261
222
|
_clamp_max_(max)
|
262
223
|
end
|
224
|
+
|
225
|
+
private
|
226
|
+
|
227
|
+
def tensor_indexes(indexes)
|
228
|
+
indexes.map do |index|
|
229
|
+
case index
|
230
|
+
when Integer
|
231
|
+
TensorIndex.integer(index)
|
232
|
+
when Range
|
233
|
+
finish = index.end || -1
|
234
|
+
if finish == -1 && !index.exclude_end?
|
235
|
+
finish = nil
|
236
|
+
else
|
237
|
+
finish += 1 unless index.exclude_end?
|
238
|
+
end
|
239
|
+
TensorIndex.slice(index.begin, finish)
|
240
|
+
when Tensor
|
241
|
+
TensorIndex.tensor(index)
|
242
|
+
when nil
|
243
|
+
TensorIndex.none
|
244
|
+
when true, false
|
245
|
+
TensorIndex.boolean(index)
|
246
|
+
else
|
247
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
248
|
+
end
|
249
|
+
end
|
250
|
+
end
|
263
251
|
end
|
264
252
|
end
|
@@ -0,0 +1,23 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class << self
|
5
|
+
def random_split(dataset, lengths)
|
6
|
+
if lengths.sum != dataset.length
|
7
|
+
raise ArgumentError, "Sum of input lengths does not equal the length of the input dataset!"
|
8
|
+
end
|
9
|
+
|
10
|
+
indices = Torch.randperm(lengths.sum).to_a
|
11
|
+
_accumulate(lengths).zip(lengths).map { |offset, length| Subset.new(dataset, indices[(offset - length)...offset]) }
|
12
|
+
end
|
13
|
+
|
14
|
+
private
|
15
|
+
|
16
|
+
def _accumulate(iterable)
|
17
|
+
sum = 0
|
18
|
+
iterable.map { |x| sum += x }
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
@@ -6,10 +6,22 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1, shuffle: false)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false, collate_fn: nil)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
12
|
@shuffle = shuffle
|
13
|
+
|
14
|
+
@batch_sampler = nil
|
15
|
+
|
16
|
+
if collate_fn.nil?
|
17
|
+
if auto_collation?
|
18
|
+
collate_fn = method(:default_collate)
|
19
|
+
else
|
20
|
+
collate_fn = method(:default_convert)
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
@collate_fn = collate_fn
|
13
25
|
end
|
14
26
|
|
15
27
|
def each
|
@@ -25,18 +37,20 @@ module Torch
|
|
25
37
|
end
|
26
38
|
|
27
39
|
indexes.each_slice(@batch_size) do |idx|
|
28
|
-
|
29
|
-
yield
|
40
|
+
# TODO improve performance
|
41
|
+
yield @collate_fn.call(idx.map { |i| @dataset[i] })
|
30
42
|
end
|
31
43
|
end
|
32
44
|
|
33
45
|
def size
|
34
46
|
(@dataset.size / @batch_size.to_f).ceil
|
35
47
|
end
|
48
|
+
alias_method :length, :size
|
49
|
+
alias_method :count, :size
|
36
50
|
|
37
51
|
private
|
38
52
|
|
39
|
-
def
|
53
|
+
def default_convert(batch)
|
40
54
|
elem = batch[0]
|
41
55
|
case elem
|
42
56
|
when Tensor
|
@@ -44,11 +58,15 @@ module Torch
|
|
44
58
|
when Integer
|
45
59
|
Torch.tensor(batch)
|
46
60
|
when Array
|
47
|
-
batch.transpose.map { |v|
|
61
|
+
batch.transpose.map { |v| default_convert(v) }
|
48
62
|
else
|
49
|
-
raise
|
63
|
+
raise NotImplementedYet
|
50
64
|
end
|
51
65
|
end
|
66
|
+
|
67
|
+
def auto_collation?
|
68
|
+
!@batch_sampler.nil?
|
69
|
+
end
|
52
70
|
end
|
53
71
|
end
|
54
72
|
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module Utils
|
3
|
+
module Data
|
4
|
+
class Subset < Dataset
|
5
|
+
def initialize(dataset, indices)
|
6
|
+
@dataset = dataset
|
7
|
+
@indices = indices
|
8
|
+
end
|
9
|
+
|
10
|
+
def [](idx)
|
11
|
+
@dataset[@indices[idx]]
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@indices.length
|
16
|
+
end
|
17
|
+
alias_method :size, :length
|
18
|
+
|
19
|
+
def to_a
|
20
|
+
@indices.map { |i| @dataset[i] }
|
21
|
+
end
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-09-18 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -259,8 +259,10 @@ files:
|
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
261
|
- lib/torch/tensor.rb
|
262
|
+
- lib/torch/utils/data.rb
|
262
263
|
- lib/torch/utils/data/data_loader.rb
|
263
264
|
- lib/torch/utils/data/dataset.rb
|
265
|
+
- lib/torch/utils/data/subset.rb
|
264
266
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
267
|
- lib/torch/version.rb
|
266
268
|
homepage: https://github.com/ankane/torch.rb
|