torch-rb 0.2.4 → 0.2.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +18 -6
- data/ext/torch/ext.cpp +35 -19
- data/lib/torch.rb +5 -0
- data/lib/torch/hub.rb +48 -4
- data/lib/torch/inspector.rb +2 -2
- data/lib/torch/tensor.rb +23 -24
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8f6ab78fb5cff27d0d60ddb9c08fb2f526bd60e241dd1011554b21716bdd2f43
|
4
|
+
data.tar.gz: 5568f53d8d5d688e3f29fb55ddbe9457e0b933dc69f59b422035c0cee249e396
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9c5dcfbf35382b37678662690677b2b90d0d544e8802703cf83dba6e10483a4df487f9687e4e898c9cc449c568f4e02f9d831daa982b5b3135af6f9ce176ec88
|
7
|
+
data.tar.gz: 34f142d874606e140661ae992a9f8cd4779f95c93c11d9a89a1864dd0bd53c5480c30d9aec5897f1955e2450bd0a3bc56ed0868e3b54d82ff4cdba40af379840
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -294,7 +308,7 @@ Torch.save(net.state_dict, "net.pth")
|
|
294
308
|
Load a model
|
295
309
|
|
296
310
|
```ruby
|
297
|
-
net =
|
311
|
+
net = MyNet.new
|
298
312
|
net.load_state_dict(Torch.load("net.pth"))
|
299
313
|
net.eval
|
300
314
|
```
|
@@ -413,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
413
427
|
|
414
428
|
### Linux
|
415
429
|
|
416
|
-
Deep learning is significantly faster on
|
417
|
-
|
418
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
419
431
|
|
420
432
|
Check if CUDA is available
|
421
433
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
|
|
23
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
24
24
|
};
|
25
25
|
|
26
|
-
void handle_error(
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
27
|
{
|
28
28
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
29
|
}
|
@@ -32,15 +32,19 @@ extern "C"
|
|
32
32
|
void Init_ext()
|
33
33
|
{
|
34
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
35
36
|
add_torch_functions(rb_mTorch);
|
36
37
|
|
37
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
38
40
|
add_tensor_functions(rb_cTensor);
|
39
41
|
|
40
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
41
44
|
add_nn_functions(rb_mNN);
|
42
45
|
|
43
46
|
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
44
48
|
.define_singleton_method(
|
45
49
|
"initial_seed",
|
46
50
|
*[]() {
|
@@ -55,6 +59,7 @@ void Init_ext()
|
|
55
59
|
|
56
60
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
57
61
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
58
63
|
.define_constructor(Constructor<torch::IValue>())
|
59
64
|
.define_method("bool?", &torch::IValue::isBool)
|
60
65
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
@@ -317,7 +322,6 @@ void Init_ext()
|
|
317
322
|
});
|
318
323
|
|
319
324
|
rb_cTensor
|
320
|
-
.add_handler<c10::Error>(handle_error)
|
321
325
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
322
326
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
323
327
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -374,6 +378,21 @@ void Init_ext()
|
|
374
378
|
s << self.device();
|
375
379
|
return s.str();
|
376
380
|
})
|
381
|
+
.define_method(
|
382
|
+
"_data_str",
|
383
|
+
*[](Tensor& self) {
|
384
|
+
Tensor tensor = self;
|
385
|
+
|
386
|
+
// move to CPU to get data
|
387
|
+
if (tensor.device().type() != torch::kCPU) {
|
388
|
+
torch::Device device("cpu");
|
389
|
+
tensor = tensor.to(device);
|
390
|
+
}
|
391
|
+
|
392
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
393
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
394
|
+
})
|
395
|
+
// TODO figure out a better way to do this
|
377
396
|
.define_method(
|
378
397
|
"_flat_data",
|
379
398
|
*[](Tensor& self) {
|
@@ -388,46 +407,40 @@ void Init_ext()
|
|
388
407
|
Array a;
|
389
408
|
auto dtype = tensor.dtype();
|
390
409
|
|
410
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
411
|
+
|
391
412
|
// TODO DRY if someone knows C++
|
392
413
|
if (dtype == torch::kByte) {
|
393
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
394
414
|
for (int i = 0; i < tensor.numel(); i++) {
|
395
|
-
a.push(
|
415
|
+
a.push(view[i].item().to<uint8_t>());
|
396
416
|
}
|
397
417
|
} else if (dtype == torch::kChar) {
|
398
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
399
418
|
for (int i = 0; i < tensor.numel(); i++) {
|
400
|
-
a.push(to_ruby<int>(
|
419
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
401
420
|
}
|
402
421
|
} else if (dtype == torch::kShort) {
|
403
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
404
422
|
for (int i = 0; i < tensor.numel(); i++) {
|
405
|
-
a.push(
|
423
|
+
a.push(view[i].item().to<int16_t>());
|
406
424
|
}
|
407
425
|
} else if (dtype == torch::kInt) {
|
408
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
409
426
|
for (int i = 0; i < tensor.numel(); i++) {
|
410
|
-
a.push(
|
427
|
+
a.push(view[i].item().to<int32_t>());
|
411
428
|
}
|
412
429
|
} else if (dtype == torch::kLong) {
|
413
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
414
430
|
for (int i = 0; i < tensor.numel(); i++) {
|
415
|
-
a.push(
|
431
|
+
a.push(view[i].item().to<int64_t>());
|
416
432
|
}
|
417
433
|
} else if (dtype == torch::kFloat) {
|
418
|
-
float* data = tensor.data_ptr<float>();
|
419
434
|
for (int i = 0; i < tensor.numel(); i++) {
|
420
|
-
a.push(
|
435
|
+
a.push(view[i].item().to<float>());
|
421
436
|
}
|
422
437
|
} else if (dtype == torch::kDouble) {
|
423
|
-
double* data = tensor.data_ptr<double>();
|
424
438
|
for (int i = 0; i < tensor.numel(); i++) {
|
425
|
-
a.push(
|
439
|
+
a.push(view[i].item().to<double>());
|
426
440
|
}
|
427
441
|
} else if (dtype == torch::kBool) {
|
428
|
-
bool* data = tensor.data_ptr<bool>();
|
429
442
|
for (int i = 0; i < tensor.numel(); i++) {
|
430
|
-
a.push(
|
443
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
431
444
|
}
|
432
445
|
} else {
|
433
446
|
throw std::runtime_error("Unsupported type");
|
@@ -449,7 +462,7 @@ void Init_ext()
|
|
449
462
|
});
|
450
463
|
|
451
464
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
452
|
-
.add_handler<
|
465
|
+
.add_handler<torch::Error>(handle_error)
|
453
466
|
.define_constructor(Constructor<torch::TensorOptions>())
|
454
467
|
.define_method(
|
455
468
|
"dtype",
|
@@ -555,6 +568,7 @@ void Init_ext()
|
|
555
568
|
});
|
556
569
|
|
557
570
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
571
|
+
.add_handler<torch::Error>(handle_error)
|
558
572
|
.define_method(
|
559
573
|
"grad",
|
560
574
|
*[](Parameter& self) {
|
@@ -564,6 +578,7 @@ void Init_ext()
|
|
564
578
|
|
565
579
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
566
580
|
.define_constructor(Constructor<torch::Device, std::string>())
|
581
|
+
.add_handler<torch::Error>(handle_error)
|
567
582
|
.define_method("index", &torch::Device::index)
|
568
583
|
.define_method("index?", &torch::Device::has_index)
|
569
584
|
.define_method(
|
@@ -575,6 +590,7 @@ void Init_ext()
|
|
575
590
|
});
|
576
591
|
|
577
592
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
593
|
+
.add_handler<torch::Error>(handle_error)
|
578
594
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
579
595
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
580
596
|
}
|
data/lib/torch.rb
CHANGED
data/lib/torch/hub.rb
CHANGED
@@ -5,12 +5,56 @@ module Torch
|
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
7
|
|
8
|
-
def download_url_to_file(url)
|
9
|
-
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
10
39
|
end
|
11
40
|
|
12
|
-
def load_state_dict_from_url(url)
|
13
|
-
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
14
58
|
end
|
15
59
|
end
|
16
60
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
-
# TODO make more
|
3
|
+
# TODO make more performant, especially when summarizing
|
4
4
|
# how? only read data that will be displayed
|
5
5
|
def inspect
|
6
6
|
data =
|
@@ -14,7 +14,7 @@ module Torch
|
|
14
14
|
if dtype == :bool
|
15
15
|
fmt = "%s"
|
16
16
|
else
|
17
|
-
values =
|
17
|
+
values = _flat_data
|
18
18
|
abs = values.select { |v| v != 0 }.map(&:abs)
|
19
19
|
max = abs.max || 1
|
20
20
|
min = abs.min || 1
|
data/lib/torch/tensor.rb
CHANGED
@@ -25,8 +25,17 @@ module Torch
|
|
25
25
|
inspect
|
26
26
|
end
|
27
27
|
|
28
|
+
# TODO make more performant
|
28
29
|
def to_a
|
29
|
-
|
30
|
+
arr = _flat_data
|
31
|
+
if shape.empty?
|
32
|
+
arr
|
33
|
+
else
|
34
|
+
shape[1..-1].reverse.each do |dim|
|
35
|
+
arr = arr.each_slice(dim)
|
36
|
+
end
|
37
|
+
arr.to_a
|
38
|
+
end
|
30
39
|
end
|
31
40
|
|
32
41
|
# TODO support dtype
|
@@ -64,7 +73,7 @@ module Torch
|
|
64
73
|
if numel != 1
|
65
74
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
66
75
|
end
|
67
|
-
|
76
|
+
to_a.first
|
68
77
|
end
|
69
78
|
|
70
79
|
def to_i
|
@@ -88,7 +97,7 @@ module Torch
|
|
88
97
|
def numo
|
89
98
|
cls = Torch._dtype_to_numo[dtype]
|
90
99
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
91
|
-
cls.
|
100
|
+
cls.from_string(_data_str).reshape(*shape)
|
92
101
|
end
|
93
102
|
|
94
103
|
def new_ones(*size, **options)
|
@@ -116,15 +125,6 @@ module Torch
|
|
116
125
|
_view(size)
|
117
126
|
end
|
118
127
|
|
119
|
-
# value and other are swapped for some methods
|
120
|
-
def add!(value = 1, other)
|
121
|
-
if other.is_a?(Numeric)
|
122
|
-
_add__scalar(other, value)
|
123
|
-
else
|
124
|
-
_add__tensor(other, value)
|
125
|
-
end
|
126
|
-
end
|
127
|
-
|
128
128
|
def +(other)
|
129
129
|
add(other)
|
130
130
|
end
|
@@ -201,6 +201,17 @@ module Torch
|
|
201
201
|
end
|
202
202
|
end
|
203
203
|
|
204
|
+
# native functions that need manually defined
|
205
|
+
|
206
|
+
# value and other are swapped for some methods
|
207
|
+
def add!(value = 1, other)
|
208
|
+
if other.is_a?(Numeric)
|
209
|
+
_add__scalar(other, value)
|
210
|
+
else
|
211
|
+
_add__tensor(other, value)
|
212
|
+
end
|
213
|
+
end
|
214
|
+
|
204
215
|
# native functions overlap, so need to handle manually
|
205
216
|
def random!(*args)
|
206
217
|
case args.size
|
@@ -218,17 +229,5 @@ module Torch
|
|
218
229
|
def copy_to(dst, src)
|
219
230
|
dst.copy!(src)
|
220
231
|
end
|
221
|
-
|
222
|
-
def reshape_arr(arr, dims)
|
223
|
-
if dims.empty?
|
224
|
-
arr
|
225
|
-
else
|
226
|
-
arr = arr.flatten
|
227
|
-
dims[1..-1].reverse.each do |dim|
|
228
|
-
arr = arr.each_slice(dim)
|
229
|
-
end
|
230
|
-
arr.to_a
|
231
|
-
end
|
232
|
-
end
|
233
232
|
end
|
234
233
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.5
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-06-07 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|