torch-rb 0.2.0 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +31 -0
- data/README.md +36 -6
- data/ext/torch/ext.cpp +197 -24
- data/ext/torch/extconf.rb +34 -21
- data/lib/torch.rb +102 -6
- data/lib/torch/hub.rb +52 -0
- data/lib/torch/inspector.rb +3 -3
- data/lib/torch/nn/batch_norm.rb +5 -0
- data/lib/torch/nn/conv2d.rb +8 -1
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/max_poolnd.rb +2 -1
- data/lib/torch/nn/module.rb +45 -8
- data/lib/torch/tensor.rb +48 -26
- data/lib/torch/utils/data/data_loader.rb +32 -4
- data/lib/torch/utils/data/dataset.rb +8 -0
- data/lib/torch/utils/data/tensor_dataset.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +6 -13
- data/ext/torch/nn_functions.cpp +0 -560
- data/ext/torch/nn_functions.hpp +0 -6
- data/ext/torch/tensor_functions.cpp +0 -2085
- data/ext/torch/tensor_functions.hpp +0 -6
- data/ext/torch/torch_functions.cpp +0 -3175
- data/ext/torch/torch_functions.hpp +0 -6
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/random.rb +0 -10
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8f6ab78fb5cff27d0d60ddb9c08fb2f526bd60e241dd1011554b21716bdd2f43
|
4
|
+
data.tar.gz: 5568f53d8d5d688e3f29fb55ddbe9457e0b933dc69f59b422035c0cee249e396
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9c5dcfbf35382b37678662690677b2b90d0d544e8802703cf83dba6e10483a4df487f9687e4e898c9cc449c568f4e02f9d831daa982b5b3135af6f9ce176ec88
|
7
|
+
data.tar.gz: 34f142d874606e140661ae992a9f8cd4779f95c93c11d9a89a1864dd0bd53c5480c30d9aec5897f1955e2450bd0a3bc56ed0868e3b54d82ff4cdba40af379840
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,34 @@
|
|
1
|
+
## 0.2.5 (2020-06-07)
|
2
|
+
|
3
|
+
- Added `download_url_to_file` and `load_state_dict_from_url` to `Torch::Hub`
|
4
|
+
- Improved error messages
|
5
|
+
- Fixed tensor slicing
|
6
|
+
|
7
|
+
## 0.2.4 (2020-04-29)
|
8
|
+
|
9
|
+
- Added `to_i` and `to_f` to tensors
|
10
|
+
- Added `shuffle` option to data loader
|
11
|
+
- Fixed `modules` and `named_modules` for nested modules
|
12
|
+
|
13
|
+
## 0.2.3 (2020-04-28)
|
14
|
+
|
15
|
+
- Added `show_config` and `parallel_info` methods
|
16
|
+
- Added `initial_seed` and `seed` methods to `Random`
|
17
|
+
- Improved data loader
|
18
|
+
- Build with MKL-DNN and NNPACK when available
|
19
|
+
- Fixed `inspect` for modules
|
20
|
+
|
21
|
+
## 0.2.2 (2020-04-27)
|
22
|
+
|
23
|
+
- Added support for saving tensor lists
|
24
|
+
- Added `ndim` and `ndimension` methods to tensors
|
25
|
+
|
26
|
+
## 0.2.1 (2020-04-26)
|
27
|
+
|
28
|
+
- Added support for saving and loading models
|
29
|
+
- Improved error messages
|
30
|
+
- Reduced gem size
|
31
|
+
|
1
32
|
## 0.2.0 (2020-04-22)
|
2
33
|
|
3
34
|
- No longer experimental
|
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
+
For computer vision tasks, also check out [TorchVision](https://github.com/ankane/torchvision)
|
6
|
+
|
5
7
|
[](https://travis-ci.org/ankane/torch.rb)
|
6
8
|
|
7
9
|
## Installation
|
@@ -22,6 +24,18 @@ It can take a few minutes to compile the extension.
|
|
22
24
|
|
23
25
|
## Getting Started
|
24
26
|
|
27
|
+
Deep learning is significantly faster with a GPU. If you don’t have an NVIDIA GPU, we recommend using a cloud service. [Paperspace](https://www.paperspace.com/) has a great free plan.
|
28
|
+
|
29
|
+
We’ve put together a [Docker image](https://github.com/ankane/ml-stack) to make it easy to get started. On Paperspace, create a notebook with a custom container. Set the container name to:
|
30
|
+
|
31
|
+
```text
|
32
|
+
ankane/ml-stack:torch-gpu
|
33
|
+
```
|
34
|
+
|
35
|
+
And leave the other fields in that section blank. Once the notebook is running, you can run the [MNIST example](https://github.com/ankane/ml-stack/blob/master/torch-gpu/MNIST.ipynb).
|
36
|
+
|
37
|
+
## API
|
38
|
+
|
25
39
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
26
40
|
|
27
41
|
- Methods that perform in-place modifications end with `!` instead of `_` (`add!` instead of `add_`)
|
@@ -192,7 +206,7 @@ end
|
|
192
206
|
Define a neural network
|
193
207
|
|
194
208
|
```ruby
|
195
|
-
class
|
209
|
+
class MyNet < Torch::NN::Module
|
196
210
|
def initialize
|
197
211
|
super
|
198
212
|
@conv1 = Torch::NN::Conv2d.new(1, 6, 3)
|
@@ -226,7 +240,7 @@ end
|
|
226
240
|
Create an instance of it
|
227
241
|
|
228
242
|
```ruby
|
229
|
-
net =
|
243
|
+
net = MyNet.new
|
230
244
|
input = Torch.randn(1, 1, 32, 32)
|
231
245
|
net.call(input)
|
232
246
|
```
|
@@ -283,6 +297,22 @@ loss.backward
|
|
283
297
|
optimizer.step
|
284
298
|
```
|
285
299
|
|
300
|
+
### Saving and Loading Models
|
301
|
+
|
302
|
+
Save a model
|
303
|
+
|
304
|
+
```ruby
|
305
|
+
Torch.save(net.state_dict, "net.pth")
|
306
|
+
```
|
307
|
+
|
308
|
+
Load a model
|
309
|
+
|
310
|
+
```ruby
|
311
|
+
net = MyNet.new
|
312
|
+
net.load_state_dict(Torch.load("net.pth"))
|
313
|
+
net.eval
|
314
|
+
```
|
315
|
+
|
286
316
|
### Tensor Creation
|
287
317
|
|
288
318
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -397,9 +427,7 @@ Then install the gem (no need for `bundle config`).
|
|
397
427
|
|
398
428
|
### Linux
|
399
429
|
|
400
|
-
Deep learning is significantly faster on
|
401
|
-
|
402
|
-
Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
430
|
+
Deep learning is significantly faster on a GPU. Install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
403
431
|
|
404
432
|
Check if CUDA is available
|
405
433
|
|
@@ -410,7 +438,7 @@ Torch::CUDA.available?
|
|
410
438
|
Move a neural network to a GPU
|
411
439
|
|
412
440
|
```ruby
|
413
|
-
net.
|
441
|
+
net.cuda
|
414
442
|
```
|
415
443
|
|
416
444
|
## rbenv
|
@@ -445,6 +473,8 @@ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
|
|
445
473
|
bundle exec rake test
|
446
474
|
```
|
447
475
|
|
476
|
+
You can use [this script](https://gist.github.com/ankane/9b2b5fcbd66d6e4ccfeb9d73e529abe7) to test on GPUs with the AWS Deep Learning Base AMI (Ubuntu 18.04).
|
477
|
+
|
448
478
|
Here are some good resources for contributors:
|
449
479
|
|
450
480
|
- [PyTorch API](https://pytorch.org/docs/stable/torch.html)
|
data/ext/torch/ext.cpp
CHANGED
@@ -5,6 +5,7 @@
|
|
5
5
|
#include <rice/Array.hpp>
|
6
6
|
#include <rice/Class.hpp>
|
7
7
|
#include <rice/Constructor.hpp>
|
8
|
+
#include <rice/Hash.hpp>
|
8
9
|
|
9
10
|
#include "templates.hpp"
|
10
11
|
|
@@ -22,18 +23,163 @@ class Parameter: public torch::autograd::Variable {
|
|
22
23
|
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
23
24
|
};
|
24
25
|
|
26
|
+
void handle_error(torch::Error const & ex)
|
27
|
+
{
|
28
|
+
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
|
+
}
|
30
|
+
|
25
31
|
extern "C"
|
26
32
|
void Init_ext()
|
27
33
|
{
|
28
34
|
Module rb_mTorch = define_module("Torch");
|
35
|
+
rb_mTorch.add_handler<torch::Error>(handle_error);
|
29
36
|
add_torch_functions(rb_mTorch);
|
30
37
|
|
31
38
|
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
39
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
32
40
|
add_tensor_functions(rb_cTensor);
|
33
41
|
|
34
42
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
43
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
35
44
|
add_nn_functions(rb_mNN);
|
36
45
|
|
46
|
+
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
47
|
+
.add_handler<torch::Error>(handle_error)
|
48
|
+
.define_singleton_method(
|
49
|
+
"initial_seed",
|
50
|
+
*[]() {
|
51
|
+
return at::detail::getDefaultCPUGenerator()->current_seed();
|
52
|
+
})
|
53
|
+
.define_singleton_method(
|
54
|
+
"seed",
|
55
|
+
*[]() {
|
56
|
+
// TODO set for CUDA when available
|
57
|
+
return at::detail::getDefaultCPUGenerator()->seed();
|
58
|
+
});
|
59
|
+
|
60
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
61
|
+
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
62
|
+
.add_handler<torch::Error>(handle_error)
|
63
|
+
.define_constructor(Constructor<torch::IValue>())
|
64
|
+
.define_method("bool?", &torch::IValue::isBool)
|
65
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
66
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
67
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
68
|
+
.define_method("device?", &torch::IValue::isDevice)
|
69
|
+
.define_method("double?", &torch::IValue::isDouble)
|
70
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
71
|
+
.define_method("future?", &torch::IValue::isFuture)
|
72
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
73
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
74
|
+
.define_method("list?", &torch::IValue::isList)
|
75
|
+
.define_method("int?", &torch::IValue::isInt)
|
76
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
77
|
+
.define_method("module?", &torch::IValue::isModule)
|
78
|
+
.define_method("none?", &torch::IValue::isNone)
|
79
|
+
.define_method("object?", &torch::IValue::isObject)
|
80
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
81
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
82
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
83
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
84
|
+
.define_method("string?", &torch::IValue::isString)
|
85
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
86
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
87
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
88
|
+
.define_method(
|
89
|
+
"to_bool",
|
90
|
+
*[](torch::IValue& self) {
|
91
|
+
return self.toBool();
|
92
|
+
})
|
93
|
+
.define_method(
|
94
|
+
"to_double",
|
95
|
+
*[](torch::IValue& self) {
|
96
|
+
return self.toDouble();
|
97
|
+
})
|
98
|
+
.define_method(
|
99
|
+
"to_int",
|
100
|
+
*[](torch::IValue& self) {
|
101
|
+
return self.toInt();
|
102
|
+
})
|
103
|
+
.define_method(
|
104
|
+
"to_list",
|
105
|
+
*[](torch::IValue& self) {
|
106
|
+
auto list = self.toListRef();
|
107
|
+
Array obj;
|
108
|
+
for (auto& elem : list) {
|
109
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
110
|
+
}
|
111
|
+
return obj;
|
112
|
+
})
|
113
|
+
.define_method(
|
114
|
+
"to_string_ref",
|
115
|
+
*[](torch::IValue& self) {
|
116
|
+
return self.toStringRef();
|
117
|
+
})
|
118
|
+
.define_method(
|
119
|
+
"to_tensor",
|
120
|
+
*[](torch::IValue& self) {
|
121
|
+
return self.toTensor();
|
122
|
+
})
|
123
|
+
.define_method(
|
124
|
+
"to_generic_dict",
|
125
|
+
*[](torch::IValue& self) {
|
126
|
+
auto dict = self.toGenericDict();
|
127
|
+
Hash obj;
|
128
|
+
for (auto& pair : dict) {
|
129
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
130
|
+
}
|
131
|
+
return obj;
|
132
|
+
})
|
133
|
+
.define_singleton_method(
|
134
|
+
"from_tensor",
|
135
|
+
*[](torch::Tensor& v) {
|
136
|
+
return torch::IValue(v);
|
137
|
+
})
|
138
|
+
// TODO create specialized list types?
|
139
|
+
.define_singleton_method(
|
140
|
+
"from_list",
|
141
|
+
*[](Array obj) {
|
142
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
143
|
+
for (auto entry : obj) {
|
144
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
145
|
+
}
|
146
|
+
return torch::IValue(list);
|
147
|
+
})
|
148
|
+
.define_singleton_method(
|
149
|
+
"from_string",
|
150
|
+
*[](String v) {
|
151
|
+
return torch::IValue(v.str());
|
152
|
+
})
|
153
|
+
.define_singleton_method(
|
154
|
+
"from_int",
|
155
|
+
*[](int64_t v) {
|
156
|
+
return torch::IValue(v);
|
157
|
+
})
|
158
|
+
.define_singleton_method(
|
159
|
+
"from_double",
|
160
|
+
*[](double v) {
|
161
|
+
return torch::IValue(v);
|
162
|
+
})
|
163
|
+
.define_singleton_method(
|
164
|
+
"from_bool",
|
165
|
+
*[](bool v) {
|
166
|
+
return torch::IValue(v);
|
167
|
+
})
|
168
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
169
|
+
// createGenericDict and toIValue
|
170
|
+
.define_singleton_method(
|
171
|
+
"from_dict",
|
172
|
+
*[](Hash obj) {
|
173
|
+
auto key_type = c10::AnyType::get();
|
174
|
+
auto value_type = c10::AnyType::get();
|
175
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
176
|
+
elems.reserve(obj.size());
|
177
|
+
for (auto entry : obj) {
|
178
|
+
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
|
179
|
+
}
|
180
|
+
return torch::IValue(std::move(elems));
|
181
|
+
});
|
182
|
+
|
37
183
|
rb_mTorch.define_singleton_method(
|
38
184
|
"grad_enabled?",
|
39
185
|
*[]() {
|
@@ -49,6 +195,17 @@ void Init_ext()
|
|
49
195
|
*[](uint64_t seed) {
|
50
196
|
return torch::manual_seed(seed);
|
51
197
|
})
|
198
|
+
// config
|
199
|
+
.define_singleton_method(
|
200
|
+
"show_config",
|
201
|
+
*[] {
|
202
|
+
return torch::show_config();
|
203
|
+
})
|
204
|
+
.define_singleton_method(
|
205
|
+
"parallel_info",
|
206
|
+
*[] {
|
207
|
+
return torch::get_parallel_info();
|
208
|
+
})
|
52
209
|
// begin tensor creation
|
53
210
|
.define_singleton_method(
|
54
211
|
"_arange",
|
@@ -113,11 +270,19 @@ void Init_ext()
|
|
113
270
|
// begin operations
|
114
271
|
.define_singleton_method(
|
115
272
|
"_save",
|
116
|
-
*[](const
|
273
|
+
*[](const torch::IValue &value) {
|
117
274
|
auto v = torch::pickle_save(value);
|
118
275
|
std::string str(v.begin(), v.end());
|
119
276
|
return str;
|
120
277
|
})
|
278
|
+
.define_singleton_method(
|
279
|
+
"_load",
|
280
|
+
*[](const std::string &s) {
|
281
|
+
std::vector<char> v;
|
282
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
283
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
284
|
+
return torch::pickle_load(v);
|
285
|
+
})
|
121
286
|
.define_singleton_method(
|
122
287
|
"_binary_cross_entropy_with_logits",
|
123
288
|
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
@@ -213,6 +378,21 @@ void Init_ext()
|
|
213
378
|
s << self.device();
|
214
379
|
return s.str();
|
215
380
|
})
|
381
|
+
.define_method(
|
382
|
+
"_data_str",
|
383
|
+
*[](Tensor& self) {
|
384
|
+
Tensor tensor = self;
|
385
|
+
|
386
|
+
// move to CPU to get data
|
387
|
+
if (tensor.device().type() != torch::kCPU) {
|
388
|
+
torch::Device device("cpu");
|
389
|
+
tensor = tensor.to(device);
|
390
|
+
}
|
391
|
+
|
392
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
393
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
394
|
+
})
|
395
|
+
// TODO figure out a better way to do this
|
216
396
|
.define_method(
|
217
397
|
"_flat_data",
|
218
398
|
*[](Tensor& self) {
|
@@ -227,46 +407,40 @@ void Init_ext()
|
|
227
407
|
Array a;
|
228
408
|
auto dtype = tensor.dtype();
|
229
409
|
|
410
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
411
|
+
|
230
412
|
// TODO DRY if someone knows C++
|
231
413
|
if (dtype == torch::kByte) {
|
232
|
-
uint8_t* data = tensor.data_ptr<uint8_t>();
|
233
414
|
for (int i = 0; i < tensor.numel(); i++) {
|
234
|
-
a.push(
|
415
|
+
a.push(view[i].item().to<uint8_t>());
|
235
416
|
}
|
236
417
|
} else if (dtype == torch::kChar) {
|
237
|
-
int8_t* data = tensor.data_ptr<int8_t>();
|
238
418
|
for (int i = 0; i < tensor.numel(); i++) {
|
239
|
-
a.push(to_ruby<int>(
|
419
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
240
420
|
}
|
241
421
|
} else if (dtype == torch::kShort) {
|
242
|
-
int16_t* data = tensor.data_ptr<int16_t>();
|
243
422
|
for (int i = 0; i < tensor.numel(); i++) {
|
244
|
-
a.push(
|
423
|
+
a.push(view[i].item().to<int16_t>());
|
245
424
|
}
|
246
425
|
} else if (dtype == torch::kInt) {
|
247
|
-
int32_t* data = tensor.data_ptr<int32_t>();
|
248
426
|
for (int i = 0; i < tensor.numel(); i++) {
|
249
|
-
a.push(
|
427
|
+
a.push(view[i].item().to<int32_t>());
|
250
428
|
}
|
251
429
|
} else if (dtype == torch::kLong) {
|
252
|
-
int64_t* data = tensor.data_ptr<int64_t>();
|
253
430
|
for (int i = 0; i < tensor.numel(); i++) {
|
254
|
-
a.push(
|
431
|
+
a.push(view[i].item().to<int64_t>());
|
255
432
|
}
|
256
433
|
} else if (dtype == torch::kFloat) {
|
257
|
-
float* data = tensor.data_ptr<float>();
|
258
434
|
for (int i = 0; i < tensor.numel(); i++) {
|
259
|
-
a.push(
|
435
|
+
a.push(view[i].item().to<float>());
|
260
436
|
}
|
261
437
|
} else if (dtype == torch::kDouble) {
|
262
|
-
double* data = tensor.data_ptr<double>();
|
263
438
|
for (int i = 0; i < tensor.numel(); i++) {
|
264
|
-
a.push(
|
439
|
+
a.push(view[i].item().to<double>());
|
265
440
|
}
|
266
441
|
} else if (dtype == torch::kBool) {
|
267
|
-
bool* data = tensor.data_ptr<bool>();
|
268
442
|
for (int i = 0; i < tensor.numel(); i++) {
|
269
|
-
a.push(
|
443
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
270
444
|
}
|
271
445
|
} else {
|
272
446
|
throw std::runtime_error("Unsupported type");
|
@@ -288,6 +462,7 @@ void Init_ext()
|
|
288
462
|
});
|
289
463
|
|
290
464
|
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
465
|
+
.add_handler<torch::Error>(handle_error)
|
291
466
|
.define_constructor(Constructor<torch::TensorOptions>())
|
292
467
|
.define_method(
|
293
468
|
"dtype",
|
@@ -311,13 +486,8 @@ void Init_ext()
|
|
311
486
|
.define_method(
|
312
487
|
"device",
|
313
488
|
*[](torch::TensorOptions& self, std::string device) {
|
314
|
-
|
315
|
-
|
316
|
-
torch::Device d(device);
|
317
|
-
return self.device(d);
|
318
|
-
} catch (const c10::Error& error) {
|
319
|
-
throw std::runtime_error(error.what_without_backtrace());
|
320
|
-
}
|
489
|
+
torch::Device d(device);
|
490
|
+
return self.device(d);
|
321
491
|
})
|
322
492
|
.define_method(
|
323
493
|
"requires_grad",
|
@@ -398,6 +568,7 @@ void Init_ext()
|
|
398
568
|
});
|
399
569
|
|
400
570
|
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
571
|
+
.add_handler<torch::Error>(handle_error)
|
401
572
|
.define_method(
|
402
573
|
"grad",
|
403
574
|
*[](Parameter& self) {
|
@@ -407,6 +578,7 @@ void Init_ext()
|
|
407
578
|
|
408
579
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
409
580
|
.define_constructor(Constructor<torch::Device, std::string>())
|
581
|
+
.add_handler<torch::Error>(handle_error)
|
410
582
|
.define_method("index", &torch::Device::index)
|
411
583
|
.define_method("index?", &torch::Device::has_index)
|
412
584
|
.define_method(
|
@@ -418,6 +590,7 @@ void Init_ext()
|
|
418
590
|
});
|
419
591
|
|
420
592
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
593
|
+
.add_handler<torch::Error>(handle_error)
|
421
594
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
422
595
|
.define_singleton_method("device_count", &torch::cuda::device_count);
|
423
596
|
}
|