torch-rb 0.6.0 → 0.8.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +43 -6
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +22 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +23 -12
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +2 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +16 -0
- data/lib/torch/tensor.rb +2 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +11 -0
- metadata +20 -5
data/ext/torch/templates.h
CHANGED
@@ -4,8 +4,7 @@
|
|
4
4
|
#undef isfinite
|
5
5
|
#endif
|
6
6
|
|
7
|
-
#include <rice/
|
8
|
-
#include <rice/Object.hpp>
|
7
|
+
#include <rice/rice.hpp>
|
9
8
|
|
10
9
|
using namespace Rice;
|
11
10
|
|
@@ -22,6 +21,10 @@ using torch::IntArrayRef;
|
|
22
21
|
using torch::ArrayRef;
|
23
22
|
using torch::TensorList;
|
24
23
|
using torch::Storage;
|
24
|
+
using ScalarList = ArrayRef<Scalar>;
|
25
|
+
|
26
|
+
using torch::nn::init::FanModeType;
|
27
|
+
using torch::nn::init::NonlinearityType;
|
25
28
|
|
26
29
|
#define HANDLE_TH_ERRORS \
|
27
30
|
try {
|
@@ -38,37 +41,42 @@ using torch::Storage;
|
|
38
41
|
#define RETURN_NIL \
|
39
42
|
return Qnil;
|
40
43
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
{
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
44
|
+
class OptionalTensor {
|
45
|
+
torch::Tensor value;
|
46
|
+
public:
|
47
|
+
OptionalTensor(Object o) {
|
48
|
+
if (o.is_nil()) {
|
49
|
+
value = {};
|
50
|
+
} else {
|
51
|
+
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
52
|
+
}
|
53
|
+
}
|
54
|
+
OptionalTensor(torch::Tensor o) {
|
55
|
+
value = o;
|
56
|
+
}
|
57
|
+
operator torch::Tensor() const {
|
58
|
+
return value;
|
59
|
+
}
|
60
|
+
};
|
52
61
|
|
53
|
-
|
54
|
-
inline
|
55
|
-
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
62
|
+
namespace Rice::detail
|
56
63
|
{
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
}
|
64
|
+
template<>
|
65
|
+
struct Type<FanModeType>
|
66
|
+
{
|
67
|
+
static bool verify()
|
68
|
+
{
|
69
|
+
return true;
|
70
|
+
}
|
71
|
+
};
|
64
72
|
|
65
|
-
|
66
|
-
|
73
|
+
template<>
|
74
|
+
class From_Ruby<FanModeType>
|
75
|
+
{
|
67
76
|
public:
|
68
|
-
FanModeType(
|
69
|
-
|
70
|
-
|
71
|
-
operator torch::nn::init::FanModeType() {
|
77
|
+
FanModeType convert(VALUE x)
|
78
|
+
{
|
79
|
+
auto s = String(x).str();
|
72
80
|
if (s == "fan_in") {
|
73
81
|
return torch::kFanIn;
|
74
82
|
} else if (s == "fan_out") {
|
@@ -77,22 +85,24 @@ class FanModeType {
|
|
77
85
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
78
86
|
}
|
79
87
|
}
|
80
|
-
};
|
81
|
-
|
82
|
-
template<>
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
+
};
|
89
|
+
|
90
|
+
template<>
|
91
|
+
struct Type<NonlinearityType>
|
92
|
+
{
|
93
|
+
static bool verify()
|
94
|
+
{
|
95
|
+
return true;
|
96
|
+
}
|
97
|
+
};
|
88
98
|
|
89
|
-
|
90
|
-
|
99
|
+
template<>
|
100
|
+
class From_Ruby<NonlinearityType>
|
101
|
+
{
|
91
102
|
public:
|
92
|
-
NonlinearityType(
|
93
|
-
|
94
|
-
|
95
|
-
operator torch::nn::init::NonlinearityType() {
|
103
|
+
NonlinearityType convert(VALUE x)
|
104
|
+
{
|
105
|
+
auto s = String(x).str();
|
96
106
|
if (s == "linear") {
|
97
107
|
return torch::kLinear;
|
98
108
|
} else if (s == "conv1d") {
|
@@ -119,102 +129,70 @@ class NonlinearityType {
|
|
119
129
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
120
130
|
}
|
121
131
|
}
|
122
|
-
};
|
132
|
+
};
|
133
|
+
|
134
|
+
template<>
|
135
|
+
struct Type<OptionalTensor>
|
136
|
+
{
|
137
|
+
static bool verify()
|
138
|
+
{
|
139
|
+
return true;
|
140
|
+
}
|
141
|
+
};
|
123
142
|
|
124
|
-
template<>
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
143
|
+
template<>
|
144
|
+
class From_Ruby<OptionalTensor>
|
145
|
+
{
|
146
|
+
public:
|
147
|
+
OptionalTensor convert(VALUE x)
|
148
|
+
{
|
149
|
+
return OptionalTensor(x);
|
150
|
+
}
|
151
|
+
};
|
152
|
+
|
153
|
+
template<>
|
154
|
+
struct Type<Scalar>
|
155
|
+
{
|
156
|
+
static bool verify()
|
157
|
+
{
|
158
|
+
return true;
|
159
|
+
}
|
160
|
+
};
|
130
161
|
|
131
|
-
|
132
|
-
|
162
|
+
template<>
|
163
|
+
class From_Ruby<Scalar>
|
164
|
+
{
|
133
165
|
public:
|
134
|
-
|
135
|
-
|
136
|
-
|
166
|
+
Scalar convert(VALUE x)
|
167
|
+
{
|
168
|
+
if (FIXNUM_P(x)) {
|
169
|
+
return torch::Scalar(From_Ruby<int64_t>().convert(x));
|
137
170
|
} else {
|
138
|
-
|
171
|
+
return torch::Scalar(From_Ruby<double>().convert(x));
|
139
172
|
}
|
140
173
|
}
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
174
|
+
};
|
175
|
+
|
176
|
+
template<typename T>
|
177
|
+
struct Type<torch::optional<T>>
|
178
|
+
{
|
179
|
+
static bool verify()
|
180
|
+
{
|
181
|
+
return true;
|
146
182
|
}
|
147
|
-
};
|
148
|
-
|
149
|
-
template<>
|
150
|
-
inline
|
151
|
-
Scalar from_ruby<Scalar>(Object x)
|
152
|
-
{
|
153
|
-
if (x.rb_type() == T_FIXNUM) {
|
154
|
-
return torch::Scalar(from_ruby<int64_t>(x));
|
155
|
-
} else {
|
156
|
-
return torch::Scalar(from_ruby<double>(x));
|
157
|
-
}
|
158
|
-
}
|
159
|
-
|
160
|
-
template<>
|
161
|
-
inline
|
162
|
-
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
163
|
-
{
|
164
|
-
return OptionalTensor(x);
|
165
|
-
}
|
166
|
-
|
167
|
-
template<>
|
168
|
-
inline
|
169
|
-
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
170
|
-
{
|
171
|
-
if (x.is_nil()) {
|
172
|
-
return torch::nullopt;
|
173
|
-
} else {
|
174
|
-
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
175
|
-
}
|
176
|
-
}
|
177
|
-
|
178
|
-
template<>
|
179
|
-
inline
|
180
|
-
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
181
|
-
{
|
182
|
-
if (x.is_nil()) {
|
183
|
-
return torch::nullopt;
|
184
|
-
} else {
|
185
|
-
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
186
|
-
}
|
187
|
-
}
|
183
|
+
};
|
188
184
|
|
189
|
-
template
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
}
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
203
|
-
{
|
204
|
-
if (x.is_nil()) {
|
205
|
-
return torch::nullopt;
|
206
|
-
} else {
|
207
|
-
return torch::optional<bool>{from_ruby<bool>(x)};
|
208
|
-
}
|
209
|
-
}
|
210
|
-
|
211
|
-
template<>
|
212
|
-
inline
|
213
|
-
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
214
|
-
{
|
215
|
-
if (x.is_nil()) {
|
216
|
-
return torch::nullopt;
|
217
|
-
} else {
|
218
|
-
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
219
|
-
}
|
185
|
+
template<typename T>
|
186
|
+
class From_Ruby<torch::optional<T>>
|
187
|
+
{
|
188
|
+
public:
|
189
|
+
torch::optional<T> convert(VALUE x)
|
190
|
+
{
|
191
|
+
if (NIL_P(x)) {
|
192
|
+
return torch::nullopt;
|
193
|
+
} else {
|
194
|
+
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
195
|
+
}
|
196
|
+
}
|
197
|
+
};
|
220
198
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Module.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
5
4
|
|
6
5
|
#include "tensor_functions.h"
|
7
6
|
#include "ruby_arg_parser.h"
|
@@ -11,6 +10,39 @@
|
|
11
10
|
using namespace Rice;
|
12
11
|
using torch::indexing::TensorIndex;
|
13
12
|
|
13
|
+
namespace Rice::detail
|
14
|
+
{
|
15
|
+
template<typename T>
|
16
|
+
struct Type<c10::complex<T>>
|
17
|
+
{
|
18
|
+
static bool verify()
|
19
|
+
{
|
20
|
+
return true;
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
template<typename T>
|
25
|
+
class To_Ruby<c10::complex<T>>
|
26
|
+
{
|
27
|
+
public:
|
28
|
+
VALUE convert(c10::complex<T> const& x)
|
29
|
+
{
|
30
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
+
}
|
32
|
+
};
|
33
|
+
}
|
34
|
+
|
35
|
+
template<typename T>
|
36
|
+
Array flat_data(Tensor& tensor) {
|
37
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
38
|
+
|
39
|
+
Array a;
|
40
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
41
|
+
a.push(view[i].item().to<T>());
|
42
|
+
}
|
43
|
+
return a;
|
44
|
+
}
|
45
|
+
|
14
46
|
Class rb_cTensor;
|
15
47
|
|
16
48
|
std::vector<TensorIndex> index_vector(Array a) {
|
@@ -23,19 +55,19 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
23
55
|
obj = a[i];
|
24
56
|
|
25
57
|
if (obj.is_instance_of(rb_cInteger)) {
|
26
|
-
indices.push_back(
|
58
|
+
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
27
59
|
} else if (obj.is_instance_of(rb_cRange)) {
|
28
60
|
torch::optional<int64_t> start_index = torch::nullopt;
|
29
61
|
torch::optional<int64_t> stop_index = torch::nullopt;
|
30
62
|
|
31
63
|
Object begin = obj.call("begin");
|
32
64
|
if (!begin.is_nil()) {
|
33
|
-
start_index =
|
65
|
+
start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
|
34
66
|
}
|
35
67
|
|
36
68
|
Object end = obj.call("end");
|
37
69
|
if (!end.is_nil()) {
|
38
|
-
stop_index =
|
70
|
+
stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
|
39
71
|
}
|
40
72
|
|
41
73
|
Object exclude_end = obj.call("exclude_end?");
|
@@ -49,11 +81,11 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
49
81
|
|
50
82
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
51
83
|
} else if (obj.is_instance_of(rb_cTensor)) {
|
52
|
-
indices.push_back(
|
84
|
+
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
53
85
|
} else if (obj.is_nil()) {
|
54
86
|
indices.push_back(torch::indexing::None);
|
55
87
|
} else if (obj == True || obj == False) {
|
56
|
-
indices.push_back(
|
88
|
+
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
57
89
|
} else {
|
58
90
|
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
59
91
|
}
|
@@ -68,7 +100,7 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
68
100
|
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
69
101
|
{
|
70
102
|
HANDLE_TH_ERRORS
|
71
|
-
Tensor& self =
|
103
|
+
Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
|
72
104
|
static RubyArgParser parser({
|
73
105
|
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
74
106
|
});
|
@@ -84,8 +116,8 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
84
116
|
END_HANDLE_TH_ERRORS
|
85
117
|
}
|
86
118
|
|
87
|
-
void init_tensor(Rice::Module& m) {
|
88
|
-
rb_cTensor =
|
119
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
|
120
|
+
rb_cTensor = c;
|
89
121
|
rb_cTensor.add_handler<torch::Error>(handle_error);
|
90
122
|
add_tensor_functions(rb_cTensor);
|
91
123
|
THPVariableClass = rb_cTensor.value();
|
@@ -102,18 +134,18 @@ void init_tensor(Rice::Module& m) {
|
|
102
134
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
103
135
|
.define_method(
|
104
136
|
"_size",
|
105
|
-
|
137
|
+
[](Tensor& self, int64_t dim) {
|
106
138
|
return self.size(dim);
|
107
139
|
})
|
108
140
|
.define_method(
|
109
141
|
"_stride",
|
110
|
-
|
142
|
+
[](Tensor& self, int64_t dim) {
|
111
143
|
return self.stride(dim);
|
112
144
|
})
|
113
145
|
// in C++ for performance
|
114
146
|
.define_method(
|
115
147
|
"shape",
|
116
|
-
|
148
|
+
[](Tensor& self) {
|
117
149
|
Array a;
|
118
150
|
for (auto &size : self.sizes()) {
|
119
151
|
a.push(size);
|
@@ -122,7 +154,7 @@ void init_tensor(Rice::Module& m) {
|
|
122
154
|
})
|
123
155
|
.define_method(
|
124
156
|
"_strides",
|
125
|
-
|
157
|
+
[](Tensor& self) {
|
126
158
|
Array a;
|
127
159
|
for (auto &stride : self.strides()) {
|
128
160
|
a.push(stride);
|
@@ -131,65 +163,65 @@ void init_tensor(Rice::Module& m) {
|
|
131
163
|
})
|
132
164
|
.define_method(
|
133
165
|
"_index",
|
134
|
-
|
166
|
+
[](Tensor& self, Array indices) {
|
135
167
|
auto vec = index_vector(indices);
|
136
168
|
return self.index(vec);
|
137
169
|
})
|
138
170
|
.define_method(
|
139
171
|
"_index_put_custom",
|
140
|
-
|
172
|
+
[](Tensor& self, Array indices, torch::Tensor& value) {
|
141
173
|
auto vec = index_vector(indices);
|
142
174
|
return self.index_put_(vec, value);
|
143
175
|
})
|
144
176
|
.define_method(
|
145
177
|
"contiguous?",
|
146
|
-
|
178
|
+
[](Tensor& self) {
|
147
179
|
return self.is_contiguous();
|
148
180
|
})
|
149
181
|
.define_method(
|
150
182
|
"_requires_grad!",
|
151
|
-
|
183
|
+
[](Tensor& self, bool requires_grad) {
|
152
184
|
return self.set_requires_grad(requires_grad);
|
153
185
|
})
|
154
186
|
.define_method(
|
155
187
|
"grad",
|
156
|
-
|
188
|
+
[](Tensor& self) {
|
157
189
|
auto grad = self.grad();
|
158
|
-
return grad.defined() ?
|
190
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
159
191
|
})
|
160
192
|
.define_method(
|
161
193
|
"grad=",
|
162
|
-
|
194
|
+
[](Tensor& self, torch::Tensor& grad) {
|
163
195
|
self.mutable_grad() = grad;
|
164
196
|
})
|
165
197
|
.define_method(
|
166
198
|
"_dtype",
|
167
|
-
|
199
|
+
[](Tensor& self) {
|
168
200
|
return (int) at::typeMetaToScalarType(self.dtype());
|
169
201
|
})
|
170
202
|
.define_method(
|
171
203
|
"_type",
|
172
|
-
|
204
|
+
[](Tensor& self, int dtype) {
|
173
205
|
return self.toType((torch::ScalarType) dtype);
|
174
206
|
})
|
175
207
|
.define_method(
|
176
208
|
"_layout",
|
177
|
-
|
209
|
+
[](Tensor& self) {
|
178
210
|
std::stringstream s;
|
179
211
|
s << self.layout();
|
180
212
|
return s.str();
|
181
213
|
})
|
182
214
|
.define_method(
|
183
215
|
"device",
|
184
|
-
|
216
|
+
[](Tensor& self) {
|
185
217
|
std::stringstream s;
|
186
218
|
s << self.device();
|
187
219
|
return s.str();
|
188
220
|
})
|
189
221
|
.define_method(
|
190
222
|
"_data_str",
|
191
|
-
|
192
|
-
|
223
|
+
[](Tensor& self) {
|
224
|
+
auto tensor = self;
|
193
225
|
|
194
226
|
// move to CPU to get data
|
195
227
|
if (tensor.device().type() != torch::kCPU) {
|
@@ -207,14 +239,14 @@ void init_tensor(Rice::Module& m) {
|
|
207
239
|
// for TorchVision
|
208
240
|
.define_method(
|
209
241
|
"_data_ptr",
|
210
|
-
|
242
|
+
[](Tensor& self) {
|
211
243
|
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
212
244
|
})
|
213
245
|
// TODO figure out a better way to do this
|
214
246
|
.define_method(
|
215
247
|
"_flat_data",
|
216
|
-
|
217
|
-
|
248
|
+
[](Tensor& self) {
|
249
|
+
auto tensor = self;
|
218
250
|
|
219
251
|
// move to CPU to get data
|
220
252
|
if (tensor.device().type() != torch::kCPU) {
|
@@ -222,66 +254,47 @@ void init_tensor(Rice::Module& m) {
|
|
222
254
|
tensor = tensor.to(device);
|
223
255
|
}
|
224
256
|
|
225
|
-
Array a;
|
226
257
|
auto dtype = tensor.dtype();
|
227
|
-
|
228
|
-
Tensor view = tensor.reshape({tensor.numel()});
|
229
|
-
|
230
|
-
// TODO DRY if someone knows C++
|
231
258
|
if (dtype == torch::kByte) {
|
232
|
-
|
233
|
-
a.push(view[i].item().to<uint8_t>());
|
234
|
-
}
|
259
|
+
return flat_data<uint8_t>(tensor);
|
235
260
|
} else if (dtype == torch::kChar) {
|
236
|
-
|
237
|
-
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
238
|
-
}
|
261
|
+
return flat_data<int8_t>(tensor);
|
239
262
|
} else if (dtype == torch::kShort) {
|
240
|
-
|
241
|
-
a.push(view[i].item().to<int16_t>());
|
242
|
-
}
|
263
|
+
return flat_data<int16_t>(tensor);
|
243
264
|
} else if (dtype == torch::kInt) {
|
244
|
-
|
245
|
-
a.push(view[i].item().to<int32_t>());
|
246
|
-
}
|
265
|
+
return flat_data<int32_t>(tensor);
|
247
266
|
} else if (dtype == torch::kLong) {
|
248
|
-
|
249
|
-
a.push(view[i].item().to<int64_t>());
|
250
|
-
}
|
267
|
+
return flat_data<int64_t>(tensor);
|
251
268
|
} else if (dtype == torch::kFloat) {
|
252
|
-
|
253
|
-
a.push(view[i].item().to<float>());
|
254
|
-
}
|
269
|
+
return flat_data<float>(tensor);
|
255
270
|
} else if (dtype == torch::kDouble) {
|
256
|
-
|
257
|
-
a.push(view[i].item().to<double>());
|
258
|
-
}
|
271
|
+
return flat_data<double>(tensor);
|
259
272
|
} else if (dtype == torch::kBool) {
|
260
|
-
|
261
|
-
|
262
|
-
|
273
|
+
return flat_data<bool>(tensor);
|
274
|
+
} else if (dtype == torch::kComplexFloat) {
|
275
|
+
return flat_data<c10::complex<float>>(tensor);
|
276
|
+
} else if (dtype == torch::kComplexDouble) {
|
277
|
+
return flat_data<c10::complex<double>>(tensor);
|
263
278
|
} else {
|
264
279
|
throw std::runtime_error("Unsupported type");
|
265
280
|
}
|
266
|
-
return a;
|
267
281
|
})
|
268
282
|
.define_method(
|
269
283
|
"_to",
|
270
|
-
|
284
|
+
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
271
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
272
286
|
});
|
273
287
|
|
274
|
-
|
288
|
+
rb_cTensorOptions
|
275
289
|
.add_handler<torch::Error>(handle_error)
|
276
|
-
.define_constructor(Rice::Constructor<torch::TensorOptions>())
|
277
290
|
.define_method(
|
278
291
|
"dtype",
|
279
|
-
|
292
|
+
[](torch::TensorOptions& self, int dtype) {
|
280
293
|
return self.dtype((torch::ScalarType) dtype);
|
281
294
|
})
|
282
295
|
.define_method(
|
283
296
|
"layout",
|
284
|
-
|
297
|
+
[](torch::TensorOptions& self, const std::string& layout) {
|
285
298
|
torch::Layout l;
|
286
299
|
if (layout == "strided") {
|
287
300
|
l = torch::kStrided;
|
@@ -295,13 +308,13 @@ void init_tensor(Rice::Module& m) {
|
|
295
308
|
})
|
296
309
|
.define_method(
|
297
310
|
"device",
|
298
|
-
|
311
|
+
[](torch::TensorOptions& self, const std::string& device) {
|
299
312
|
torch::Device d(device);
|
300
313
|
return self.device(d);
|
301
314
|
})
|
302
315
|
.define_method(
|
303
316
|
"requires_grad",
|
304
|
-
|
317
|
+
[](torch::TensorOptions& self, bool requires_grad) {
|
305
318
|
return self.requires_grad(requires_grad);
|
306
319
|
});
|
307
320
|
}
|