torch-rb 0.5.2 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +14 -4
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +48 -10
- data/codegen/native_functions.yaml +3318 -1679
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +28 -0
- data/ext/torch/ext.cpp +34 -613
- data/ext/torch/extconf.rb +1 -4
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +132 -0
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +3 -3
- data/ext/torch/ruby_arg_parser.h +44 -17
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +320 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +95 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -2
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +14 -10
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +21 -91
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "special_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_special(Rice::Module& m) {
|
10
|
+
auto rb_mSpecial = Rice::define_module_under(m, "Special");
|
11
|
+
rb_mSpecial.add_handler<torch::Error>(handle_error);
|
12
|
+
add_special_functions(rb_mSpecial);
|
13
|
+
}
|
data/ext/torch/templates.h
CHANGED
@@ -4,8 +4,7 @@
|
|
4
4
|
#undef isfinite
|
5
5
|
#endif
|
6
6
|
|
7
|
-
#include <rice/
|
8
|
-
#include <rice/Object.hpp>
|
7
|
+
#include <rice/rice.hpp>
|
9
8
|
|
10
9
|
using namespace Rice;
|
11
10
|
|
@@ -22,6 +21,10 @@ using torch::IntArrayRef;
|
|
22
21
|
using torch::ArrayRef;
|
23
22
|
using torch::TensorList;
|
24
23
|
using torch::Storage;
|
24
|
+
using ScalarList = ArrayRef<Scalar>;
|
25
|
+
|
26
|
+
using torch::nn::init::FanModeType;
|
27
|
+
using torch::nn::init::NonlinearityType;
|
25
28
|
|
26
29
|
#define HANDLE_TH_ERRORS \
|
27
30
|
try {
|
@@ -38,37 +41,42 @@ using torch::Storage;
|
|
38
41
|
#define RETURN_NIL \
|
39
42
|
return Qnil;
|
40
43
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
{
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
44
|
+
class OptionalTensor {
|
45
|
+
torch::Tensor value;
|
46
|
+
public:
|
47
|
+
OptionalTensor(Object o) {
|
48
|
+
if (o.is_nil()) {
|
49
|
+
value = {};
|
50
|
+
} else {
|
51
|
+
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
52
|
+
}
|
53
|
+
}
|
54
|
+
OptionalTensor(torch::Tensor o) {
|
55
|
+
value = o;
|
56
|
+
}
|
57
|
+
operator torch::Tensor() const {
|
58
|
+
return value;
|
59
|
+
}
|
60
|
+
};
|
52
61
|
|
53
|
-
|
54
|
-
inline
|
55
|
-
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
62
|
+
namespace Rice::detail
|
56
63
|
{
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
}
|
64
|
+
template<>
|
65
|
+
struct Type<FanModeType>
|
66
|
+
{
|
67
|
+
static bool verify()
|
68
|
+
{
|
69
|
+
return true;
|
70
|
+
}
|
71
|
+
};
|
64
72
|
|
65
|
-
|
66
|
-
|
73
|
+
template<>
|
74
|
+
class From_Ruby<FanModeType>
|
75
|
+
{
|
67
76
|
public:
|
68
|
-
FanModeType(
|
69
|
-
|
70
|
-
|
71
|
-
operator torch::nn::init::FanModeType() {
|
77
|
+
FanModeType convert(VALUE x)
|
78
|
+
{
|
79
|
+
auto s = String(x).str();
|
72
80
|
if (s == "fan_in") {
|
73
81
|
return torch::kFanIn;
|
74
82
|
} else if (s == "fan_out") {
|
@@ -77,22 +85,24 @@ class FanModeType {
|
|
77
85
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
78
86
|
}
|
79
87
|
}
|
80
|
-
};
|
81
|
-
|
82
|
-
template<>
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
+
};
|
89
|
+
|
90
|
+
template<>
|
91
|
+
struct Type<NonlinearityType>
|
92
|
+
{
|
93
|
+
static bool verify()
|
94
|
+
{
|
95
|
+
return true;
|
96
|
+
}
|
97
|
+
};
|
88
98
|
|
89
|
-
|
90
|
-
|
99
|
+
template<>
|
100
|
+
class From_Ruby<NonlinearityType>
|
101
|
+
{
|
91
102
|
public:
|
92
|
-
NonlinearityType(
|
93
|
-
|
94
|
-
|
95
|
-
operator torch::nn::init::NonlinearityType() {
|
103
|
+
NonlinearityType convert(VALUE x)
|
104
|
+
{
|
105
|
+
auto s = String(x).str();
|
96
106
|
if (s == "linear") {
|
97
107
|
return torch::kLinear;
|
98
108
|
} else if (s == "conv1d") {
|
@@ -119,102 +129,70 @@ class NonlinearityType {
|
|
119
129
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
120
130
|
}
|
121
131
|
}
|
122
|
-
};
|
132
|
+
};
|
133
|
+
|
134
|
+
template<>
|
135
|
+
struct Type<OptionalTensor>
|
136
|
+
{
|
137
|
+
static bool verify()
|
138
|
+
{
|
139
|
+
return true;
|
140
|
+
}
|
141
|
+
};
|
123
142
|
|
124
|
-
template<>
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
143
|
+
template<>
|
144
|
+
class From_Ruby<OptionalTensor>
|
145
|
+
{
|
146
|
+
public:
|
147
|
+
OptionalTensor convert(VALUE x)
|
148
|
+
{
|
149
|
+
return OptionalTensor(x);
|
150
|
+
}
|
151
|
+
};
|
152
|
+
|
153
|
+
template<>
|
154
|
+
struct Type<Scalar>
|
155
|
+
{
|
156
|
+
static bool verify()
|
157
|
+
{
|
158
|
+
return true;
|
159
|
+
}
|
160
|
+
};
|
130
161
|
|
131
|
-
|
132
|
-
|
162
|
+
template<>
|
163
|
+
class From_Ruby<Scalar>
|
164
|
+
{
|
133
165
|
public:
|
134
|
-
|
135
|
-
|
136
|
-
|
166
|
+
Scalar convert(VALUE x)
|
167
|
+
{
|
168
|
+
if (FIXNUM_P(x)) {
|
169
|
+
return torch::Scalar(From_Ruby<int64_t>().convert(x));
|
137
170
|
} else {
|
138
|
-
|
171
|
+
return torch::Scalar(From_Ruby<double>().convert(x));
|
139
172
|
}
|
140
173
|
}
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
174
|
+
};
|
175
|
+
|
176
|
+
template<typename T>
|
177
|
+
struct Type<torch::optional<T>>
|
178
|
+
{
|
179
|
+
static bool verify()
|
180
|
+
{
|
181
|
+
return true;
|
146
182
|
}
|
147
|
-
};
|
148
|
-
|
149
|
-
template<>
|
150
|
-
inline
|
151
|
-
Scalar from_ruby<Scalar>(Object x)
|
152
|
-
{
|
153
|
-
if (x.rb_type() == T_FIXNUM) {
|
154
|
-
return torch::Scalar(from_ruby<int64_t>(x));
|
155
|
-
} else {
|
156
|
-
return torch::Scalar(from_ruby<double>(x));
|
157
|
-
}
|
158
|
-
}
|
159
|
-
|
160
|
-
template<>
|
161
|
-
inline
|
162
|
-
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
163
|
-
{
|
164
|
-
return OptionalTensor(x);
|
165
|
-
}
|
166
|
-
|
167
|
-
template<>
|
168
|
-
inline
|
169
|
-
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
170
|
-
{
|
171
|
-
if (x.is_nil()) {
|
172
|
-
return torch::nullopt;
|
173
|
-
} else {
|
174
|
-
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
175
|
-
}
|
176
|
-
}
|
177
|
-
|
178
|
-
template<>
|
179
|
-
inline
|
180
|
-
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
181
|
-
{
|
182
|
-
if (x.is_nil()) {
|
183
|
-
return torch::nullopt;
|
184
|
-
} else {
|
185
|
-
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
186
|
-
}
|
187
|
-
}
|
183
|
+
};
|
188
184
|
|
189
|
-
template
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
}
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
203
|
-
{
|
204
|
-
if (x.is_nil()) {
|
205
|
-
return torch::nullopt;
|
206
|
-
} else {
|
207
|
-
return torch::optional<bool>{from_ruby<bool>(x)};
|
208
|
-
}
|
209
|
-
}
|
210
|
-
|
211
|
-
template<>
|
212
|
-
inline
|
213
|
-
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
214
|
-
{
|
215
|
-
if (x.is_nil()) {
|
216
|
-
return torch::nullopt;
|
217
|
-
} else {
|
218
|
-
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
219
|
-
}
|
185
|
+
template<typename T>
|
186
|
+
class From_Ruby<torch::optional<T>>
|
187
|
+
{
|
188
|
+
public:
|
189
|
+
torch::optional<T> convert(VALUE x)
|
190
|
+
{
|
191
|
+
if (NIL_P(x)) {
|
192
|
+
return torch::nullopt;
|
193
|
+
} else {
|
194
|
+
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
195
|
+
}
|
196
|
+
}
|
197
|
+
};
|
220
198
|
}
|
@@ -0,0 +1,320 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "tensor_functions.h"
|
6
|
+
#include "ruby_arg_parser.h"
|
7
|
+
#include "templates.h"
|
8
|
+
#include "utils.h"
|
9
|
+
|
10
|
+
using namespace Rice;
|
11
|
+
using torch::indexing::TensorIndex;
|
12
|
+
|
13
|
+
namespace Rice::detail
|
14
|
+
{
|
15
|
+
template<typename T>
|
16
|
+
struct Type<c10::complex<T>>
|
17
|
+
{
|
18
|
+
static bool verify()
|
19
|
+
{
|
20
|
+
return true;
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
template<typename T>
|
25
|
+
class To_Ruby<c10::complex<T>>
|
26
|
+
{
|
27
|
+
public:
|
28
|
+
VALUE convert(c10::complex<T> const& x)
|
29
|
+
{
|
30
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
+
}
|
32
|
+
};
|
33
|
+
}
|
34
|
+
|
35
|
+
template<typename T>
|
36
|
+
Array flat_data(Tensor& tensor) {
|
37
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
38
|
+
|
39
|
+
Array a;
|
40
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
41
|
+
a.push(view[i].item().to<T>());
|
42
|
+
}
|
43
|
+
return a;
|
44
|
+
}
|
45
|
+
|
46
|
+
Class rb_cTensor;
|
47
|
+
|
48
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
49
|
+
Object obj;
|
50
|
+
|
51
|
+
std::vector<TensorIndex> indices;
|
52
|
+
indices.reserve(a.size());
|
53
|
+
|
54
|
+
for (long i = 0; i < a.size(); i++) {
|
55
|
+
obj = a[i];
|
56
|
+
|
57
|
+
if (obj.is_instance_of(rb_cInteger)) {
|
58
|
+
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
59
|
+
} else if (obj.is_instance_of(rb_cRange)) {
|
60
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
61
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
62
|
+
|
63
|
+
Object begin = obj.call("begin");
|
64
|
+
if (!begin.is_nil()) {
|
65
|
+
start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
|
66
|
+
}
|
67
|
+
|
68
|
+
Object end = obj.call("end");
|
69
|
+
if (!end.is_nil()) {
|
70
|
+
stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
|
71
|
+
}
|
72
|
+
|
73
|
+
Object exclude_end = obj.call("exclude_end?");
|
74
|
+
if (stop_index.has_value() && !exclude_end) {
|
75
|
+
if (stop_index.value() == -1) {
|
76
|
+
stop_index = torch::nullopt;
|
77
|
+
} else {
|
78
|
+
stop_index = stop_index.value() + 1;
|
79
|
+
}
|
80
|
+
}
|
81
|
+
|
82
|
+
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
83
|
+
} else if (obj.is_instance_of(rb_cTensor)) {
|
84
|
+
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
85
|
+
} else if (obj.is_nil()) {
|
86
|
+
indices.push_back(torch::indexing::None);
|
87
|
+
} else if (obj == True || obj == False) {
|
88
|
+
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
89
|
+
} else {
|
90
|
+
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
91
|
+
}
|
92
|
+
}
|
93
|
+
return indices;
|
94
|
+
}
|
95
|
+
|
96
|
+
// hack (removes inputs argument)
|
97
|
+
// https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
|
98
|
+
// TODO add support for inputs argument
|
99
|
+
// _backward
|
100
|
+
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
101
|
+
{
|
102
|
+
HANDLE_TH_ERRORS
|
103
|
+
Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
|
104
|
+
static RubyArgParser parser({
|
105
|
+
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
106
|
+
});
|
107
|
+
ParsedArgs<4> parsed_args;
|
108
|
+
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
109
|
+
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
110
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
111
|
+
// in future, release GVL
|
112
|
+
self._backward(inputs, gradient, retain_graph, create_graph);
|
113
|
+
};
|
114
|
+
dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
|
115
|
+
RETURN_NIL
|
116
|
+
END_HANDLE_TH_ERRORS
|
117
|
+
}
|
118
|
+
|
119
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
|
120
|
+
rb_cTensor = c;
|
121
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
122
|
+
add_tensor_functions(rb_cTensor);
|
123
|
+
THPVariableClass = rb_cTensor.value();
|
124
|
+
|
125
|
+
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
126
|
+
|
127
|
+
rb_cTensor
|
128
|
+
.define_method("cuda?", &torch::Tensor::is_cuda)
|
129
|
+
.define_method("sparse?", &torch::Tensor::is_sparse)
|
130
|
+
.define_method("quantized?", &torch::Tensor::is_quantized)
|
131
|
+
.define_method("dim", &torch::Tensor::dim)
|
132
|
+
.define_method("numel", &torch::Tensor::numel)
|
133
|
+
.define_method("element_size", &torch::Tensor::element_size)
|
134
|
+
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
135
|
+
.define_method(
|
136
|
+
"_size",
|
137
|
+
[](Tensor& self, int64_t dim) {
|
138
|
+
return self.size(dim);
|
139
|
+
})
|
140
|
+
.define_method(
|
141
|
+
"_stride",
|
142
|
+
[](Tensor& self, int64_t dim) {
|
143
|
+
return self.stride(dim);
|
144
|
+
})
|
145
|
+
// in C++ for performance
|
146
|
+
.define_method(
|
147
|
+
"shape",
|
148
|
+
[](Tensor& self) {
|
149
|
+
Array a;
|
150
|
+
for (auto &size : self.sizes()) {
|
151
|
+
a.push(size);
|
152
|
+
}
|
153
|
+
return a;
|
154
|
+
})
|
155
|
+
.define_method(
|
156
|
+
"_strides",
|
157
|
+
[](Tensor& self) {
|
158
|
+
Array a;
|
159
|
+
for (auto &stride : self.strides()) {
|
160
|
+
a.push(stride);
|
161
|
+
}
|
162
|
+
return a;
|
163
|
+
})
|
164
|
+
.define_method(
|
165
|
+
"_index",
|
166
|
+
[](Tensor& self, Array indices) {
|
167
|
+
auto vec = index_vector(indices);
|
168
|
+
return self.index(vec);
|
169
|
+
})
|
170
|
+
.define_method(
|
171
|
+
"_index_put_custom",
|
172
|
+
[](Tensor& self, Array indices, torch::Tensor& value) {
|
173
|
+
auto vec = index_vector(indices);
|
174
|
+
return self.index_put_(vec, value);
|
175
|
+
})
|
176
|
+
.define_method(
|
177
|
+
"contiguous?",
|
178
|
+
[](Tensor& self) {
|
179
|
+
return self.is_contiguous();
|
180
|
+
})
|
181
|
+
.define_method(
|
182
|
+
"_requires_grad!",
|
183
|
+
[](Tensor& self, bool requires_grad) {
|
184
|
+
return self.set_requires_grad(requires_grad);
|
185
|
+
})
|
186
|
+
.define_method(
|
187
|
+
"grad",
|
188
|
+
[](Tensor& self) {
|
189
|
+
auto grad = self.grad();
|
190
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
191
|
+
})
|
192
|
+
.define_method(
|
193
|
+
"grad=",
|
194
|
+
[](Tensor& self, torch::Tensor& grad) {
|
195
|
+
self.mutable_grad() = grad;
|
196
|
+
})
|
197
|
+
.define_method(
|
198
|
+
"_dtype",
|
199
|
+
[](Tensor& self) {
|
200
|
+
return (int) at::typeMetaToScalarType(self.dtype());
|
201
|
+
})
|
202
|
+
.define_method(
|
203
|
+
"_type",
|
204
|
+
[](Tensor& self, int dtype) {
|
205
|
+
return self.toType((torch::ScalarType) dtype);
|
206
|
+
})
|
207
|
+
.define_method(
|
208
|
+
"_layout",
|
209
|
+
[](Tensor& self) {
|
210
|
+
std::stringstream s;
|
211
|
+
s << self.layout();
|
212
|
+
return s.str();
|
213
|
+
})
|
214
|
+
.define_method(
|
215
|
+
"device",
|
216
|
+
[](Tensor& self) {
|
217
|
+
std::stringstream s;
|
218
|
+
s << self.device();
|
219
|
+
return s.str();
|
220
|
+
})
|
221
|
+
.define_method(
|
222
|
+
"_data_str",
|
223
|
+
[](Tensor& self) {
|
224
|
+
auto tensor = self;
|
225
|
+
|
226
|
+
// move to CPU to get data
|
227
|
+
if (tensor.device().type() != torch::kCPU) {
|
228
|
+
torch::Device device("cpu");
|
229
|
+
tensor = tensor.to(device);
|
230
|
+
}
|
231
|
+
|
232
|
+
if (!tensor.is_contiguous()) {
|
233
|
+
tensor = tensor.contiguous();
|
234
|
+
}
|
235
|
+
|
236
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
237
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
238
|
+
})
|
239
|
+
// for TorchVision
|
240
|
+
.define_method(
|
241
|
+
"_data_ptr",
|
242
|
+
[](Tensor& self) {
|
243
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
244
|
+
})
|
245
|
+
// TODO figure out a better way to do this
|
246
|
+
.define_method(
|
247
|
+
"_flat_data",
|
248
|
+
[](Tensor& self) {
|
249
|
+
auto tensor = self;
|
250
|
+
|
251
|
+
// move to CPU to get data
|
252
|
+
if (tensor.device().type() != torch::kCPU) {
|
253
|
+
torch::Device device("cpu");
|
254
|
+
tensor = tensor.to(device);
|
255
|
+
}
|
256
|
+
|
257
|
+
auto dtype = tensor.dtype();
|
258
|
+
if (dtype == torch::kByte) {
|
259
|
+
return flat_data<uint8_t>(tensor);
|
260
|
+
} else if (dtype == torch::kChar) {
|
261
|
+
return flat_data<int8_t>(tensor);
|
262
|
+
} else if (dtype == torch::kShort) {
|
263
|
+
return flat_data<int16_t>(tensor);
|
264
|
+
} else if (dtype == torch::kInt) {
|
265
|
+
return flat_data<int32_t>(tensor);
|
266
|
+
} else if (dtype == torch::kLong) {
|
267
|
+
return flat_data<int64_t>(tensor);
|
268
|
+
} else if (dtype == torch::kFloat) {
|
269
|
+
return flat_data<float>(tensor);
|
270
|
+
} else if (dtype == torch::kDouble) {
|
271
|
+
return flat_data<double>(tensor);
|
272
|
+
} else if (dtype == torch::kBool) {
|
273
|
+
return flat_data<bool>(tensor);
|
274
|
+
} else if (dtype == torch::kComplexFloat) {
|
275
|
+
return flat_data<c10::complex<float>>(tensor);
|
276
|
+
} else if (dtype == torch::kComplexDouble) {
|
277
|
+
return flat_data<c10::complex<double>>(tensor);
|
278
|
+
} else {
|
279
|
+
throw std::runtime_error("Unsupported type");
|
280
|
+
}
|
281
|
+
})
|
282
|
+
.define_method(
|
283
|
+
"_to",
|
284
|
+
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
285
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
286
|
+
});
|
287
|
+
|
288
|
+
rb_cTensorOptions
|
289
|
+
.add_handler<torch::Error>(handle_error)
|
290
|
+
.define_method(
|
291
|
+
"dtype",
|
292
|
+
[](torch::TensorOptions& self, int dtype) {
|
293
|
+
return self.dtype((torch::ScalarType) dtype);
|
294
|
+
})
|
295
|
+
.define_method(
|
296
|
+
"layout",
|
297
|
+
[](torch::TensorOptions& self, const std::string& layout) {
|
298
|
+
torch::Layout l;
|
299
|
+
if (layout == "strided") {
|
300
|
+
l = torch::kStrided;
|
301
|
+
} else if (layout == "sparse") {
|
302
|
+
l = torch::kSparse;
|
303
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
304
|
+
} else {
|
305
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
306
|
+
}
|
307
|
+
return self.layout(l);
|
308
|
+
})
|
309
|
+
.define_method(
|
310
|
+
"device",
|
311
|
+
[](torch::TensorOptions& self, const std::string& device) {
|
312
|
+
torch::Device d(device);
|
313
|
+
return self.device(d);
|
314
|
+
})
|
315
|
+
.define_method(
|
316
|
+
"requires_grad",
|
317
|
+
[](torch::TensorOptions& self, bool requires_grad) {
|
318
|
+
return self.requires_grad(requires_grad);
|
319
|
+
});
|
320
|
+
}
|