torch-rb 0.5.0 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/README.md +13 -4
- data/codegen/generate_functions.rb +13 -14
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +28 -0
- data/ext/torch/ext.cpp +26 -613
- data/ext/torch/extconf.rb +1 -4
- data/ext/torch/ivalue.cpp +132 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +3 -3
- data/ext/torch/ruby_arg_parser.h +37 -16
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +320 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +95 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -2
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +19 -17
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +9 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +14 -91
data/ext/torch/templates.h
CHANGED
@@ -4,8 +4,7 @@
|
|
4
4
|
#undef isfinite
|
5
5
|
#endif
|
6
6
|
|
7
|
-
#include <rice/
|
8
|
-
#include <rice/Object.hpp>
|
7
|
+
#include <rice/rice.hpp>
|
9
8
|
|
10
9
|
using namespace Rice;
|
11
10
|
|
@@ -23,6 +22,9 @@ using torch::ArrayRef;
|
|
23
22
|
using torch::TensorList;
|
24
23
|
using torch::Storage;
|
25
24
|
|
25
|
+
using torch::nn::init::FanModeType;
|
26
|
+
using torch::nn::init::NonlinearityType;
|
27
|
+
|
26
28
|
#define HANDLE_TH_ERRORS \
|
27
29
|
try {
|
28
30
|
|
@@ -38,37 +40,42 @@ using torch::Storage;
|
|
38
40
|
#define RETURN_NIL \
|
39
41
|
return Qnil;
|
40
42
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
{
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
43
|
+
class OptionalTensor {
|
44
|
+
torch::Tensor value;
|
45
|
+
public:
|
46
|
+
OptionalTensor(Object o) {
|
47
|
+
if (o.is_nil()) {
|
48
|
+
value = {};
|
49
|
+
} else {
|
50
|
+
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
51
|
+
}
|
52
|
+
}
|
53
|
+
OptionalTensor(torch::Tensor o) {
|
54
|
+
value = o;
|
55
|
+
}
|
56
|
+
operator torch::Tensor() const {
|
57
|
+
return value;
|
58
|
+
}
|
59
|
+
};
|
52
60
|
|
53
|
-
|
54
|
-
inline
|
55
|
-
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
61
|
+
namespace Rice::detail
|
56
62
|
{
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
}
|
63
|
+
template<>
|
64
|
+
struct Type<FanModeType>
|
65
|
+
{
|
66
|
+
static bool verify()
|
67
|
+
{
|
68
|
+
return true;
|
69
|
+
}
|
70
|
+
};
|
64
71
|
|
65
|
-
|
66
|
-
|
72
|
+
template<>
|
73
|
+
class From_Ruby<FanModeType>
|
74
|
+
{
|
67
75
|
public:
|
68
|
-
FanModeType(
|
69
|
-
|
70
|
-
|
71
|
-
operator torch::nn::init::FanModeType() {
|
76
|
+
FanModeType convert(VALUE x)
|
77
|
+
{
|
78
|
+
auto s = String(x).str();
|
72
79
|
if (s == "fan_in") {
|
73
80
|
return torch::kFanIn;
|
74
81
|
} else if (s == "fan_out") {
|
@@ -77,22 +84,24 @@ class FanModeType {
|
|
77
84
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
78
85
|
}
|
79
86
|
}
|
80
|
-
};
|
81
|
-
|
82
|
-
template<>
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
87
|
+
};
|
88
|
+
|
89
|
+
template<>
|
90
|
+
struct Type<NonlinearityType>
|
91
|
+
{
|
92
|
+
static bool verify()
|
93
|
+
{
|
94
|
+
return true;
|
95
|
+
}
|
96
|
+
};
|
88
97
|
|
89
|
-
|
90
|
-
|
98
|
+
template<>
|
99
|
+
class From_Ruby<NonlinearityType>
|
100
|
+
{
|
91
101
|
public:
|
92
|
-
NonlinearityType(
|
93
|
-
|
94
|
-
|
95
|
-
operator torch::nn::init::NonlinearityType() {
|
102
|
+
NonlinearityType convert(VALUE x)
|
103
|
+
{
|
104
|
+
auto s = String(x).str();
|
96
105
|
if (s == "linear") {
|
97
106
|
return torch::kLinear;
|
98
107
|
} else if (s == "conv1d") {
|
@@ -119,102 +128,70 @@ class NonlinearityType {
|
|
119
128
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
120
129
|
}
|
121
130
|
}
|
122
|
-
};
|
131
|
+
};
|
132
|
+
|
133
|
+
template<>
|
134
|
+
struct Type<OptionalTensor>
|
135
|
+
{
|
136
|
+
static bool verify()
|
137
|
+
{
|
138
|
+
return true;
|
139
|
+
}
|
140
|
+
};
|
123
141
|
|
124
|
-
template<>
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
142
|
+
template<>
|
143
|
+
class From_Ruby<OptionalTensor>
|
144
|
+
{
|
145
|
+
public:
|
146
|
+
OptionalTensor convert(VALUE x)
|
147
|
+
{
|
148
|
+
return OptionalTensor(x);
|
149
|
+
}
|
150
|
+
};
|
151
|
+
|
152
|
+
template<>
|
153
|
+
struct Type<Scalar>
|
154
|
+
{
|
155
|
+
static bool verify()
|
156
|
+
{
|
157
|
+
return true;
|
158
|
+
}
|
159
|
+
};
|
130
160
|
|
131
|
-
|
132
|
-
|
161
|
+
template<>
|
162
|
+
class From_Ruby<Scalar>
|
163
|
+
{
|
133
164
|
public:
|
134
|
-
|
135
|
-
|
136
|
-
|
165
|
+
Scalar convert(VALUE x)
|
166
|
+
{
|
167
|
+
if (FIXNUM_P(x)) {
|
168
|
+
return torch::Scalar(From_Ruby<int64_t>().convert(x));
|
137
169
|
} else {
|
138
|
-
|
170
|
+
return torch::Scalar(From_Ruby<double>().convert(x));
|
139
171
|
}
|
140
172
|
}
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
173
|
+
};
|
174
|
+
|
175
|
+
template<typename T>
|
176
|
+
struct Type<torch::optional<T>>
|
177
|
+
{
|
178
|
+
static bool verify()
|
179
|
+
{
|
180
|
+
return true;
|
146
181
|
}
|
147
|
-
};
|
148
|
-
|
149
|
-
template<>
|
150
|
-
inline
|
151
|
-
Scalar from_ruby<Scalar>(Object x)
|
152
|
-
{
|
153
|
-
if (x.rb_type() == T_FIXNUM) {
|
154
|
-
return torch::Scalar(from_ruby<int64_t>(x));
|
155
|
-
} else {
|
156
|
-
return torch::Scalar(from_ruby<double>(x));
|
157
|
-
}
|
158
|
-
}
|
159
|
-
|
160
|
-
template<>
|
161
|
-
inline
|
162
|
-
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
163
|
-
{
|
164
|
-
return OptionalTensor(x);
|
165
|
-
}
|
166
|
-
|
167
|
-
template<>
|
168
|
-
inline
|
169
|
-
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
170
|
-
{
|
171
|
-
if (x.is_nil()) {
|
172
|
-
return torch::nullopt;
|
173
|
-
} else {
|
174
|
-
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
175
|
-
}
|
176
|
-
}
|
177
|
-
|
178
|
-
template<>
|
179
|
-
inline
|
180
|
-
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
181
|
-
{
|
182
|
-
if (x.is_nil()) {
|
183
|
-
return torch::nullopt;
|
184
|
-
} else {
|
185
|
-
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
186
|
-
}
|
187
|
-
}
|
182
|
+
};
|
188
183
|
|
189
|
-
template
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
}
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
203
|
-
{
|
204
|
-
if (x.is_nil()) {
|
205
|
-
return torch::nullopt;
|
206
|
-
} else {
|
207
|
-
return torch::optional<bool>{from_ruby<bool>(x)};
|
208
|
-
}
|
209
|
-
}
|
210
|
-
|
211
|
-
template<>
|
212
|
-
inline
|
213
|
-
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
214
|
-
{
|
215
|
-
if (x.is_nil()) {
|
216
|
-
return torch::nullopt;
|
217
|
-
} else {
|
218
|
-
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
219
|
-
}
|
184
|
+
template<typename T>
|
185
|
+
class From_Ruby<torch::optional<T>>
|
186
|
+
{
|
187
|
+
public:
|
188
|
+
torch::optional<T> convert(VALUE x)
|
189
|
+
{
|
190
|
+
if (NIL_P(x)) {
|
191
|
+
return torch::nullopt;
|
192
|
+
} else {
|
193
|
+
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
194
|
+
}
|
195
|
+
}
|
196
|
+
};
|
220
197
|
}
|
@@ -0,0 +1,320 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "tensor_functions.h"
|
6
|
+
#include "ruby_arg_parser.h"
|
7
|
+
#include "templates.h"
|
8
|
+
#include "utils.h"
|
9
|
+
|
10
|
+
using namespace Rice;
|
11
|
+
using torch::indexing::TensorIndex;
|
12
|
+
|
13
|
+
namespace Rice::detail
|
14
|
+
{
|
15
|
+
template<typename T>
|
16
|
+
struct Type<c10::complex<T>>
|
17
|
+
{
|
18
|
+
static bool verify()
|
19
|
+
{
|
20
|
+
return true;
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
template<typename T>
|
25
|
+
class To_Ruby<c10::complex<T>>
|
26
|
+
{
|
27
|
+
public:
|
28
|
+
VALUE convert(c10::complex<T> const& x)
|
29
|
+
{
|
30
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
+
}
|
32
|
+
};
|
33
|
+
}
|
34
|
+
|
35
|
+
template<typename T>
|
36
|
+
Array flat_data(Tensor& tensor) {
|
37
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
38
|
+
|
39
|
+
Array a;
|
40
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
41
|
+
a.push(view[i].item().to<T>());
|
42
|
+
}
|
43
|
+
return a;
|
44
|
+
}
|
45
|
+
|
46
|
+
Class rb_cTensor;
|
47
|
+
|
48
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
49
|
+
Object obj;
|
50
|
+
|
51
|
+
std::vector<TensorIndex> indices;
|
52
|
+
indices.reserve(a.size());
|
53
|
+
|
54
|
+
for (long i = 0; i < a.size(); i++) {
|
55
|
+
obj = a[i];
|
56
|
+
|
57
|
+
if (obj.is_instance_of(rb_cInteger)) {
|
58
|
+
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
59
|
+
} else if (obj.is_instance_of(rb_cRange)) {
|
60
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
61
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
62
|
+
|
63
|
+
Object begin = obj.call("begin");
|
64
|
+
if (!begin.is_nil()) {
|
65
|
+
start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
|
66
|
+
}
|
67
|
+
|
68
|
+
Object end = obj.call("end");
|
69
|
+
if (!end.is_nil()) {
|
70
|
+
stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
|
71
|
+
}
|
72
|
+
|
73
|
+
Object exclude_end = obj.call("exclude_end?");
|
74
|
+
if (stop_index.has_value() && !exclude_end) {
|
75
|
+
if (stop_index.value() == -1) {
|
76
|
+
stop_index = torch::nullopt;
|
77
|
+
} else {
|
78
|
+
stop_index = stop_index.value() + 1;
|
79
|
+
}
|
80
|
+
}
|
81
|
+
|
82
|
+
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
83
|
+
} else if (obj.is_instance_of(rb_cTensor)) {
|
84
|
+
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
85
|
+
} else if (obj.is_nil()) {
|
86
|
+
indices.push_back(torch::indexing::None);
|
87
|
+
} else if (obj == True || obj == False) {
|
88
|
+
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
89
|
+
} else {
|
90
|
+
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
91
|
+
}
|
92
|
+
}
|
93
|
+
return indices;
|
94
|
+
}
|
95
|
+
|
96
|
+
// hack (removes inputs argument)
|
97
|
+
// https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
|
98
|
+
// TODO add support for inputs argument
|
99
|
+
// _backward
|
100
|
+
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
101
|
+
{
|
102
|
+
HANDLE_TH_ERRORS
|
103
|
+
Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
|
104
|
+
static RubyArgParser parser({
|
105
|
+
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
106
|
+
});
|
107
|
+
ParsedArgs<4> parsed_args;
|
108
|
+
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
109
|
+
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
110
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
111
|
+
// in future, release GVL
|
112
|
+
self._backward(inputs, gradient, retain_graph, create_graph);
|
113
|
+
};
|
114
|
+
dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
|
115
|
+
RETURN_NIL
|
116
|
+
END_HANDLE_TH_ERRORS
|
117
|
+
}
|
118
|
+
|
119
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
|
120
|
+
rb_cTensor = c;
|
121
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
122
|
+
add_tensor_functions(rb_cTensor);
|
123
|
+
THPVariableClass = rb_cTensor.value();
|
124
|
+
|
125
|
+
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
126
|
+
|
127
|
+
rb_cTensor
|
128
|
+
.define_method("cuda?", &torch::Tensor::is_cuda)
|
129
|
+
.define_method("sparse?", &torch::Tensor::is_sparse)
|
130
|
+
.define_method("quantized?", &torch::Tensor::is_quantized)
|
131
|
+
.define_method("dim", &torch::Tensor::dim)
|
132
|
+
.define_method("numel", &torch::Tensor::numel)
|
133
|
+
.define_method("element_size", &torch::Tensor::element_size)
|
134
|
+
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
135
|
+
.define_method(
|
136
|
+
"_size",
|
137
|
+
[](Tensor& self, int64_t dim) {
|
138
|
+
return self.size(dim);
|
139
|
+
})
|
140
|
+
.define_method(
|
141
|
+
"_stride",
|
142
|
+
[](Tensor& self, int64_t dim) {
|
143
|
+
return self.stride(dim);
|
144
|
+
})
|
145
|
+
// in C++ for performance
|
146
|
+
.define_method(
|
147
|
+
"shape",
|
148
|
+
[](Tensor& self) {
|
149
|
+
Array a;
|
150
|
+
for (auto &size : self.sizes()) {
|
151
|
+
a.push(size);
|
152
|
+
}
|
153
|
+
return a;
|
154
|
+
})
|
155
|
+
.define_method(
|
156
|
+
"_strides",
|
157
|
+
[](Tensor& self) {
|
158
|
+
Array a;
|
159
|
+
for (auto &stride : self.strides()) {
|
160
|
+
a.push(stride);
|
161
|
+
}
|
162
|
+
return a;
|
163
|
+
})
|
164
|
+
.define_method(
|
165
|
+
"_index",
|
166
|
+
[](Tensor& self, Array indices) {
|
167
|
+
auto vec = index_vector(indices);
|
168
|
+
return self.index(vec);
|
169
|
+
})
|
170
|
+
.define_method(
|
171
|
+
"_index_put_custom",
|
172
|
+
[](Tensor& self, Array indices, torch::Tensor& value) {
|
173
|
+
auto vec = index_vector(indices);
|
174
|
+
return self.index_put_(vec, value);
|
175
|
+
})
|
176
|
+
.define_method(
|
177
|
+
"contiguous?",
|
178
|
+
[](Tensor& self) {
|
179
|
+
return self.is_contiguous();
|
180
|
+
})
|
181
|
+
.define_method(
|
182
|
+
"_requires_grad!",
|
183
|
+
[](Tensor& self, bool requires_grad) {
|
184
|
+
return self.set_requires_grad(requires_grad);
|
185
|
+
})
|
186
|
+
.define_method(
|
187
|
+
"grad",
|
188
|
+
[](Tensor& self) {
|
189
|
+
auto grad = self.grad();
|
190
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
191
|
+
})
|
192
|
+
.define_method(
|
193
|
+
"grad=",
|
194
|
+
[](Tensor& self, torch::Tensor& grad) {
|
195
|
+
self.mutable_grad() = grad;
|
196
|
+
})
|
197
|
+
.define_method(
|
198
|
+
"_dtype",
|
199
|
+
[](Tensor& self) {
|
200
|
+
return (int) at::typeMetaToScalarType(self.dtype());
|
201
|
+
})
|
202
|
+
.define_method(
|
203
|
+
"_type",
|
204
|
+
[](Tensor& self, int dtype) {
|
205
|
+
return self.toType((torch::ScalarType) dtype);
|
206
|
+
})
|
207
|
+
.define_method(
|
208
|
+
"_layout",
|
209
|
+
[](Tensor& self) {
|
210
|
+
std::stringstream s;
|
211
|
+
s << self.layout();
|
212
|
+
return s.str();
|
213
|
+
})
|
214
|
+
.define_method(
|
215
|
+
"device",
|
216
|
+
[](Tensor& self) {
|
217
|
+
std::stringstream s;
|
218
|
+
s << self.device();
|
219
|
+
return s.str();
|
220
|
+
})
|
221
|
+
.define_method(
|
222
|
+
"_data_str",
|
223
|
+
[](Tensor& self) {
|
224
|
+
auto tensor = self;
|
225
|
+
|
226
|
+
// move to CPU to get data
|
227
|
+
if (tensor.device().type() != torch::kCPU) {
|
228
|
+
torch::Device device("cpu");
|
229
|
+
tensor = tensor.to(device);
|
230
|
+
}
|
231
|
+
|
232
|
+
if (!tensor.is_contiguous()) {
|
233
|
+
tensor = tensor.contiguous();
|
234
|
+
}
|
235
|
+
|
236
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
237
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
238
|
+
})
|
239
|
+
// for TorchVision
|
240
|
+
.define_method(
|
241
|
+
"_data_ptr",
|
242
|
+
[](Tensor& self) {
|
243
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
244
|
+
})
|
245
|
+
// TODO figure out a better way to do this
|
246
|
+
.define_method(
|
247
|
+
"_flat_data",
|
248
|
+
[](Tensor& self) {
|
249
|
+
auto tensor = self;
|
250
|
+
|
251
|
+
// move to CPU to get data
|
252
|
+
if (tensor.device().type() != torch::kCPU) {
|
253
|
+
torch::Device device("cpu");
|
254
|
+
tensor = tensor.to(device);
|
255
|
+
}
|
256
|
+
|
257
|
+
auto dtype = tensor.dtype();
|
258
|
+
if (dtype == torch::kByte) {
|
259
|
+
return flat_data<uint8_t>(tensor);
|
260
|
+
} else if (dtype == torch::kChar) {
|
261
|
+
return flat_data<int8_t>(tensor);
|
262
|
+
} else if (dtype == torch::kShort) {
|
263
|
+
return flat_data<int16_t>(tensor);
|
264
|
+
} else if (dtype == torch::kInt) {
|
265
|
+
return flat_data<int32_t>(tensor);
|
266
|
+
} else if (dtype == torch::kLong) {
|
267
|
+
return flat_data<int64_t>(tensor);
|
268
|
+
} else if (dtype == torch::kFloat) {
|
269
|
+
return flat_data<float>(tensor);
|
270
|
+
} else if (dtype == torch::kDouble) {
|
271
|
+
return flat_data<double>(tensor);
|
272
|
+
} else if (dtype == torch::kBool) {
|
273
|
+
return flat_data<bool>(tensor);
|
274
|
+
} else if (dtype == torch::kComplexFloat) {
|
275
|
+
return flat_data<c10::complex<float>>(tensor);
|
276
|
+
} else if (dtype == torch::kComplexDouble) {
|
277
|
+
return flat_data<c10::complex<double>>(tensor);
|
278
|
+
} else {
|
279
|
+
throw std::runtime_error("Unsupported type");
|
280
|
+
}
|
281
|
+
})
|
282
|
+
.define_method(
|
283
|
+
"_to",
|
284
|
+
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
285
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
286
|
+
});
|
287
|
+
|
288
|
+
rb_cTensorOptions
|
289
|
+
.add_handler<torch::Error>(handle_error)
|
290
|
+
.define_method(
|
291
|
+
"dtype",
|
292
|
+
[](torch::TensorOptions& self, int dtype) {
|
293
|
+
return self.dtype((torch::ScalarType) dtype);
|
294
|
+
})
|
295
|
+
.define_method(
|
296
|
+
"layout",
|
297
|
+
[](torch::TensorOptions& self, const std::string& layout) {
|
298
|
+
torch::Layout l;
|
299
|
+
if (layout == "strided") {
|
300
|
+
l = torch::kStrided;
|
301
|
+
} else if (layout == "sparse") {
|
302
|
+
l = torch::kSparse;
|
303
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
304
|
+
} else {
|
305
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
306
|
+
}
|
307
|
+
return self.layout(l);
|
308
|
+
})
|
309
|
+
.define_method(
|
310
|
+
"device",
|
311
|
+
[](torch::TensorOptions& self, const std::string& device) {
|
312
|
+
torch::Device d(device);
|
313
|
+
return self.device(d);
|
314
|
+
})
|
315
|
+
.define_method(
|
316
|
+
"requires_grad",
|
317
|
+
[](torch::TensorOptions& self, bool requires_grad) {
|
318
|
+
return self.requires_grad(requires_grad);
|
319
|
+
});
|
320
|
+
}
|