torch-rb 0.2.1 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +20 -8
- data/ext/torch/ext.cpp +87 -22
- data/ext/torch/extconf.rb +21 -18
- data/lib/torch.rb +14 -5
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +3 -3
- data/lib/torch/native/function.rb +1 -0
- data/lib/torch/native/generator.rb +5 -2
- data/lib/torch/native/parser.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +22 -6
- data/lib/torch/optim/rprop.rb +0 -3
- data/lib/torch/tensor.rb +58 -30
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -6
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
|
4
|
+
data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
|
7
|
+
data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## 0.2.6 (2020-06-29)
|
2
|
+
|
3
|
+
- Added support for indexing with tensors
|
4
|
+
- Added `contiguous` methods
|
5
|
+
- Fixed named parameters for nested parameters
|
6
|
+
|
7
|
+
## 0.2.5 (2020-06-07)
|
8
|
+
|
9
|
+
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
10
|
+
- Improved error messages
|
11
|
+
- Fixed tensor slicing
|
12
|
+
|
13
|
+
## 0.2.4 (2020-04-29)
|
14
|
+
|
15
|
+
- Added `to_i` and `to_f` to tensors
|
16
|
+
- Added `shuffle` option to data loader
|
17
|
+
- Fixed `modules` and `named_modules` for nested modules
|
18
|
+
|
19
|
+
## 0.2.3 (2020-04-28)
|
20
|
+
|
21
|
+
- Added `show_config` and `parallel_info` methods
|
22
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
23
|
+
- Improved data loader
|
24
|
+
- Build with MKL-DNN and NNPACK when available
|
25
|
+
- Fixed `inspect` for modules
|
26
|
+
|
27
|
+
## 0.2.2 (2020-04-27)
|
28
|
+
|
29
|
+
- Added support for saving tensor lists
|
30
|
+
- Added `ndim` and `ndimension` methods to tensors
|
31
|
+
|
1
32
|
## 0.2.1 (2020-04-26)
|
2
33
|
|
3
34
|
- Added support for saving and loading models
|
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -294,7 +308,7 @@ Torch.save(net.state_dict, "net.pth")
|
|
294
308
|
Load a model
|
295
309
|
|
296
310
|
```ruby
|
297
|
-
net =
|
311
|
+
net = MyNet.new
|
298
312
|
net.load_state_dict(Torch.load("net.pth"))
|
299
313
|
net.eval
|
300
314
|
```
|
@@ -395,7 +409,7 @@ Here’s the list of compatible versions.
|
|
395
409
|
|
396
410
|
Torch.rb | LibTorch
|
397
411
|
--- | ---
|
398
|
-
0.2.0 | 1.5.0
|
412
|
+
0.2.0+ | 1.5.0+
|
399
413
|
0.1.8 | 1.4.0
|
400
414
|
0.1.0-0.1.7 | 1.3.1
|
401
415
|
|
@@ -413,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
413
427
|
|
414
428
|
### Linux
|
415
429
|
|
416
|
-
Deep learning is significantly faster on
|
417
|
-
|
418
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
419
431
|
|
420
432
|
Check if CUDA is available
|
421
433
|
|
@@ -426,7 +438,7 @@ Torch::CUDA.available?
|
|
426
438
|
Move a neural network to a GPU
|
427
439
|
|
428
440
|
```ruby
|
429
|
-
net.
|
441
|
+
net.cuda
|
430
442
|
```
|
431
443
|
|
432
444
|
## rbenv
|
data/ext/torch/ext.cpp
CHANGED
@@ -23,7 +23,7 @@ class Parameter: public torch::autograd::Variable {
|
|
23
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
24
24
|
};
|
25
25
|
|
26
|
-
void handle_error(
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
27
|
{
|
28
28
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
29
|
}
|
@@ -32,16 +32,34 @@ extern "C"
|
|
32
32
|
void Init_ext()
|
33
33
|
{
|
34
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
35
36
|
add_torch_functions(rb_mTorch);
|
36
37
|
|
37
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
38
40
|
add_tensor_functions(rb_cTensor);
|
39
41
|
|
40
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
41
44
|
add_nn_functions(rb_mNN);
|
42
45
|
|
46
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
48
|
+
.define_singleton_method(
|
49
|
+
"initial_seed",
|
50
|
+
*[]() {
|
51
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
52
|
+
})
|
53
|
+
.define_singleton_method(
|
54
|
+
"seed",
|
55
|
+
*[]() {
|
56
|
+
// TODO set for CUDA when available
|
57
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
58
|
+
});
|
59
|
+
|
43
60
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
44
61
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
45
63
|
.define_constructor(Constructor<torch::IValue>())
|
46
64
|
.define_method("bool?", &torch::IValue::isBool)
|
47
65
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
@@ -82,6 +100,16 @@ void Init_ext()
|
|
82
100
|
*[](torch::IValue& self) {
|
83
101
|
return self.toInt();
|
84
102
|
})
|
103
|
+
.define_method(
|
104
|
+
"to_list",
|
105
|
+
*[](torch::IValue& self) {
|
106
|
+
auto list = self.toListRef();
|
107
|
+
Array obj;
|
108
|
+
for (auto& elem : list) {
|
109
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
110
|
+
}
|
111
|
+
return obj;
|
112
|
+
})
|
85
113
|
.define_method(
|
86
114
|
"to_string_ref",
|
87
115
|
*[](torch::IValue& self) {
|
@@ -96,17 +124,27 @@ void Init_ext()
|
|
96
124
|
"to_generic_dict",
|
97
125
|
*[](torch::IValue& self) {
|
98
126
|
auto dict = self.toGenericDict();
|
99
|
-
Hash
|
127
|
+
Hash obj;
|
100
128
|
for (auto& pair : dict) {
|
101
|
-
|
129
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
102
130
|
}
|
103
|
-
return
|
131
|
+
return obj;
|
104
132
|
})
|
105
133
|
.define_singleton_method(
|
106
134
|
"from_tensor",
|
107
135
|
*[](torch::Tensor& v) {
|
108
136
|
return torch::IValue(v);
|
109
137
|
})
|
138
|
+
// TODO create specialized list types?
|
139
|
+
.define_singleton_method(
|
140
|
+
"from_list",
|
141
|
+
*[](Array obj) {
|
142
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
143
|
+
for (auto entry : obj) {
|
144
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
145
|
+
}
|
146
|
+
return torch::IValue(list);
|
147
|
+
})
|
110
148
|
.define_singleton_method(
|
111
149
|
"from_string",
|
112
150
|
*[](String v) {
|
@@ -157,6 +195,17 @@ void Init_ext()
|
|
157
195
|
*[](uint64_t seed) {
|
158
196
|
return torch::manual_seed(seed);
|
159
197
|
})
|
198
|
+
// config
|
199
|
+
.define_singleton_method(
|
200
|
+
"show_config",
|
201
|
+
*[] {
|
202
|
+
return torch::show_config();
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"parallel_info",
|
206
|
+
*[] {
|
207
|
+
return torch::get_parallel_info();
|
208
|
+
})
|
160
209
|
// begin tensor creation
|
161
210
|
.define_singleton_method(
|
162
211
|
"_arange",
|
@@ -273,7 +322,6 @@ void Init_ext()
|
|
273
322
|
});
|
274
323
|
|
275
324
|
rb_cTensor
|
276
|
-
.add_handler<c10::Error>(handle_error)
|
277
325
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
278
326
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
279
327
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
@@ -281,6 +329,11 @@ void Init_ext()
|
|
281
329
|
.define_method("numel", &torch::Tensor::numel)
|
282
330
|
.define_method("element_size", &torch::Tensor::element_size)
|
283
331
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
332
|
+
.define_method(
|
333
|
+
"contiguous?",
|
334
|
+
*[](Tensor& self) {
|
335
|
+
return self.is_contiguous();
|
336
|
+
})
|
284
337
|
.define_method(
|
285
338
|
"addcmul!",
|
286
339
|
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
@@ -330,6 +383,21 @@ void Init_ext()
|
|
330
383
|
s << self.device();
|
331
384
|
return s.str();
|
332
385
|
})
|
386
|
+
.define_method(
|
387
|
+
"_data_str",
|
388
|
+
*[](Tensor& self) {
|
389
|
+
Tensor tensor = self;
|
390
|
+
|
391
|
+
// move to CPU to get data
|
392
|
+
if (tensor.device().type() != torch::kCPU) {
|
393
|
+
torch::Device device("cpu");
|
394
|
+
tensor = tensor.to(device);
|
395
|
+
}
|
396
|
+
|
397
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
398
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
399
|
+
})
|
400
|
+
// TODO figure out a better way to do this
|
333
401
|
.define_method(
|
334
402
|
"_flat_data",
|
335
403
|
*[](Tensor& self) {
|
@@ -344,46 +412,40 @@ void Init_ext()
|
|
344
412
|
Array a;
|
345
413
|
auto dtype = tensor.dtype();
|
346
414
|
|
415
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
416
|
+
|
347
417
|
// TODO DRY if someone knows C++
|
348
418
|
if (dtype == torch::kByte) {
|
349
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
350
419
|
for (int i = 0; i < tensor.numel(); i++) {
|
351
|
-
a.push(
|
420
|
+
a.push(view[i].item().to<uint8_t>());
|
352
421
|
}
|
353
422
|
} else if (dtype == torch::kChar) {
|
354
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
355
423
|
for (int i = 0; i < tensor.numel(); i++) {
|
356
|
-
a.push(to_ruby<int>(
|
424
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
357
425
|
}
|
358
426
|
} else if (dtype == torch::kShort) {
|
359
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
360
427
|
for (int i = 0; i < tensor.numel(); i++) {
|
361
|
-
a.push(
|
428
|
+
a.push(view[i].item().to<int16_t>());
|
362
429
|
}
|
363
430
|
} else if (dtype == torch::kInt) {
|
364
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
365
431
|
for (int i = 0; i < tensor.numel(); i++) {
|
366
|
-
a.push(
|
432
|
+
a.push(view[i].item().to<int32_t>());
|
367
433
|
}
|
368
434
|
} else if (dtype == torch::kLong) {
|
369
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
370
435
|
for (int i = 0; i < tensor.numel(); i++) {
|
371
|
-
a.push(
|
436
|
+
a.push(view[i].item().to<int64_t>());
|
372
437
|
}
|
373
438
|
} else if (dtype == torch::kFloat) {
|
374
|
-
float* data = tensor.data_ptr<float>();
|
375
439
|
for (int i = 0; i < tensor.numel(); i++) {
|
376
|
-
a.push(
|
440
|
+
a.push(view[i].item().to<float>());
|
377
441
|
}
|
378
442
|
} else if (dtype == torch::kDouble) {
|
379
|
-
double* data = tensor.data_ptr<double>();
|
380
443
|
for (int i = 0; i < tensor.numel(); i++) {
|
381
|
-
a.push(
|
444
|
+
a.push(view[i].item().to<double>());
|
382
445
|
}
|
383
446
|
} else if (dtype == torch::kBool) {
|
384
|
-
bool* data = tensor.data_ptr<bool>();
|
385
447
|
for (int i = 0; i < tensor.numel(); i++) {
|
386
|
-
a.push(
|
448
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
387
449
|
}
|
388
450
|
} else {
|
389
451
|
throw std::runtime_error("Unsupported type");
|
@@ -405,7 +467,7 @@ void Init_ext()
|
|
405
467
|
});
|
406
468
|
|
407
469
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
408
|
-
.add_handler<
|
470
|
+
.add_handler<torch::Error>(handle_error)
|
409
471
|
.define_constructor(Constructor<torch::TensorOptions>())
|
410
472
|
.define_method(
|
411
473
|
"dtype",
|
@@ -511,6 +573,7 @@ void Init_ext()
|
|
511
573
|
});
|
512
574
|
|
513
575
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
576
|
+
.add_handler<torch::Error>(handle_error)
|
514
577
|
.define_method(
|
515
578
|
"grad",
|
516
579
|
*[](Parameter& self) {
|
@@ -520,6 +583,7 @@ void Init_ext()
|
|
520
583
|
|
521
584
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
522
585
|
.define_constructor(Constructor<torch::Device, std::string>())
|
586
|
+
.add_handler<torch::Error>(handle_error)
|
523
587
|
.define_method("index", &torch::Device::index)
|
524
588
|
.define_method("index?", &torch::Device::has_index)
|
525
589
|
.define_method(
|
@@ -531,6 +595,7 @@ void Init_ext()
|
|
531
595
|
});
|
532
596
|
|
533
597
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
598
|
+
.add_handler<torch::Error>(handle_error)
|
534
599
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
535
600
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
536
601
|
}
|
data/ext/torch/extconf.rb
CHANGED
@@ -2,33 +2,33 @@ require "mkmf-rice"
|
|
2
2
|
|
3
3
|
abort "Missing stdc++" unless have_library("stdc++")
|
4
4
|
|
5
|
-
$CXXFLAGS
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
6
|
|
7
7
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
|
-
$CXXFLAGS
|
8
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
9
9
|
|
10
10
|
# TODO check compiler name
|
11
11
|
clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
|
12
12
|
|
13
13
|
# check omp first
|
14
14
|
if have_library("omp") || have_library("gomp")
|
15
|
-
$CXXFLAGS
|
16
|
-
$CXXFLAGS
|
17
|
-
$CXXFLAGS
|
15
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
16
|
+
$CXXFLAGS += " -Xclang" if clang
|
17
|
+
$CXXFLAGS += " -fopenmp"
|
18
18
|
end
|
19
19
|
|
20
20
|
if clang
|
21
21
|
# silence ruby/intern.h warning
|
22
|
-
$CXXFLAGS
|
22
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
23
23
|
|
24
24
|
# silence torch warnings
|
25
|
-
$CXXFLAGS
|
25
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
26
26
|
else
|
27
27
|
# silence rice warnings
|
28
|
-
$CXXFLAGS
|
28
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
29
29
|
|
30
30
|
# silence torch warnings
|
31
|
-
$CXXFLAGS
|
31
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
32
32
|
end
|
33
33
|
|
34
34
|
inc, lib = dir_config("torch")
|
@@ -39,27 +39,30 @@ cuda_inc, cuda_lib = dir_config("cuda")
|
|
39
39
|
cuda_inc ||= "/usr/local/cuda/include"
|
40
40
|
cuda_lib ||= "/usr/local/cuda/lib64"
|
41
41
|
|
42
|
-
$LDFLAGS
|
42
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
43
43
|
abort "LibTorch not found" unless have_library("torch")
|
44
44
|
|
45
|
+
have_library("mkldnn")
|
46
|
+
have_library("nnpack")
|
47
|
+
|
45
48
|
with_cuda = false
|
46
49
|
if Dir["#{lib}/*torch_cuda*"].any?
|
47
|
-
$LDFLAGS
|
50
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
48
51
|
with_cuda = have_library("cuda") && have_library("cudnn")
|
49
52
|
end
|
50
53
|
|
51
|
-
$INCFLAGS
|
52
|
-
$INCFLAGS
|
54
|
+
$INCFLAGS += " -I#{inc}"
|
55
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
56
|
|
54
|
-
$LDFLAGS
|
55
|
-
$LDFLAGS
|
57
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
58
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
58
|
-
$LDFLAGS
|
61
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
59
62
|
if with_cuda
|
60
|
-
$LDFLAGS
|
63
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
61
64
|
# TODO figure out why this is needed
|
62
|
-
$LDFLAGS
|
65
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
63
66
|
end
|
64
67
|
|
65
68
|
# generate C++ functions
|
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# stdlib
|
5
|
+
require "fileutils"
|
6
|
+
require "net/http"
|
7
|
+
require "tmpdir"
|
8
|
+
|
4
9
|
# native functions
|
5
10
|
require "torch/native/generator"
|
6
11
|
require "torch/native/parser"
|
@@ -174,11 +179,9 @@ require "torch/nn/init"
|
|
174
179
|
|
175
180
|
# utils
|
176
181
|
require "torch/utils/data/data_loader"
|
182
|
+
require "torch/utils/data/dataset"
|
177
183
|
require "torch/utils/data/tensor_dataset"
|
178
184
|
|
179
|
-
# random
|
180
|
-
require "torch/random"
|
181
|
-
|
182
185
|
# hub
|
183
186
|
require "torch/hub"
|
184
187
|
|
@@ -466,6 +469,12 @@ module Torch
|
|
466
469
|
IValue.from_bool(obj)
|
467
470
|
when nil
|
468
471
|
IValue.new
|
472
|
+
when Array
|
473
|
+
if obj.all? { |v| v.is_a?(Tensor) }
|
474
|
+
IValue.from_list(obj.map { |v| IValue.from_tensor(v) })
|
475
|
+
else
|
476
|
+
raise Error, "Unknown list type"
|
477
|
+
end
|
469
478
|
else
|
470
479
|
raise Error, "Unknown type: #{obj.class.name}"
|
471
480
|
end
|
@@ -490,6 +499,8 @@ module Torch
|
|
490
499
|
dict[to_ruby(k)] = to_ruby(v)
|
491
500
|
end
|
492
501
|
dict
|
502
|
+
elsif ivalue.list?
|
503
|
+
ivalue.to_list.map { |v| to_ruby(v) }
|
493
504
|
else
|
494
505
|
type =
|
495
506
|
if ivalue.capsule?
|
@@ -510,8 +521,6 @@ module Torch
|
|
510
521
|
"BoolList"
|
511
522
|
elsif ivalue.tensor_list?
|
512
523
|
"TensorList"
|
513
|
-
elsif ivalue.list?
|
514
|
-
"List"
|
515
524
|
elsif ivalue.object?
|
516
525
|
"Object"
|
517
526
|
elsif ivalue.module?
|
data/lib/torch/hub.rb
CHANGED
@@ -4,6 +4,58 @@ module Torch
|
|
4
4
|
def list(github, force_reload: false)
|
5
5
|
raise NotImplementedYet
|
6
6
|
end
|
7
|
+
|
8
|
+
def download_url_to_file(url, dst)
|
9
|
+
uri = URI(url)
|
10
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
11
|
+
location = nil
|
12
|
+
|
13
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
14
|
+
request = Net::HTTP::Get.new(uri)
|
15
|
+
|
16
|
+
puts "Downloading #{url}..."
|
17
|
+
File.open(tmp, "wb") do |f|
|
18
|
+
http.request(request) do |response|
|
19
|
+
case response
|
20
|
+
when Net::HTTPRedirection
|
21
|
+
location = response["location"]
|
22
|
+
when Net::HTTPSuccess
|
23
|
+
response.read_body do |chunk|
|
24
|
+
f.write(chunk)
|
25
|
+
end
|
26
|
+
else
|
27
|
+
raise Error, "Bad response"
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
if location
|
34
|
+
download_url_to_file(location, dst)
|
35
|
+
else
|
36
|
+
FileUtils.mv(tmp, dst)
|
37
|
+
nil
|
38
|
+
end
|
39
|
+
end
|
40
|
+
|
41
|
+
def load_state_dict_from_url(url, model_dir: nil)
|
42
|
+
unless model_dir
|
43
|
+
torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
|
44
|
+
model_dir = File.join(torch_home, "checkpoints")
|
45
|
+
end
|
46
|
+
|
47
|
+
FileUtils.mkdir_p(model_dir)
|
48
|
+
|
49
|
+
parts = URI(url)
|
50
|
+
filename = File.basename(parts.path)
|
51
|
+
cached_file = File.join(model_dir, filename)
|
52
|
+
unless File.exist?(cached_file)
|
53
|
+
# TODO support hash_prefix
|
54
|
+
download_url_to_file(url, cached_file)
|
55
|
+
end
|
56
|
+
|
57
|
+
Torch.load(cached_file)
|
58
|
+
end
|
7
59
|
end
|
8
60
|
end
|
9
61
|
end
|
data/lib/torch/inspector.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module Torch
|
2
2
|
module Inspector
|
3
|
-
# TODO make more
|
3
|
+
# TODO make more performant, especially when summarizing
|
4
4
|
# how? only read data that will be displayed
|
5
5
|
def inspect
|
6
6
|
data =
|
@@ -14,7 +14,7 @@ module Torch
|
|
14
14
|
if dtype == :bool
|
15
15
|
fmt = "%s"
|
16
16
|
else
|
17
|
-
values =
|
17
|
+
values = _flat_data
|
18
18
|
abs = values.select { |v| v != 0 }.map(&:abs)
|
19
19
|
max = abs.max || 1
|
20
20
|
min = abs.min || 1
|
@@ -25,7 +25,7 @@ module Torch
|
|
25
25
|
end
|
26
26
|
|
27
27
|
if floating_point?
|
28
|
-
sci = max
|
28
|
+
sci = max > 1e8 || max < 1e-4
|
29
29
|
|
30
30
|
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
31
31
|
decimal = all_int ? 1 : 4
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["bool[3]", "Dimname", "
|
21
|
+
skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
@@ -31,7 +31,7 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
|
34
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
35
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
36
|
# native_functions.yaml is missing size argument for normal
|
37
37
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
@@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) {
|
|
112
112
|
"OptionalScalarType"
|
113
113
|
when "Tensor[]"
|
114
114
|
"TensorList"
|
115
|
+
when "Tensor?[]"
|
116
|
+
# TODO make optional
|
117
|
+
"TensorList"
|
115
118
|
when "int"
|
116
119
|
"int64_t"
|
117
120
|
when "float"
|
data/lib/torch/native/parser.rb
CHANGED
data/lib/torch/nn/batch_norm.rb
CHANGED
@@ -70,6 +70,11 @@ module Torch
|
|
70
70
|
momentum: exponential_average_factor, eps: @eps
|
71
71
|
)
|
72
72
|
end
|
73
|
+
|
74
|
+
def extra_inspect
|
75
|
+
s = "%{num_features}, eps: %{eps}, momentum: %{momentum}, affine: %{affine}, track_running_stats: %{track_running_stats}"
|
76
|
+
format(s, **dict)
|
77
|
+
end
|
73
78
|
end
|
74
79
|
end
|
75
80
|
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -20,7 +20,14 @@ module Torch
|
|
20
20
|
|
21
21
|
# TODO add more parameters
|
22
22
|
def extra_inspect
|
23
|
-
|
23
|
+
s = String.new("%{in_channels}, %{out_channels}, kernel_size: %{kernel_size}, stride: %{stride}")
|
24
|
+
s += ", padding: %{padding}" if @padding != [0] * @padding.size
|
25
|
+
s += ", dilation: %{dilation}" if @dilation != [1] * @dilation.size
|
26
|
+
s += ", output_padding: %{output_padding}" if @output_padding != [0] * @output_padding.size
|
27
|
+
s += ", groups: %{groups}" if @groups != 1
|
28
|
+
s += ", bias: false" unless @bias
|
29
|
+
s += ", padding_mode: %{padding_mode}" if @padding_mode != "zeros"
|
30
|
+
format(s, **dict)
|
24
31
|
end
|
25
32
|
end
|
26
33
|
end
|
data/lib/torch/nn/convnd.rb
CHANGED
data/lib/torch/nn/max_poolnd.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -145,7 +145,7 @@ module Torch
|
|
145
145
|
params = {}
|
146
146
|
if recurse
|
147
147
|
named_children.each do |name, mod|
|
148
|
-
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
148
|
+
params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
|
149
149
|
end
|
150
150
|
end
|
151
151
|
instance_variables.each do |name|
|
@@ -186,8 +186,22 @@ module Torch
|
|
186
186
|
named_modules.values
|
187
187
|
end
|
188
188
|
|
189
|
-
|
190
|
-
|
189
|
+
# TODO return enumerator?
|
190
|
+
def named_modules(memo: nil, prefix: "")
|
191
|
+
ret = {}
|
192
|
+
memo ||= Set.new
|
193
|
+
unless memo.include?(self)
|
194
|
+
memo << self
|
195
|
+
ret[prefix] = self
|
196
|
+
named_children.each do |name, mod|
|
197
|
+
next unless mod.is_a?(Module)
|
198
|
+
submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
|
199
|
+
mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
|
200
|
+
ret[m[0]] = m[1]
|
201
|
+
end
|
202
|
+
end
|
203
|
+
end
|
204
|
+
ret
|
191
205
|
end
|
192
206
|
|
193
207
|
def train(mode = true)
|
@@ -224,13 +238,15 @@ module Torch
|
|
224
238
|
|
225
239
|
def inspect
|
226
240
|
name = self.class.name.split("::").last
|
227
|
-
if
|
241
|
+
if named_children.empty?
|
228
242
|
"#{name}(#{extra_inspect})"
|
229
243
|
else
|
230
244
|
str = String.new
|
231
245
|
str << "#{name}(\n"
|
232
|
-
|
233
|
-
|
246
|
+
named_children.each do |name, mod|
|
247
|
+
mod_str = mod.inspect
|
248
|
+
mod_str = mod_str.lines.join(" ")
|
249
|
+
str << " (#{name}): #{mod_str}\n"
|
234
250
|
end
|
235
251
|
str << ")"
|
236
252
|
end
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -4,6 +4,8 @@ module Torch
|
|
4
4
|
include Inspector
|
5
5
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
|
+
alias_method :ndim, :dim
|
8
|
+
alias_method :ndimension, :dim
|
7
9
|
|
8
10
|
def self.new(*args)
|
9
11
|
FloatTensor.new(*args)
|
@@ -23,8 +25,17 @@ module Torch
|
|
23
25
|
inspect
|
24
26
|
end
|
25
27
|
|
28
|
+
# TODO make more performant
|
26
29
|
def to_a
|
27
|
-
|
30
|
+
arr = _flat_data
|
31
|
+
if shape.empty?
|
32
|
+
arr
|
33
|
+
else
|
34
|
+
shape[1..-1].reverse.each do |dim|
|
35
|
+
arr = arr.each_slice(dim)
|
36
|
+
end
|
37
|
+
arr.to_a
|
38
|
+
end
|
28
39
|
end
|
29
40
|
|
30
41
|
# TODO support dtype
|
@@ -62,7 +73,15 @@ module Torch
|
|
62
73
|
if numel != 1
|
63
74
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
64
75
|
end
|
65
|
-
|
76
|
+
to_a.first
|
77
|
+
end
|
78
|
+
|
79
|
+
def to_i
|
80
|
+
item.to_i
|
81
|
+
end
|
82
|
+
|
83
|
+
def to_f
|
84
|
+
item.to_f
|
66
85
|
end
|
67
86
|
|
68
87
|
# unsure if this is correct
|
@@ -78,7 +97,7 @@ module Torch
|
|
78
97
|
def numo
|
79
98
|
cls = Torch._dtype_to_numo[dtype]
|
80
99
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
|
-
cls.
|
100
|
+
cls.from_string(_data_str).reshape(*shape)
|
82
101
|
end
|
83
102
|
|
84
103
|
def new_ones(*size, **options)
|
@@ -106,15 +125,6 @@ module Torch
|
|
106
125
|
_view(size)
|
107
126
|
end
|
108
127
|
|
109
|
-
# value and other are swapped for some methods
|
110
|
-
def add!(value = 1, other)
|
111
|
-
if other.is_a?(Numeric)
|
112
|
-
_add__scalar(other, value)
|
113
|
-
else
|
114
|
-
_add__tensor(other, value)
|
115
|
-
end
|
116
|
-
end
|
117
|
-
|
118
128
|
def +(other)
|
119
129
|
add(other)
|
120
130
|
end
|
@@ -143,11 +153,13 @@ module Torch
|
|
143
153
|
neg
|
144
154
|
end
|
145
155
|
|
156
|
+
# TODO better compare?
|
146
157
|
def <=>(other)
|
147
158
|
item <=> other
|
148
159
|
end
|
149
160
|
|
150
|
-
# based on python_variable_indexing.cpp
|
161
|
+
# based on python_variable_indexing.cpp and
|
162
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
151
163
|
def [](*indexes)
|
152
164
|
result = self
|
153
165
|
dim = 0
|
@@ -159,6 +171,8 @@ module Torch
|
|
159
171
|
finish += 1 unless index.exclude_end?
|
160
172
|
result = result._slice_tensor(dim, index.begin, finish, 1)
|
161
173
|
dim += 1
|
174
|
+
elsif index.is_a?(Tensor)
|
175
|
+
result = result.index([index])
|
162
176
|
elsif index.nil?
|
163
177
|
result = result.unsqueeze(dim)
|
164
178
|
dim += 1
|
@@ -172,12 +186,12 @@ module Torch
|
|
172
186
|
result
|
173
187
|
end
|
174
188
|
|
175
|
-
#
|
176
|
-
#
|
189
|
+
# based on python_variable_indexing.cpp and
|
190
|
+
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
177
191
|
def []=(index, value)
|
178
192
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
179
193
|
|
180
|
-
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
194
|
+
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
181
195
|
|
182
196
|
if index.is_a?(Numeric)
|
183
197
|
copy_to(_select_int(0, index), value)
|
@@ -185,13 +199,39 @@ module Torch
|
|
185
199
|
finish = index.end
|
186
200
|
finish += 1 unless index.exclude_end?
|
187
201
|
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
202
|
+
elsif index.is_a?(Tensor)
|
203
|
+
index_put!([index], value)
|
188
204
|
else
|
189
205
|
raise Error, "Unsupported index type: #{index.class.name}"
|
190
206
|
end
|
191
207
|
end
|
192
208
|
|
193
|
-
|
194
|
-
|
209
|
+
# native functions that need manually defined
|
210
|
+
|
211
|
+
# value and other are swapped for some methods
|
212
|
+
def add!(value = 1, other)
|
213
|
+
if other.is_a?(Numeric)
|
214
|
+
_add__scalar(other, value)
|
215
|
+
else
|
216
|
+
_add__tensor(other, value)
|
217
|
+
end
|
218
|
+
end
|
219
|
+
|
220
|
+
# native functions overlap, so need to handle manually
|
221
|
+
def random!(*args)
|
222
|
+
case args.size
|
223
|
+
when 1
|
224
|
+
_random__to(*args)
|
225
|
+
when 2
|
226
|
+
_random__from_to(*args)
|
227
|
+
else
|
228
|
+
_random_(*args)
|
229
|
+
end
|
230
|
+
end
|
231
|
+
|
232
|
+
def clamp!(min, max)
|
233
|
+
_clamp_min_(min)
|
234
|
+
_clamp_max_(max)
|
195
235
|
end
|
196
236
|
|
197
237
|
private
|
@@ -199,17 +239,5 @@ module Torch
|
|
199
239
|
def copy_to(dst, src)
|
200
240
|
dst.copy!(src)
|
201
241
|
end
|
202
|
-
|
203
|
-
def reshape_arr(arr, dims)
|
204
|
-
if dims.empty?
|
205
|
-
arr
|
206
|
-
else
|
207
|
-
arr = arr.flatten
|
208
|
-
dims[1..-1].reverse.each do |dim|
|
209
|
-
arr = arr.each_slice(dim)
|
210
|
-
end
|
211
|
-
arr.to_a
|
212
|
-
end
|
213
|
-
end
|
214
242
|
end
|
215
243
|
end
|
@@ -6,21 +6,49 @@ module Torch
|
|
6
6
|
|
7
7
|
attr_reader :dataset
|
8
8
|
|
9
|
-
def initialize(dataset, batch_size: 1)
|
9
|
+
def initialize(dataset, batch_size: 1, shuffle: false)
|
10
10
|
@dataset = dataset
|
11
11
|
@batch_size = batch_size
|
12
|
+
@shuffle = shuffle
|
12
13
|
end
|
13
14
|
|
14
15
|
def each
|
15
|
-
|
16
|
-
|
17
|
-
|
16
|
+
# try to keep the random number generator in sync with Python
|
17
|
+
# this makes it easy to compare results
|
18
|
+
base_seed = Torch.empty([], dtype: :int64).random!.item
|
19
|
+
|
20
|
+
indexes =
|
21
|
+
if @shuffle
|
22
|
+
Torch.randperm(@dataset.size).to_a
|
23
|
+
else
|
24
|
+
@dataset.size.times
|
25
|
+
end
|
26
|
+
|
27
|
+
indexes.each_slice(@batch_size) do |idx|
|
28
|
+
batch = idx.map { |i| @dataset[i] }
|
29
|
+
yield collate(batch)
|
18
30
|
end
|
19
31
|
end
|
20
32
|
|
21
33
|
def size
|
22
34
|
(@dataset.size / @batch_size.to_f).ceil
|
23
35
|
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def collate(batch)
|
40
|
+
elem = batch[0]
|
41
|
+
case elem
|
42
|
+
when Tensor
|
43
|
+
Torch.stack(batch, 0)
|
44
|
+
when Integer
|
45
|
+
Torch.tensor(batch)
|
46
|
+
when Array
|
47
|
+
batch.transpose.map { |v| collate(v) }
|
48
|
+
else
|
49
|
+
raise NotImpelmentYet
|
50
|
+
end
|
51
|
+
end
|
24
52
|
end
|
25
53
|
end
|
26
54
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-06-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -95,19 +95,19 @@ dependencies:
|
|
95
95
|
- !ruby/object:Gem::Version
|
96
96
|
version: '0'
|
97
97
|
- !ruby/object:Gem::Dependency
|
98
|
-
name:
|
98
|
+
name: torchvision
|
99
99
|
requirement: !ruby/object:Gem::Requirement
|
100
100
|
requirements:
|
101
101
|
- - ">="
|
102
102
|
- !ruby/object:Gem::Version
|
103
|
-
version:
|
103
|
+
version: 0.1.1
|
104
104
|
type: :development
|
105
105
|
prerelease: false
|
106
106
|
version_requirements: !ruby/object:Gem::Requirement
|
107
107
|
requirements:
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
|
-
version:
|
110
|
+
version: 0.1.1
|
111
111
|
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
@@ -258,9 +258,9 @@ files:
|
|
258
258
|
- lib/torch/optim/rmsprop.rb
|
259
259
|
- lib/torch/optim/rprop.rb
|
260
260
|
- lib/torch/optim/sgd.rb
|
261
|
-
- lib/torch/random.rb
|
262
261
|
- lib/torch/tensor.rb
|
263
262
|
- lib/torch/utils/data/data_loader.rb
|
263
|
+
- lib/torch/utils/data/dataset.rb
|
264
264
|
- lib/torch/utils/data/tensor_dataset.rb
|
265
265
|
- lib/torch/version.rb
|
266
266
|
homepage: https://github.com/ankane/torch.rb
|