torch-rb 0.2.0 → 0.2.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +36 -6
- data/ext/torch/ext.cpp +197 -24
- data/ext/torch/extconf.rb +34 -21
- data/lib/torch.rb +102 -6
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +3 -3
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +45 -8
- data/lib/torch/tensor.rb +48 -26
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -13
- data/ext/torch/nn_functions.cpp +0 -560
- data/ext/torch/nn_functions.hpp +0 -6
- data/ext/torch/tensor_functions.cpp +0 -2085
- data/ext/torch/tensor_functions.hpp +0 -6
- data/ext/torch/torch_functions.cpp +0 -3175
- data/ext/torch/torch_functions.hpp +0 -6
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8f6ab78fb5cff27d0d60ddb9c08fb2f526bd60e241dd1011554b21716bdd2f43
|
4
|
+
data.tar.gz: 5568f53d8d5d688e3f29fb55ddbe9457e0b933dc69f59b422035c0cee249e396
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9c5dcfbf35382b37678662690677b2b90d0d544e8802703cf83dba6e10483a4df487f9687e4e898c9cc449c568f4e02f9d831daa982b5b3135af6f9ce176ec88
|
7
|
+
data.tar.gz: 34f142d874606e140661ae992a9f8cd4779f95c93c11d9a89a1864dd0bd53c5480c30d9aec5897f1955e2450bd0a3bc56ed0868e3b54d82ff4cdba40af379840
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## 0.2.5 (2020-06-07)
|
2
|
+
|
3
|
+
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
4
|
+
- Improved error messages
|
5
|
+
- Fixed tensor slicing
|
6
|
+
|
7
|
+
## 0.2.4 (2020-04-29)
|
8
|
+
|
9
|
+
- Added `to_i` and `to_f` to tensors
|
10
|
+
- Added `shuffle` option to data loader
|
11
|
+
- Fixed `modules` and `named_modules` for nested modules
|
12
|
+
|
13
|
+
## 0.2.3 (2020-04-28)
|
14
|
+
|
15
|
+
- Added `show_config` and `parallel_info` methods
|
16
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
17
|
+
- Improved data loader
|
18
|
+
- Build with MKL-DNN and NNPACK when available
|
19
|
+
- Fixed `inspect` for modules
|
20
|
+
|
21
|
+
## 0.2.2 (2020-04-27)
|
22
|
+
|
23
|
+
- Added support for saving tensor lists
|
24
|
+
- Added `ndim` and `ndimension` methods to tensors
|
25
|
+
|
26
|
+
## 0.2.1 (2020-04-26)
|
27
|
+
|
28
|
+
- Added support for saving and loading models
|
29
|
+
- Improved error messages
|
30
|
+
- Reduced gem size
|
31
|
+
|
1
32
|
## 0.2.0 (2020-04-22)
|
2
33
|
|
3
34
|
- No longer experimental
|
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -283,6 +297,22 @@ loss.backward
|
|
283
297
|
optimizer.step
|
284
298
|
```
|
285
299
|
|
300
|
+
### Saving and Loading Models
|
301
|
+
|
302
|
+
Save a model
|
303
|
+
|
304
|
+
```ruby
|
305
|
+
Torch.save(net.state_dict, "net.pth")
|
306
|
+
```
|
307
|
+
|
308
|
+
Load a model
|
309
|
+
|
310
|
+
```ruby
|
311
|
+
net = MyNet.new
|
312
|
+
net.load_state_dict(Torch.load("net.pth"))
|
313
|
+
net.eval
|
314
|
+
```
|
315
|
+
|
286
316
|
### Tensor Creation
|
287
317
|
|
288
318
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -397,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
397
427
|
|
398
428
|
### Linux
|
399
429
|
|
400
|
-
Deep learning is significantly faster on
|
401
|
-
|
402
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
403
431
|
|
404
432
|
Check if CUDA is available
|
405
433
|
|
@@ -410,7 +438,7 @@ Torch::CUDA.available?
|
|
410
438
|
Move a neural network to a GPU
|
411
439
|
|
412
440
|
```ruby
|
413
|
-
net.
|
441
|
+
net.cuda
|
414
442
|
```
|
415
443
|
|
416
444
|
## rbenv
|
@@ -445,6 +473,8 @@ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
|
|
445
473
|
bundle exec rake test
|
446
474
|
```
|
447
475
|
|
476
|
+
You can use [this script](https://gist.github.com/ankane/9b2b5fcbd66d6e4ccfeb9d73e529abe7) to test on GPUs with the AWS Deep Learning Base AMI (Ubuntu 18.04).
|
477
|
+
|
448
478
|
Here are some good resources for contributors:
|
449
479
|
|
450
480
|
- [PyTorch API](https://pytorch.org/docs/stable/torch.html)
|
data/ext/torch/ext.cpp
CHANGED
@@ -5,6 +5,7 @@
|
|
5
5
|
#include <rice/Array.hpp>
|
6
6
|
#include <rice/Class.hpp>
|
7
7
|
#include <rice/Constructor.hpp>
|
8
|
+
#include <rice/Hash.hpp>
|
8
9
|
|
9
10
|
#include "templates.hpp"
|
10
11
|
|
@@ -22,18 +23,163 @@ class Parameter: public torch::autograd::Variable {
|
|
22
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
23
24
|
};
|
24
25
|
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
|
+
{
|
28
|
+
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
|
+
}
|
30
|
+
|
25
31
|
extern "C"
|
26
32
|
void Init_ext()
|
27
33
|
{
|
28
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
29
36
|
add_torch_functions(rb_mTorch);
|
30
37
|
|
31
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
32
40
|
add_tensor_functions(rb_cTensor);
|
33
41
|
|
34
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
35
44
|
add_nn_functions(rb_mNN);
|
36
45
|
|
46
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
48
|
+
.define_singleton_method(
|
49
|
+
"initial_seed",
|
50
|
+
*[]() {
|
51
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
52
|
+
})
|
53
|
+
.define_singleton_method(
|
54
|
+
"seed",
|
55
|
+
*[]() {
|
56
|
+
// TODO set for CUDA when available
|
57
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
58
|
+
});
|
59
|
+
|
60
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
61
|
+
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
63
|
+
.define_constructor(Constructor<torch::IValue>())
|
64
|
+
.define_method("bool?", &torch::IValue::isBool)
|
65
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
66
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
67
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
68
|
+
.define_method("device?", &torch::IValue::isDevice)
|
69
|
+
.define_method("double?", &torch::IValue::isDouble)
|
70
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
71
|
+
.define_method("future?", &torch::IValue::isFuture)
|
72
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
73
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
74
|
+
.define_method("list?", &torch::IValue::isList)
|
75
|
+
.define_method("int?", &torch::IValue::isInt)
|
76
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
77
|
+
.define_method("module?", &torch::IValue::isModule)
|
78
|
+
.define_method("none?", &torch::IValue::isNone)
|
79
|
+
.define_method("object?", &torch::IValue::isObject)
|
80
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
81
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
82
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
83
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
84
|
+
.define_method("string?", &torch::IValue::isString)
|
85
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
86
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
87
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
88
|
+
.define_method(
|
89
|
+
"to_bool",
|
90
|
+
*[](torch::IValue& self) {
|
91
|
+
return self.toBool();
|
92
|
+
})
|
93
|
+
.define_method(
|
94
|
+
"to_double",
|
95
|
+
*[](torch::IValue& self) {
|
96
|
+
return self.toDouble();
|
97
|
+
})
|
98
|
+
.define_method(
|
99
|
+
"to_int",
|
100
|
+
*[](torch::IValue& self) {
|
101
|
+
return self.toInt();
|
102
|
+
})
|
103
|
+
.define_method(
|
104
|
+
"to_list",
|
105
|
+
*[](torch::IValue& self) {
|
106
|
+
auto list = self.toListRef();
|
107
|
+
Array obj;
|
108
|
+
for (auto& elem : list) {
|
109
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
110
|
+
}
|
111
|
+
return obj;
|
112
|
+
})
|
113
|
+
.define_method(
|
114
|
+
"to_string_ref",
|
115
|
+
*[](torch::IValue& self) {
|
116
|
+
return self.toStringRef();
|
117
|
+
})
|
118
|
+
.define_method(
|
119
|
+
"to_tensor",
|
120
|
+
*[](torch::IValue& self) {
|
121
|
+
return self.toTensor();
|
122
|
+
})
|
123
|
+
.define_method(
|
124
|
+
"to_generic_dict",
|
125
|
+
*[](torch::IValue& self) {
|
126
|
+
auto dict = self.toGenericDict();
|
127
|
+
Hash obj;
|
128
|
+
for (auto& pair : dict) {
|
129
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
130
|
+
}
|
131
|
+
return obj;
|
132
|
+
})
|
133
|
+
.define_singleton_method(
|
134
|
+
"from_tensor",
|
135
|
+
*[](torch::Tensor& v) {
|
136
|
+
return torch::IValue(v);
|
137
|
+
})
|
138
|
+
// TODO create specialized list types?
|
139
|
+
.define_singleton_method(
|
140
|
+
"from_list",
|
141
|
+
*[](Array obj) {
|
142
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
143
|
+
for (auto entry : obj) {
|
144
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
145
|
+
}
|
146
|
+
return torch::IValue(list);
|
147
|
+
})
|
148
|
+
.define_singleton_method(
|
149
|
+
"from_string",
|
150
|
+
*[](String v) {
|
151
|
+
return torch::IValue(v.str());
|
152
|
+
})
|
153
|
+
.define_singleton_method(
|
154
|
+
"from_int",
|
155
|
+
*[](int64_t v) {
|
156
|
+
return torch::IValue(v);
|
157
|
+
})
|
158
|
+
.define_singleton_method(
|
159
|
+
"from_double",
|
160
|
+
*[](double v) {
|
161
|
+
return torch::IValue(v);
|
162
|
+
})
|
163
|
+
.define_singleton_method(
|
164
|
+
"from_bool",
|
165
|
+
*[](bool v) {
|
166
|
+
return torch::IValue(v);
|
167
|
+
})
|
168
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
169
|
+
// createGenericDict and toIValue
|
170
|
+
.define_singleton_method(
|
171
|
+
"from_dict",
|
172
|
+
*[](Hash obj) {
|
173
|
+
auto key_type = c10::AnyType::get();
|
174
|
+
auto value_type = c10::AnyType::get();
|
175
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
176
|
+
elems.reserve(obj.size());
|
177
|
+
for (auto entry : obj) {
|
178
|
+
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
|
179
|
+
}
|
180
|
+
return torch::IValue(std::move(elems));
|
181
|
+
});
|
182
|
+
|
37
183
|
rb_mTorch.define_singleton_method(
|
38
184
|
"grad_enabled?",
|
39
185
|
*[]() {
|
@@ -49,6 +195,17 @@ void Init_ext()
|
|
49
195
|
*[](uint64_t seed) {
|
50
196
|
return torch::manual_seed(seed);
|
51
197
|
})
|
198
|
+
// config
|
199
|
+
.define_singleton_method(
|
200
|
+
"show_config",
|
201
|
+
*[] {
|
202
|
+
return torch::show_config();
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"parallel_info",
|
206
|
+
*[] {
|
207
|
+
return torch::get_parallel_info();
|
208
|
+
})
|
52
209
|
// begin tensor creation
|
53
210
|
.define_singleton_method(
|
54
211
|
"_arange",
|
@@ -113,11 +270,19 @@ void Init_ext()
|
|
113
270
|
// begin operations
|
114
271
|
.define_singleton_method(
|
115
272
|
"_save",
|
116
|
-
*[](const
|
273
|
+
*[](const torch::IValue &value) {
|
117
274
|
auto v = torch::pickle_save(value);
|
118
275
|
std::string str(v.begin(), v.end());
|
119
276
|
return str;
|
120
277
|
})
|
278
|
+
.define_singleton_method(
|
279
|
+
"_load",
|
280
|
+
*[](const std::string &s) {
|
281
|
+
std::vector<char> v;
|
282
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
283
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
284
|
+
return torch::pickle_load(v);
|
285
|
+
})
|
121
286
|
.define_singleton_method(
|
122
287
|
"_binary_cross_entropy_with_logits",
|
123
288
|
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
@@ -213,6 +378,21 @@ void Init_ext()
|
|
213
378
|
s << self.device();
|
214
379
|
return s.str();
|
215
380
|
})
|
381
|
+
.define_method(
|
382
|
+
"_data_str",
|
383
|
+
*[](Tensor& self) {
|
384
|
+
Tensor tensor = self;
|
385
|
+
|
386
|
+
// move to CPU to get data
|
387
|
+
if (tensor.device().type() != torch::kCPU) {
|
388
|
+
torch::Device device("cpu");
|
389
|
+
tensor = tensor.to(device);
|
390
|
+
}
|
391
|
+
|
392
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
393
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
394
|
+
})
|
395
|
+
// TODO figure out a better way to do this
|
216
396
|
.define_method(
|
217
397
|
"_flat_data",
|
218
398
|
*[](Tensor& self) {
|
@@ -227,46 +407,40 @@ void Init_ext()
|
|
227
407
|
Array a;
|
228
408
|
auto dtype = tensor.dtype();
|
229
409
|
|
410
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
411
|
+
|
230
412
|
// TODO DRY if someone knows C++
|
231
413
|
if (dtype == torch::kByte) {
|
232
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
233
414
|
for (int i = 0; i < tensor.numel(); i++) {
|
234
|
-
a.push(
|
415
|
+
a.push(view[i].item().to<uint8_t>());
|
235
416
|
}
|
236
417
|
} else if (dtype == torch::kChar) {
|
237
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
238
418
|
for (int i = 0; i < tensor.numel(); i++) {
|
239
|
-
a.push(to_ruby<int>(
|
419
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
240
420
|
}
|
241
421
|
} else if (dtype == torch::kShort) {
|
242
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
243
422
|
for (int i = 0; i < tensor.numel(); i++) {
|
244
|
-
a.push(
|
423
|
+
a.push(view[i].item().to<int16_t>());
|
245
424
|
}
|
246
425
|
} else if (dtype == torch::kInt) {
|
247
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
248
426
|
for (int i = 0; i < tensor.numel(); i++) {
|
249
|
-
a.push(
|
427
|
+
a.push(view[i].item().to<int32_t>());
|
250
428
|
}
|
251
429
|
} else if (dtype == torch::kLong) {
|
252
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
253
430
|
for (int i = 0; i < tensor.numel(); i++) {
|
254
|
-
a.push(
|
431
|
+
a.push(view[i].item().to<int64_t>());
|
255
432
|
}
|
256
433
|
} else if (dtype == torch::kFloat) {
|
257
|
-
float* data = tensor.data_ptr<float>();
|
258
434
|
for (int i = 0; i < tensor.numel(); i++) {
|
259
|
-
a.push(
|
435
|
+
a.push(view[i].item().to<float>());
|
260
436
|
}
|
261
437
|
} else if (dtype == torch::kDouble) {
|
262
|
-
double* data = tensor.data_ptr<double>();
|
263
438
|
for (int i = 0; i < tensor.numel(); i++) {
|
264
|
-
a.push(
|
439
|
+
a.push(view[i].item().to<double>());
|
265
440
|
}
|
266
441
|
} else if (dtype == torch::kBool) {
|
267
|
-
bool* data = tensor.data_ptr<bool>();
|
268
442
|
for (int i = 0; i < tensor.numel(); i++) {
|
269
|
-
a.push(
|
443
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
270
444
|
}
|
271
445
|
} else {
|
272
446
|
throw std::runtime_error("Unsupported type");
|
@@ -288,6 +462,7 @@ void Init_ext()
|
|
288
462
|
});
|
289
463
|
|
290
464
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
465
|
+
.add_handler<torch::Error>(handle_error)
|
291
466
|
.define_constructor(Constructor<torch::TensorOptions>())
|
292
467
|
.define_method(
|
293
468
|
"dtype",
|
@@ -311,13 +486,8 @@ void Init_ext()
|
|
311
486
|
.define_method(
|
312
487
|
"device",
|
313
488
|
*[](torch::TensorOptions& self, std::string device) {
|
314
|
-
|
315
|
-
|
316
|
-
torch::Device d(device);
|
317
|
-
return self.device(d);
|
318
|
-
} catch (const c10::Error& error) {
|
319
|
-
throw std::runtime_error(error.what_without_backtrace());
|
320
|
-
}
|
489
|
+
torch::Device d(device);
|
490
|
+
return self.device(d);
|
321
491
|
})
|
322
492
|
.define_method(
|
323
493
|
"requires_grad",
|
@@ -398,6 +568,7 @@ void Init_ext()
|
|
398
568
|
});
|
399
569
|
|
400
570
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
571
|
+
.add_handler<torch::Error>(handle_error)
|
401
572
|
.define_method(
|
402
573
|
"grad",
|
403
574
|
*[](Parameter& self) {
|
@@ -407,6 +578,7 @@ void Init_ext()
|
|
407
578
|
|
408
579
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
409
580
|
.define_constructor(Constructor<torch::Device, std::string>())
|
581
|
+
.add_handler<torch::Error>(handle_error)
|
410
582
|
.define_method("index", &torch::Device::index)
|
411
583
|
.define_method("index?", &torch::Device::has_index)
|
412
584
|
.define_method(
|
@@ -418,6 +590,7 @@ void Init_ext()
|
|
418
590
|
});
|
419
591
|
|
420
592
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
593
|
+
.add_handler<torch::Error>(handle_error)
|
421
594
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
422
595
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
423
596
|
}
|