torch-rb 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -6
- data/ext/torch/ext.cpp +35 -19
- data/lib/torch.rb +5 -0
- data/lib/torch/hub.rb +48 -4
- data/lib/torch/inspector.rb +2 -2
- data/lib/torch/tensor.rb +23 -24
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8f6ab78fb5cff27d0d60ddb9c08fb2f526bd60e241dd1011554b21716bdd2f43
|
4
|
+
data.tar.gz: 5568f53d8d5d688e3f29fb55ddbe9457e0b933dc69f59b422035c0cee249e396
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9c5dcfbf35382b37678662690677b2b90d0d544e8802703cf83dba6e10483a4df487f9687e4e898c9cc449c568f4e02f9d831daa982b5b3135af6f9ce176ec88
|
7
|
+
data.tar.gz: 34f142d874606e140661ae992a9f8cd4779f95c93c11d9a89a1864dd0bd53c5480c30d9aec5897f1955e2450bd0a3bc56ed0868e3b54d82ff4cdba40af379840
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -294,7 +308,7 @@ Torch.save(net.state_dict, "net.pth")
|
|
294
308
|
Load a model
|
295
309
|
|
296
310
|
```ruby
|
297
|
-
net =
|
311
|
+
net = MyNet.new
|
298
312
|
net.load_state_dict(Torch.load("net.pth"))
|
299
313
|
net.eval
|
300
314
|
```
|
@@ -413,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
413
427
|
|
414
428
|
### Linux
|
415
429
|
|
416
|
-
Deep learning is significantly faster on
|
417
|
-
|
418
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
419
431
|
|
420
432
|
Check if CUDA is available
|
421
433
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
|
|
23
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
24
24
|
};
|
25
25
|
|
26
|
-
void handle_error(
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
27
|
{
|
28
28
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
29
|
}
|
@@ -32,15 +32,19 @@ extern "C"
|
|
32
32
|
void Init_ext()
|
33
33
|
{
|
34
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
35
36
|
add_torch_functions(rb_mTorch);
|
36
37
|
|
37
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
38
40
|
add_tensor_functions(rb_cTensor);
|
39
41
|
|
40
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
41
44
|
add_nn_functions(rb_mNN);
|
42
45
|
|
43
46
|
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
44
48
|
.define_singleton_method(
|
45
49
|
"initial_seed",
|
46
50
|
*[]() {
|
@@ -55,6 +59,7 @@ void Init_ext()
|
|
55
59
|
|
56
60
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
57
61
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
58
63
|
.define_constructor(Constructor<torch::IValue>())
|
59
64
|
.define_method("bool?", &torch::IValue::isBool)
|
60
65
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
@@ -317,7 +322,6 @@ void Init_ext()
|
|
317
322
|
});
|
318
323
|
|
319
324
|
rb_cTensor
|
320
|
-
.add_handler<c10::Error>(handle_error)
|
321
325
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
322
326
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
323
327
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -374,6 +378,21 @@ void Init_ext()
|
|
374
378
|
s << self.device();
|
375
379
|
return s.str();
|
376
380
|
})
|
381
|
+
.define_method(
|
382
|
+
"_data_str",
|
383
|
+
*[](Tensor& self) {
|
384
|
+
Tensor tensor = self;
|
385
|
+
|
386
|
+
// move to CPU to get data
|
387
|
+
if (tensor.device().type() != torch::kCPU) {
|
388
|
+
torch::Device device("cpu");
|
389
|
+
tensor = tensor.to(device);
|
390
|
+
}
|
391
|
+
|
392
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
393
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
394
|
+
})
|
395
|
+
// TODO figure out a better way to do this
|
377
396
|
.define_method(
|
378
397
|
"_flat_data",
|
379
398
|
*[](Tensor& self) {
|
@@ -388,46 +407,40 @@ void Init_ext()
|
|
388
407
|
Array a;
|
389
408
|
auto dtype = tensor.dtype();
|
390
409
|
|
410
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
411
|
+
|
391
412
|
// TODO DRY if someone knows C++
|
392
413
|
if (dtype == torch::kByte) {
|
393
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
394
414
|
for (int i = 0; i < tensor.numel(); i++) {
|
395
|
-
a.push(
|
415
|
+
a.push(view[i].item().to<uint8_t>());
|
396
416
|
}
|
397
417
|
} else if (dtype == torch::kChar) {
|
398
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
399
418
|
for (int i = 0; i < tensor.numel(); i++) {
|
400
|
-
a.push(to_ruby<int>(
|
419
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
401
420
|
}
|
402
421
|
} else if (dtype == torch::kShort) {
|
403
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
404
422
|
for (int i = 0; i < tensor.numel(); i++) {
|
405
|
-
a.push(
|
423
|
+
a.push(view[i].item().to<int16_t>());
|
406
424
|
}
|
407
425
|
} else if (dtype == torch::kInt) {
|
408
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
409
426
|
for (int i = 0; i < tensor.numel(); i++) {
|
410
|
-
a.push(
|
427
|
+
a.push(view[i].item().to<int32_t>());
|
411
428
|
}
|
412
429
|
} else if (dtype == torch::kLong) {
|
413
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
414
430
|
for (int i = 0; i < tensor.numel(); i++) {
|
415
|
-
a.push(
|
431
|
+
a.push(view[i].item().to<int64_t>());
|
416
432
|
}
|
417
433
|
} else if (dtype == torch::kFloat) {
|
418
|
-
float* data = tensor.data_ptr<float>();
|
419
434
|
for (int i = 0; i < tensor.numel(); i++) {
|
420
|
-
a.push(
|
435
|
+
a.push(view[i].item().to<float>());
|
421
436
|
}
|
422
437
|
} else if (dtype == torch::kDouble) {
|
423
|
-
double* data = tensor.data_ptr<double>();
|
424
438
|
for (int i = 0; i < tensor.numel(); i++) {
|
425
|
-
a.push(
|
439
|
+
a.push(view[i].item().to<double>());
|
426
440
|
}
|
427
441
|
} else if (dtype == torch::kBool) {
|
428
|
-
bool* data = tensor.data_ptr<bool>();
|
429
442
|
for (int i = 0; i < tensor.numel(); i++) {
|
430
|
-
a.push(
|
443
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
431
444
|
}
|
432
445
|
} else {
|
433
446
|
throw std::runtime_error("Unsupported type");
|
@@ -449,7 +462,7 @@ void Init_ext()
|
|
449
462
|
});
|
450
463
|
|
451
464
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
452
|
-
.add_handler<
|
465
|
+
.add_handler<torch::Error>(handle_error)
|
453
466
|
.define_constructor(Constructor<torch::TensorOptions>())
|
454
467
|
.define_method(
|
455
468
|
"dtype",
|
@@ -555,6 +568,7 @@ void Init_ext()
|
|
555
568
|
});
|
556
569
|
|
557
570
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
571
|
+
.add_handler<torch::Error>(handle_error)
|
558
572
|
.define_method(
|
559
573
|
"grad",
|
560
574
|
*[](Parameter& self) {
|
@@ -564,6 +578,7 @@ void Init_ext()
|
|
564
578
|
|
565
579
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
566
580
|
.define_constructor(Constructor<torch::Device, std::string>())
|
581
|
+
.add_handler<torch::Error>(handle_error)
|
567
582
|
.define_method("index", &torch::Device::index)
|
568
583
|
.define_method("index?", &torch::Device::has_index)
|
569
584
|
.define_method(
|
@@ -575,6 +590,7 @@ void Init_ext()
|
|
575
590
|
});
|
576
591
|
|
577
592
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
593
|
+
.add_handler<torch::Error>(handle_error)
|
578
594
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
579
595
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
580
596
|
}
|
data/lib/torch.rb
CHANGED
data/lib/torch/hub.rb
CHANGED
@@ -5,12 +5,56 @@ module Torch
|
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
7
|
|
8
|
-
def download_url_to_file(url)
|
9
|
-
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
10
39
|
end
|
11
40
|
|
12
|
-
def load_state_dict_from_url(url)
|
13
|
-
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
14
58
|
end
|
15
59
|
end
|
16
60
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
-
# TODO make more
|
3
|
+
# TODO make more performant, especially when summarizing
|
4
4
|
# how? only read data that will be displayed
|
5
5
|
def inspect
|
6
6
|
data =
|
@@ -14,7 +14,7 @@ module Torch
|
|
14
14
|
if dtype == :bool
|
15
15
|
fmt = "%s"
|
16
16
|
else
|
17
|
-
values =
|
17
|
+
values = _flat_data
|
18
18
|
abs = values.select { |v| v != 0 }.map(&:abs)
|
19
19
|
max = abs.max || 1
|
20
20
|
min = abs.min || 1
|
data/lib/torch/tensor.rb
CHANGED
@@ -25,8 +25,17 @@ module Torch
|
|
25
25
|
inspect
|
26
26
|
end
|
27
27
|
|
28
|
+
# TODO make more performant
|
28
29
|
def to_a
|
29
|
-
|
30
|
+
arr = _flat_data
|
31
|
+
if shape.empty?
|
32
|
+
arr
|
33
|
+
else
|
34
|
+
shape[1..-1].reverse.each do |dim|
|
35
|
+
arr = arr.each_slice(dim)
|
36
|
+
end
|
37
|
+
arr.to_a
|
38
|
+
end
|
30
39
|
end
|
31
40
|
|
32
41
|
# TODO support dtype
|
@@ -64,7 +73,7 @@ module Torch
|
|
64
73
|
if numel != 1
|
65
74
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
66
75
|
end
|
67
|
-
|
76
|
+
to_a.first
|
68
77
|
end
|
69
78
|
|
70
79
|
def to_i
|
@@ -88,7 +97,7 @@ module Torch
|
|
88
97
|
def numo
|
89
98
|
cls = Torch._dtype_to_numo[dtype]
|
90
99
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
91
|
-
cls.
|
100
|
+
cls.from_string(_data_str).reshape(*shape)
|
92
101
|
end
|
93
102
|
|
94
103
|
def new_ones(*size, **options)
|
@@ -116,15 +125,6 @@ module Torch
|
|
116
125
|
_view(size)
|
117
126
|
end
|
118
127
|
|
119
|
-
# value and other are swapped for some methods
|
120
|
-
def add!(value = 1, other)
|
121
|
-
if other.is_a?(Numeric)
|
122
|
-
_add__scalar(other, value)
|
123
|
-
else
|
124
|
-
_add__tensor(other, value)
|
125
|
-
end
|
126
|
-
end
|
127
|
-
|
128
128
|
def +(other)
|
129
129
|
add(other)
|
130
130
|
end
|
@@ -201,6 +201,17 @@ module Torch
|
|
201
201
|
end
|
202
202
|
end
|
203
203
|
|
204
|
+
# native functions that need manually defined
|
205
|
+
|
206
|
+
# value and other are swapped for some methods
|
207
|
+
def add!(value = 1, other)
|
208
|
+
if other.is_a?(Numeric)
|
209
|
+
_add__scalar(other, value)
|
210
|
+
else
|
211
|
+
_add__tensor(other, value)
|
212
|
+
end
|
213
|
+
end
|
214
|
+
|
204
215
|
# native functions overlap, so need to handle manually
|
205
216
|
def random!(*args)
|
206
217
|
case args.size
|
@@ -218,17 +229,5 @@ module Torch
|
|
218
229
|
def copy_to(dst, src)
|
219
230
|
dst.copy!(src)
|
220
231
|
end
|
221
|
-
|
222
|
-
def reshape_arr(arr, dims)
|
223
|
-
if dims.empty?
|
224
|
-
arr
|
225
|
-
else
|
226
|
-
arr = arr.flatten
|
227
|
-
dims[1..-1].reverse.each do |dim|
|
228
|
-
arr = arr.each_slice(dim)
|
229
|
-
end
|
230
|
-
arr.to_a
|
231
|
-
end
|
232
|
-
end
|
233
232
|
end
|
234
233
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-06-07 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|