torch-rb 0.2.1 → 0.2.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +20 -8
- data/ext/torch/ext.cpp +87 -22
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch.rb +14 -5
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +3 -3
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +22 -6
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +58 -30
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -6
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
|
4
|
+
data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
|
7
|
+
data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## 0.2.6 (2020-06-29)
|
2
|
+
|
3
|
+
- Added support for indexing with tensors
|
4
|
+
- Added `contiguous` methods
|
5
|
+
- Fixed named parameters for nested parameters
|
6
|
+
|
7
|
+
## 0.2.5 (2020-06-07)
|
8
|
+
|
9
|
+
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
10
|
+
- Improved error messages
|
11
|
+
- Fixed tensor slicing
|
12
|
+
|
13
|
+
## 0.2.4 (2020-04-29)
|
14
|
+
|
15
|
+
- Added `to_i` and `to_f` to tensors
|
16
|
+
- Added `shuffle` option to data loader
|
17
|
+
- Fixed `modules` and `named_modules` for nested modules
|
18
|
+
|
19
|
+
## 0.2.3 (2020-04-28)
|
20
|
+
|
21
|
+
- Added `show_config` and `parallel_info` methods
|
22
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
23
|
+
- Improved data loader
|
24
|
+
- Build with MKL-DNN and NNPACK when available
|
25
|
+
- Fixed `inspect` for modules
|
26
|
+
|
27
|
+
## 0.2.2 (2020-04-27)
|
28
|
+
|
29
|
+
- Added support for saving tensor lists
|
30
|
+
- Added `ndim` and `ndimension` methods to tensors
|
31
|
+
|
1
32
|
## 0.2.1 (2020-04-26)
|
2
33
|
|
3
34
|
- Added support for saving and loading models
|
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -294,7 +308,7 @@ Torch.save(net.state_dict, "net.pth")
|
|
294
308
|
Load a model
|
295
309
|
|
296
310
|
```ruby
|
297
|
-
net =
|
311
|
+
net = MyNet.new
|
298
312
|
net.load_state_dict(Torch.load("net.pth"))
|
299
313
|
net.eval
|
300
314
|
```
|
@@ -395,7 +409,7 @@ Here’s the list of compatible versions.
|
|
395
409
|
|
396
410
|
Torch.rb | LibTorch
|
397
411
|
--- | ---
|
398
|
-
0.2.0 | 1.5.0
|
412
|
+
0.2.0+ | 1.5.0+
|
399
413
|
0.1.8 | 1.4.0
|
400
414
|
0.1.0-0.1.7 | 1.3.1
|
401
415
|
|
@@ -413,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
413
427
|
|
414
428
|
### Linux
|
415
429
|
|
416
|
-
Deep learning is significantly faster on
|
417
|
-
|
418
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
419
431
|
|
420
432
|
Check if CUDA is available
|
421
433
|
|
@@ -426,7 +438,7 @@ Torch::CUDA.available?
|
|
426
438
|
Move a neural network to a GPU
|
427
439
|
|
428
440
|
```ruby
|
429
|
-
net.
|
441
|
+
net.cuda
|
430
442
|
```
|
431
443
|
|
432
444
|
## rbenv
|
data/ext/torch/ext.cpp
CHANGED
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
|
|
23
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
24
24
|
};
|
25
25
|
|
26
|
-
void handle_error(
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
27
|
{
|
28
28
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
29
|
}
|
@@ -32,16 +32,34 @@ extern "C"
|
|
32
32
|
void Init_ext()
|
33
33
|
{
|
34
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
35
36
|
add_torch_functions(rb_mTorch);
|
36
37
|
|
37
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
38
40
|
add_tensor_functions(rb_cTensor);
|
39
41
|
|
40
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
41
44
|
add_nn_functions(rb_mNN);
|
42
45
|
|
46
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
48
|
+
.define_singleton_method(
|
49
|
+
"initial_seed",
|
50
|
+
*[]() {
|
51
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
52
|
+
})
|
53
|
+
.define_singleton_method(
|
54
|
+
"seed",
|
55
|
+
*[]() {
|
56
|
+
// TODO set for CUDA when available
|
57
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
58
|
+
});
|
59
|
+
|
43
60
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
61
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
45
63
|
.define_constructor(Constructor<torch::IValue>())
|
46
64
|
.define_method("bool?", &torch::IValue::isBool)
|
47
65
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
@@ -82,6 +100,16 @@ void Init_ext()
|
|
82
100
|
*[](torch::IValue& self) {
|
83
101
|
return self.toInt();
|
84
102
|
})
|
103
|
+
.define_method(
|
104
|
+
"to_list",
|
105
|
+
*[](torch::IValue& self) {
|
106
|
+
auto list = self.toListRef();
|
107
|
+
Array obj;
|
108
|
+
for (auto& elem : list) {
|
109
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
110
|
+
}
|
111
|
+
return obj;
|
112
|
+
})
|
85
113
|
.define_method(
|
86
114
|
"to_string_ref",
|
87
115
|
*[](torch::IValue& self) {
|
@@ -96,17 +124,27 @@ void Init_ext()
|
|
96
124
|
"to_generic_dict",
|
97
125
|
*[](torch::IValue& self) {
|
98
126
|
auto dict = self.toGenericDict();
|
99
|
-
Hash
|
127
|
+
Hash obj;
|
100
128
|
for (auto& pair : dict) {
|
101
|
-
|
129
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
102
130
|
}
|
103
|
-
return
|
131
|
+
return obj;
|
104
132
|
})
|
105
133
|
.define_singleton_method(
|
106
134
|
"from_tensor",
|
107
135
|
*[](torch::Tensor& v) {
|
108
136
|
return torch::IValue(v);
|
109
137
|
})
|
138
|
+
// TODO create specialized list types?
|
139
|
+
.define_singleton_method(
|
140
|
+
"from_list",
|
141
|
+
*[](Array obj) {
|
142
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
143
|
+
for (auto entry : obj) {
|
144
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
145
|
+
}
|
146
|
+
return torch::IValue(list);
|
147
|
+
})
|
110
148
|
.define_singleton_method(
|
111
149
|
"from_string",
|
112
150
|
*[](String v) {
|
@@ -157,6 +195,17 @@ void Init_ext()
|
|
157
195
|
*[](uint64_t seed) {
|
158
196
|
return torch::manual_seed(seed);
|
159
197
|
})
|
198
|
+
// config
|
199
|
+
.define_singleton_method(
|
200
|
+
"show_config",
|
201
|
+
*[] {
|
202
|
+
return torch::show_config();
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"parallel_info",
|
206
|
+
*[] {
|
207
|
+
return torch::get_parallel_info();
|
208
|
+
})
|
160
209
|
// begin tensor creation
|
161
210
|
.define_singleton_method(
|
162
211
|
"_arange",
|
@@ -273,7 +322,6 @@ void Init_ext()
|
|
273
322
|
});
|
274
323
|
|
275
324
|
rb_cTensor
|
276
|
-
.add_handler<c10::Error>(handle_error)
|
277
325
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
278
326
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
279
327
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -281,6 +329,11 @@ void Init_ext()
|
|
281
329
|
.define_method("numel", &torch::Tensor::numel)
|
282
330
|
.define_method("element_size", &torch::Tensor::element_size)
|
283
331
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
332
|
+
.define_method(
|
333
|
+
"contiguous?",
|
334
|
+
*[](Tensor& self) {
|
335
|
+
return self.is_contiguous();
|
336
|
+
})
|
284
337
|
.define_method(
|
285
338
|
"addcmul!",
|
286
339
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -330,6 +383,21 @@ void Init_ext()
|
|
330
383
|
s << self.device();
|
331
384
|
return s.str();
|
332
385
|
})
|
386
|
+
.define_method(
|
387
|
+
"_data_str",
|
388
|
+
*[](Tensor& self) {
|
389
|
+
Tensor tensor = self;
|
390
|
+
|
391
|
+
// move to CPU to get data
|
392
|
+
if (tensor.device().type() != torch::kCPU) {
|
393
|
+
torch::Device device("cpu");
|
394
|
+
tensor = tensor.to(device);
|
395
|
+
}
|
396
|
+
|
397
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
398
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
399
|
+
})
|
400
|
+
// TODO figure out a better way to do this
|
333
401
|
.define_method(
|
334
402
|
"_flat_data",
|
335
403
|
*[](Tensor& self) {
|
@@ -344,46 +412,40 @@ void Init_ext()
|
|
344
412
|
Array a;
|
345
413
|
auto dtype = tensor.dtype();
|
346
414
|
|
415
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
416
|
+
|
347
417
|
// TODO DRY if someone knows C++
|
348
418
|
if (dtype == torch::kByte) {
|
349
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
350
419
|
for (int i = 0; i < tensor.numel(); i++) {
|
351
|
-
a.push(
|
420
|
+
a.push(view[i].item().to<uint8_t>());
|
352
421
|
}
|
353
422
|
} else if (dtype == torch::kChar) {
|
354
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
355
423
|
for (int i = 0; i < tensor.numel(); i++) {
|
356
|
-
a.push(to_ruby<int>(
|
424
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
357
425
|
}
|
358
426
|
} else if (dtype == torch::kShort) {
|
359
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
360
427
|
for (int i = 0; i < tensor.numel(); i++) {
|
361
|
-
a.push(
|
428
|
+
a.push(view[i].item().to<int16_t>());
|
362
429
|
}
|
363
430
|
} else if (dtype == torch::kInt) {
|
364
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
365
431
|
for (int i = 0; i < tensor.numel(); i++) {
|
366
|
-
a.push(
|
432
|
+
a.push(view[i].item().to<int32_t>());
|
367
433
|
}
|
368
434
|
} else if (dtype == torch::kLong) {
|
369
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
370
435
|
for (int i = 0; i < tensor.numel(); i++) {
|
371
|
-
a.push(
|
436
|
+
a.push(view[i].item().to<int64_t>());
|
372
437
|
}
|
373
438
|
} else if (dtype == torch::kFloat) {
|
374
|
-
float* data = tensor.data_ptr<float>();
|
375
439
|
for (int i = 0; i < tensor.numel(); i++) {
|
376
|
-
a.push(
|
440
|
+
a.push(view[i].item().to<float>());
|
377
441
|
}
|
378
442
|
} else if (dtype == torch::kDouble) {
|
379
|
-
double* data = tensor.data_ptr<double>();
|
380
443
|
for (int i = 0; i < tensor.numel(); i++) {
|
381
|
-
a.push(
|
444
|
+
a.push(view[i].item().to<double>());
|
382
445
|
}
|
383
446
|
} else if (dtype == torch::kBool) {
|
384
|
-
bool* data = tensor.data_ptr<bool>();
|
385
447
|
for (int i = 0; i < tensor.numel(); i++) {
|
386
|
-
a.push(
|
448
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
387
449
|
}
|
388
450
|
} else {
|
389
451
|
throw std::runtime_error("Unsupported type");
|
@@ -405,7 +467,7 @@ void Init_ext()
|
|
405
467
|
});
|
406
468
|
|
407
469
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
408
|
-
.add_handler<
|
470
|
+
.add_handler<torch::Error>(handle_error)
|
409
471
|
.define_constructor(Constructor<torch::TensorOptions>())
|
410
472
|
.define_method(
|
411
473
|
"dtype",
|
@@ -511,6 +573,7 @@ void Init_ext()
|
|
511
573
|
});
|
512
574
|
|
513
575
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
576
|
+
.add_handler<torch::Error>(handle_error)
|
514
577
|
.define_method(
|
515
578
|
"grad",
|
516
579
|
*[](Parameter& self) {
|
@@ -520,6 +583,7 @@ void Init_ext()
|
|
520
583
|
|
521
584
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
522
585
|
.define_constructor(Constructor<torch::Device, std::string>())
|
586
|
+
.add_handler<torch::Error>(handle_error)
|
523
587
|
.define_method("index", &torch::Device::index)
|
524
588
|
.define_method("index?", &torch::Device::has_index)
|
525
589
|
.define_method(
|
@@ -531,6 +595,7 @@ void Init_ext()
|
|
531
595
|
});
|
532
596
|
|
533
597
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
598
|
+
.add_handler<torch::Error>(handle_error)
|
534
599
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
535
600
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
536
601
|
}
|
data/ext/torch/extconf.rb
CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
|
|
2
2
|
|
3
3
|
abort "Missing stdc++" unless have_library("stdc++")
|
4
4
|
|
5
|
-
$CXXFLAGS
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
6
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
|
-
$CXXFLAGS
|
8
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
13
|
# check omp first
|
14
14
|
if have_library("omp") || have_library("gomp")
|
15
|
-
$CXXFLAGS
|
16
|
-
$CXXFLAGS
|
17
|
-
$CXXFLAGS
|
15
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
+
$CXXFLAGS += " -Xclang" if clang
|
17
|
+
$CXXFLAGS += " -fopenmp"
|
18
18
|
end
|
19
19
|
|
20
20
|
if clang
|
21
21
|
# silence ruby/intern.h warning
|
22
|
-
$CXXFLAGS
|
22
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
23
23
|
|
24
24
|
# silence torch warnings
|
25
|
-
$CXXFLAGS
|
25
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
26
26
|
else
|
27
27
|
# silence rice warnings
|
28
|
-
$CXXFLAGS
|
28
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
29
29
|
|
30
30
|
# silence torch warnings
|
31
|
-
$CXXFLAGS
|
31
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
32
32
|
end
|
33
33
|
|
34
34
|
inc, lib = dir_config("torch")
|
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
39
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
40
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
41
41
|
|
42
|
-
$LDFLAGS
|
42
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
43
43
|
abort "LibTorch not found" unless have_library("torch")
|
44
44
|
|
45
|
+
have_library("mkldnn")
|
46
|
+
have_library("nnpack")
|
47
|
+
|
45
48
|
with_cuda = false
|
46
49
|
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
-
$LDFLAGS
|
50
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
51
|
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
52
|
end
|
50
53
|
|
51
|
-
$INCFLAGS
|
52
|
-
$INCFLAGS
|
54
|
+
$INCFLAGS += " -I#{inc}"
|
55
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
56
|
|
54
|
-
$LDFLAGS
|
55
|
-
$LDFLAGS
|
57
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
58
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
58
|
-
$LDFLAGS
|
61
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
59
62
|
if with_cuda
|
60
|
-
$LDFLAGS
|
63
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
61
64
|
# TODO figure out why this is needed
|
62
|
-
$LDFLAGS
|
65
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
63
66
|
end
|
64
67
|
|
65
68
|
# generate C++ functions
|
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# stdlib
|
5
|
+
require "fileutils"
|
6
|
+
require "net/http"
|
7
|
+
require "tmpdir"
|
8
|
+
|
4
9
|
# native functions
|
5
10
|
require "torch/native/generator"
|
6
11
|
require "torch/native/parser"
|
@@ -174,11 +179,9 @@ require "torch/nn/init"
|
|
174
179
|
|
175
180
|
# utils
|
176
181
|
require "torch/utils/data/data_loader"
|
182
|
+
require "torch/utils/data/dataset"
|
177
183
|
require "torch/utils/data/tensor_dataset"
|
178
184
|
|
179
|
-
# random
|
180
|
-
require "torch/random"
|
181
|
-
|
182
185
|
# hub
|
183
186
|
require "torch/hub"
|
184
187
|
|
@@ -466,6 +469,12 @@ module Torch
|
|
466
469
|
IValue.from_bool(obj)
|
467
470
|
when nil
|
468
471
|
IValue.new
|
472
|
+
when Array
|
473
|
+
if obj.all? { |v| v.is_a?(Tensor) }
|
474
|
+
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
475
|
+
else
|
476
|
+
raise Error, "Unknown list type"
|
477
|
+
end
|
469
478
|
else
|
470
479
|
raise Error, "Unknown type: #{obj.class.name}"
|
471
480
|
end
|
@@ -490,6 +499,8 @@ module Torch
|
|
490
499
|
dict[to_ruby(k)] = to_ruby(v)
|
491
500
|
end
|
492
501
|
dict
|
502
|
+
elsif ivalue.list?
|
503
|
+
ivalue.to_list.map { |v| to_ruby(v) }
|
493
504
|
else
|
494
505
|
type =
|
495
506
|
if ivalue.capsule?
|
@@ -510,8 +521,6 @@ module Torch
|
|
510
521
|
"BoolList"
|
511
522
|
elsif ivalue.tensor_list?
|
512
523
|
"TensorList"
|
513
|
-
elsif ivalue.list?
|
514
|
-
"List"
|
515
524
|
elsif ivalue.object?
|
516
525
|
"Object"
|
517
526
|
elsif ivalue.module?
|
data/lib/torch/hub.rb
CHANGED
@@ -4,6 +4,58 @@ module Torch
|
|
4
4
|
def list(github, force_reload: false)
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
|
+
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
58
|
+
end
|
7
59
|
end
|
8
60
|
end
|
9
61
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
-
# TODO make more
|
3
|
+
# TODO make more performant, especially when summarizing
|
4
4
|
# how? only read data that will be displayed
|
5
5
|
def inspect
|
6
6
|
data =
|
@@ -14,7 +14,7 @@ module Torch
|
|
14
14
|
if dtype == :bool
|
15
15
|
fmt = "%s"
|
16
16
|
else
|
17
|
-
values =
|
17
|
+
values = _flat_data
|
18
18
|
abs = values.select { |v| v != 0 }.map(&:abs)
|
19
19
|
max = abs.max || 1
|
20
20
|
min = abs.min || 1
|
@@ -25,7 +25,7 @@ module Torch
|
|
25
25
|
end
|
26
26
|
|
27
27
|
if floating_point?
|
28
|
-
sci = max
|
28
|
+
sci = max > 1e8 || max < 1e-4
|
29
29
|
|
30
30
|
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
31
31
|
decimal = all_int ? 1 : 4
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["bool[3]", "Dimname", "
|
21
|
+
skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
@@ -31,7 +31,7 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
|
34
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
36
|
# native_functions.yaml is missing size argument for normal
|
37
37
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
|
|
112
112
|
"OptionalScalarType"
|
113
113
|
when "Tensor[]"
|
114
114
|
"TensorList"
|
115
|
+
when "Tensor?[]"
|
116
|
+
# TODO make optional
|
117
|
+
"TensorList"
|
115
118
|
when "int"
|
116
119
|
"int64_t"
|
117
120
|
when "float"
|
data/lib/torch/native/parser.rb
CHANGED
data/lib/torch/nn/batch_norm.rb
CHANGED
@@ -70,6 +70,11 @@ module Torch
|
|
70
70
|
momentum: exponential_average_factor, eps: @eps
|
71
71
|
)
|
72
72
|
end
|
73
|
+
|
74
|
+
def extra_inspect
|
75
|
+
s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
|
76
|
+
format(s, **dict)
|
77
|
+
end
|
73
78
|
end
|
74
79
|
end
|
75
80
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -20,7 +20,14 @@ module Torch
|
|
20
20
|
|
21
21
|
# TODO add more parameters
|
22
22
|
def extra_inspect
|
23
|
-
|
23
|
+
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
|
+
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
25
|
+
s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
|
26
|
+
s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
|
27
|
+
s += ", groups: %{groups}" if @groups != 1
|
28
|
+
s += ", bias: false" unless @bias
|
29
|
+
s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
|
30
|
+
format(s, **dict)
|
24
31
|
end
|
25
32
|
end
|
26
33
|
end
|
data/lib/torch/nn/convnd.rb
CHANGED
data/lib/torch/nn/max_poolnd.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -145,7 +145,7 @@ module Torch
|
|
145
145
|
params = {}
|
146
146
|
if recurse
|
147
147
|
named_children.each do |name, mod|
|
148
|
-
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
148
|
+
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
|
149
149
|
end
|
150
150
|
end
|
151
151
|
instance_variables.each do |name|
|
@@ -186,8 +186,22 @@ module Torch
|
|
186
186
|
named_modules.values
|
187
187
|
end
|
188
188
|
|
189
|
-
|
190
|
-
|
189
|
+
# TODO return enumerator?
|
190
|
+
def named_modules(memo: nil, prefix: "")
|
191
|
+
ret = {}
|
192
|
+
memo ||= Set.new
|
193
|
+
unless memo.include?(self)
|
194
|
+
memo << self
|
195
|
+
ret[prefix] = self
|
196
|
+
named_children.each do |name, mod|
|
197
|
+
next unless mod.is_a?(Module)
|
198
|
+
submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
|
199
|
+
mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
|
200
|
+
ret[m[0]] = m[1]
|
201
|
+
end
|
202
|
+
end
|
203
|
+
end
|
204
|
+
ret
|
191
205
|
end
|
192
206
|
|
193
207
|
def train(mode = true)
|
@@ -224,13 +238,15 @@ module Torch
|
|
224
238
|
|
225
239
|
def inspect
|
226
240
|
name = self.class.name.split("::").last
|
227
|
-
if
|
241
|
+
if named_children.empty?
|
228
242
|
"#{name}(#{extra_inspect})"
|
229
243
|
else
|
230
244
|
str = String.new
|
231
245
|
str << "#{name}(\n"
|
232
|
-
|
233
|
-
|
246
|
+
named_children.each do |name, mod|
|
247
|
+
mod_str = mod.inspect
|
248
|
+
mod_str = mod_str.lines.join(" ")
|
249
|
+
str << " (#{name}): #{mod_str}\n"
|
234
250
|
end
|
235
251
|
str << ")"
|
236
252
|
end
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -4,6 +4,8 @@ module Torch
|
|
4
4
|
include Inspector
|
5
5
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
|
+
alias_method :ndim, :dim
|
8
|
+
alias_method :ndimension, :dim
|
7
9
|
|
8
10
|
def self.new(*args)
|
9
11
|
FloatTensor.new(*args)
|
@@ -23,8 +25,17 @@ module Torch
|
|
23
25
|
inspect
|
24
26
|
end
|
25
27
|
|
28
|
+
# TODO make more performant
|
26
29
|
def to_a
|
27
|
-
|
30
|
+
arr = _flat_data
|
31
|
+
if shape.empty?
|
32
|
+
arr
|
33
|
+
else
|
34
|
+
shape[1..-1].reverse.each do |dim|
|
35
|
+
arr = arr.each_slice(dim)
|
36
|
+
end
|
37
|
+
arr.to_a
|
38
|
+
end
|
28
39
|
end
|
29
40
|
|
30
41
|
# TODO support dtype
|
@@ -62,7 +73,15 @@ module Torch
|
|
62
73
|
if numel != 1
|
63
74
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
64
75
|
end
|
65
|
-
|
76
|
+
to_a.first
|
77
|
+
end
|
78
|
+
|
79
|
+
def to_i
|
80
|
+
item.to_i
|
81
|
+
end
|
82
|
+
|
83
|
+
def to_f
|
84
|
+
item.to_f
|
66
85
|
end
|
67
86
|
|
68
87
|
# unsure if this is correct
|
@@ -78,7 +97,7 @@ module Torch
|
|
78
97
|
def numo
|
79
98
|
cls = Torch._dtype_to_numo[dtype]
|
80
99
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
|
-
cls.
|
100
|
+
cls.from_string(_data_str).reshape(*shape)
|
82
101
|
end
|
83
102
|
|
84
103
|
def new_ones(*size, **options)
|
@@ -106,15 +125,6 @@ module Torch
|
|
106
125
|
_view(size)
|
107
126
|
end
|
108
127
|
|
109
|
-
# value and other are swapped for some methods
|
110
|
-
def add!(value = 1, other)
|
111
|
-
if other.is_a?(Numeric)
|
112
|
-
_add__scalar(other, value)
|
113
|
-
else
|
114
|
-
_add__tensor(other, value)
|
115
|
-
end
|
116
|
-
end
|
117
|
-
|
118
128
|
def +(other)
|
119
129
|
add(other)
|
120
130
|
end
|
@@ -143,11 +153,13 @@ module Torch
|
|
143
153
|
neg
|
144
154
|
end
|
145
155
|
|
156
|
+
# TODO better compare?
|
146
157
|
def <=>(other)
|
147
158
|
item <=> other
|
148
159
|
end
|
149
160
|
|
150
|
-
# based on python_variable_indexing.cpp
|
161
|
+
# based on python_variable_indexing.cpp and
|
162
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
151
163
|
def [](*indexes)
|
152
164
|
result = self
|
153
165
|
dim = 0
|
@@ -159,6 +171,8 @@ module Torch
|
|
159
171
|
finish += 1 unless index.exclude_end?
|
160
172
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
161
173
|
dim += 1
|
174
|
+
elsif index.is_a?(Tensor)
|
175
|
+
result = result.index([index])
|
162
176
|
elsif index.nil?
|
163
177
|
result = result.unsqueeze(dim)
|
164
178
|
dim += 1
|
@@ -172,12 +186,12 @@ module Torch
|
|
172
186
|
result
|
173
187
|
end
|
174
188
|
|
175
|
-
#
|
176
|
-
#
|
189
|
+
# based on python_variable_indexing.cpp and
|
190
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
177
191
|
def []=(index, value)
|
178
192
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
179
193
|
|
180
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
194
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
181
195
|
|
182
196
|
if index.is_a?(Numeric)
|
183
197
|
copy_to(_select_int(0, index), value)
|
@@ -185,13 +199,39 @@ module Torch
|
|
185
199
|
finish = index.end
|
186
200
|
finish += 1 unless index.exclude_end?
|
187
201
|
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
202
|
+
elsif index.is_a?(Tensor)
|
203
|
+
index_put!([index], value)
|
188
204
|
else
|
189
205
|
raise Error, "Unsupported index type: #{index.class.name}"
|
190
206
|
end
|
191
207
|
end
|
192
208
|
|
193
|
-
|
194
|
-
|
209
|
+
# native functions that need manually defined
|
210
|
+
|
211
|
+
# value and other are swapped for some methods
|
212
|
+
def add!(value = 1, other)
|
213
|
+
if other.is_a?(Numeric)
|
214
|
+
_add__scalar(other, value)
|
215
|
+
else
|
216
|
+
_add__tensor(other, value)
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
# native functions overlap, so need to handle manually
|
221
|
+
def random!(*args)
|
222
|
+
case args.size
|
223
|
+
when 1
|
224
|
+
_random__to(*args)
|
225
|
+
when 2
|
226
|
+
_random__from_to(*args)
|
227
|
+
else
|
228
|
+
_random_(*args)
|
229
|
+
end
|
230
|
+
end
|
231
|
+
|
232
|
+
def clamp!(min, max)
|
233
|
+
_clamp_min_(min)
|
234
|
+
_clamp_max_(max)
|
195
235
|
end
|
196
236
|
|
197
237
|
private
|
@@ -199,17 +239,5 @@ module Torch
|
|
199
239
|
def copy_to(dst, src)
|
200
240
|
dst.copy!(src)
|
201
241
|
end
|
202
|
-
|
203
|
-
def reshape_arr(arr, dims)
|
204
|
-
if dims.empty?
|
205
|
-
arr
|
206
|
-
else
|
207
|
-
arr = arr.flatten
|
208
|
-
dims[1..-1].reverse.each do |dim|
|
209
|
-
arr = arr.each_slice(dim)
|
210
|
-
end
|
211
|
-
arr.to_a
|
212
|
-
end
|
213
|
-
end
|
214
242
|
end
|
215
243
|
end
|
@@ -6,21 +6,49 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
|
+
@shuffle = shuffle
|
12
13
|
end
|
13
14
|
|
14
15
|
def each
|
15
|
-
|
16
|
-
|
17
|
-
|
16
|
+
# try to keep the random number generator in sync with Python
|
17
|
+
# this makes it easy to compare results
|
18
|
+
base_seed = Torch.empty([], dtype: :int64).random!.item
|
19
|
+
|
20
|
+
indexes =
|
21
|
+
if @shuffle
|
22
|
+
Torch.randperm(@dataset.size).to_a
|
23
|
+
else
|
24
|
+
@dataset.size.times
|
25
|
+
end
|
26
|
+
|
27
|
+
indexes.each_slice(@batch_size) do |idx|
|
28
|
+
batch = idx.map { |i| @dataset[i] }
|
29
|
+
yield collate(batch)
|
18
30
|
end
|
19
31
|
end
|
20
32
|
|
21
33
|
def size
|
22
34
|
(@dataset.size / @batch_size.to_f).ceil
|
23
35
|
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def collate(batch)
|
40
|
+
elem = batch[0]
|
41
|
+
case elem
|
42
|
+
when Tensor
|
43
|
+
Torch.stack(batch, 0)
|
44
|
+
when Integer
|
45
|
+
Torch.tensor(batch)
|
46
|
+
when Array
|
47
|
+
batch.transpose.map { |v| collate(v) }
|
48
|
+
else
|
49
|
+
raise NotImpelmentYet
|
50
|
+
end
|
51
|
+
end
|
24
52
|
end
|
25
53
|
end
|
26
54
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-06-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -95,19 +95,19 @@ dependencies:
|
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '0'
|
97
97
|
- !ruby/object:Gem::Dependency
|
98
|
-
name:
|
98
|
+
name: torchvision
|
99
99
|
requirement: !ruby/object:Gem::Requirement
|
100
100
|
requirements:
|
101
101
|
- - ">="
|
102
102
|
- !ruby/object:Gem::Version
|
103
|
-
version:
|
103
|
+
version: 0.1.1
|
104
104
|
type: :development
|
105
105
|
prerelease: false
|
106
106
|
version_requirements: !ruby/object:Gem::Requirement
|
107
107
|
requirements:
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
|
-
version:
|
110
|
+
version: 0.1.1
|
111
111
|
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
@@ -258,9 +258,9 @@ files:
|
|
258
258
|
- lib/torch/optim/rmsprop.rb
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
|
-
- lib/torch/random.rb
|
262
261
|
- lib/torch/tensor.rb
|
263
262
|
- lib/torch/utils/data/data_loader.rb
|
263
|
+
- lib/torch/utils/data/dataset.rb
|
264
264
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
265
|
- lib/torch/version.rb
|
266
266
|
homepage: https://github.com/ankane/torch.rb
|