torch-rb 0.6.0 → 0.8.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +43 -6
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +22 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +23 -12
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +2 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +16 -0
- data/lib/torch/tensor.rb +2 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +11 -0
- metadata +20 -5
data/ext/torch/templates.h
CHANGED
@@ -4,8 +4,7 @@
|
|
4
4
|
#undef isfinite
|
5
5
|
#endif
|
6
6
|
|
7
|
-
#include <rice/
|
8
|
-
#include <rice/Object.hpp>
|
7
|
+
#include <rice/rice.hpp>
|
9
8
|
|
10
9
|
using namespace Rice;
|
11
10
|
|
@@ -22,6 +21,10 @@ using torch::IntArrayRef;
|
|
22
21
|
using torch::ArrayRef;
|
23
22
|
using torch::TensorList;
|
24
23
|
using torch::Storage;
|
24
|
+
using ScalarList = ArrayRef<Scalar>;
|
25
|
+
|
26
|
+
using torch::nn::init::FanModeType;
|
27
|
+
using torch::nn::init::NonlinearityType;
|
25
28
|
|
26
29
|
#define HANDLE_TH_ERRORS \
|
27
30
|
try {
|
@@ -38,37 +41,42 @@ using torch::Storage;
|
|
38
41
|
#define RETURN_NIL \
|
39
42
|
return Qnil;
|
40
43
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
{
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
44
|
+
class OptionalTensor {
|
45
|
+
torch::Tensor value;
|
46
|
+
public:
|
47
|
+
OptionalTensor(Object o) {
|
48
|
+
if (o.is_nil()) {
|
49
|
+
value = {};
|
50
|
+
} else {
|
51
|
+
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
52
|
+
}
|
53
|
+
}
|
54
|
+
OptionalTensor(torch::Tensor o) {
|
55
|
+
value = o;
|
56
|
+
}
|
57
|
+
operator torch::Tensor() const {
|
58
|
+
return value;
|
59
|
+
}
|
60
|
+
};
|
52
61
|
|
53
|
-
|
54
|
-
inline
|
55
|
-
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
62
|
+
namespace Rice::detail
|
56
63
|
{
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
}
|
64
|
+
template<>
|
65
|
+
struct Type<FanModeType>
|
66
|
+
{
|
67
|
+
static bool verify()
|
68
|
+
{
|
69
|
+
return true;
|
70
|
+
}
|
71
|
+
};
|
64
72
|
|
65
|
-
|
66
|
-
|
73
|
+
template<>
|
74
|
+
class From_Ruby<FanModeType>
|
75
|
+
{
|
67
76
|
public:
|
68
|
-
FanModeType(
|
69
|
-
|
70
|
-
|
71
|
-
operator torch::nn::init::FanModeType() {
|
77
|
+
FanModeType convert(VALUE x)
|
78
|
+
{
|
79
|
+
auto s = String(x).str();
|
72
80
|
if (s == "fan_in") {
|
73
81
|
return torch::kFanIn;
|
74
82
|
} else if (s == "fan_out") {
|
@@ -77,22 +85,24 @@ class FanModeType {
|
|
77
85
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
78
86
|
}
|
79
87
|
}
|
80
|
-
};
|
81
|
-
|
82
|
-
template<>
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
+
};
|
89
|
+
|
90
|
+
template<>
|
91
|
+
struct Type<NonlinearityType>
|
92
|
+
{
|
93
|
+
static bool verify()
|
94
|
+
{
|
95
|
+
return true;
|
96
|
+
}
|
97
|
+
};
|
88
98
|
|
89
|
-
|
90
|
-
|
99
|
+
template<>
|
100
|
+
class From_Ruby<NonlinearityType>
|
101
|
+
{
|
91
102
|
public:
|
92
|
-
NonlinearityType(
|
93
|
-
|
94
|
-
|
95
|
-
operator torch::nn::init::NonlinearityType() {
|
103
|
+
NonlinearityType convert(VALUE x)
|
104
|
+
{
|
105
|
+
auto s = String(x).str();
|
96
106
|
if (s == "linear") {
|
97
107
|
return torch::kLinear;
|
98
108
|
} else if (s == "conv1d") {
|
@@ -119,102 +129,70 @@ class NonlinearityType {
|
|
119
129
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
120
130
|
}
|
121
131
|
}
|
122
|
-
};
|
132
|
+
};
|
133
|
+
|
134
|
+
template<>
|
135
|
+
struct Type<OptionalTensor>
|
136
|
+
{
|
137
|
+
static bool verify()
|
138
|
+
{
|
139
|
+
return true;
|
140
|
+
}
|
141
|
+
};
|
123
142
|
|
124
|
-
template<>
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
143
|
+
template<>
|
144
|
+
class From_Ruby<OptionalTensor>
|
145
|
+
{
|
146
|
+
public:
|
147
|
+
OptionalTensor convert(VALUE x)
|
148
|
+
{
|
149
|
+
return OptionalTensor(x);
|
150
|
+
}
|
151
|
+
};
|
152
|
+
|
153
|
+
template<>
|
154
|
+
struct Type<Scalar>
|
155
|
+
{
|
156
|
+
static bool verify()
|
157
|
+
{
|
158
|
+
return true;
|
159
|
+
}
|
160
|
+
};
|
130
161
|
|
131
|
-
|
132
|
-
|
162
|
+
template<>
|
163
|
+
class From_Ruby<Scalar>
|
164
|
+
{
|
133
165
|
public:
|
134
|
-
|
135
|
-
|
136
|
-
|
166
|
+
Scalar convert(VALUE x)
|
167
|
+
{
|
168
|
+
if (FIXNUM_P(x)) {
|
169
|
+
return torch::Scalar(From_Ruby<int64_t>().convert(x));
|
137
170
|
} else {
|
138
|
-
|
171
|
+
return torch::Scalar(From_Ruby<double>().convert(x));
|
139
172
|
}
|
140
173
|
}
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
174
|
+
};
|
175
|
+
|
176
|
+
template<typename T>
|
177
|
+
struct Type<torch::optional<T>>
|
178
|
+
{
|
179
|
+
static bool verify()
|
180
|
+
{
|
181
|
+
return true;
|
146
182
|
}
|
147
|
-
};
|
148
|
-
|
149
|
-
template<>
|
150
|
-
inline
|
151
|
-
Scalar from_ruby<Scalar>(Object x)
|
152
|
-
{
|
153
|
-
if (x.rb_type() == T_FIXNUM) {
|
154
|
-
return torch::Scalar(from_ruby<int64_t>(x));
|
155
|
-
} else {
|
156
|
-
return torch::Scalar(from_ruby<double>(x));
|
157
|
-
}
|
158
|
-
}
|
159
|
-
|
160
|
-
template<>
|
161
|
-
inline
|
162
|
-
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
163
|
-
{
|
164
|
-
return OptionalTensor(x);
|
165
|
-
}
|
166
|
-
|
167
|
-
template<>
|
168
|
-
inline
|
169
|
-
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
170
|
-
{
|
171
|
-
if (x.is_nil()) {
|
172
|
-
return torch::nullopt;
|
173
|
-
} else {
|
174
|
-
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
175
|
-
}
|
176
|
-
}
|
177
|
-
|
178
|
-
template<>
|
179
|
-
inline
|
180
|
-
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
181
|
-
{
|
182
|
-
if (x.is_nil()) {
|
183
|
-
return torch::nullopt;
|
184
|
-
} else {
|
185
|
-
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
186
|
-
}
|
187
|
-
}
|
183
|
+
};
|
188
184
|
|
189
|
-
template
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
}
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
203
|
-
{
|
204
|
-
if (x.is_nil()) {
|
205
|
-
return torch::nullopt;
|
206
|
-
} else {
|
207
|
-
return torch::optional<bool>{from_ruby<bool>(x)};
|
208
|
-
}
|
209
|
-
}
|
210
|
-
|
211
|
-
template<>
|
212
|
-
inline
|
213
|
-
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
214
|
-
{
|
215
|
-
if (x.is_nil()) {
|
216
|
-
return torch::nullopt;
|
217
|
-
} else {
|
218
|
-
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
219
|
-
}
|
185
|
+
template<typename T>
|
186
|
+
class From_Ruby<torch::optional<T>>
|
187
|
+
{
|
188
|
+
public:
|
189
|
+
torch::optional<T> convert(VALUE x)
|
190
|
+
{
|
191
|
+
if (NIL_P(x)) {
|
192
|
+
return torch::nullopt;
|
193
|
+
} else {
|
194
|
+
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
195
|
+
}
|
196
|
+
}
|
197
|
+
};
|
220
198
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Module.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
5
4
|
|
6
5
|
#include "tensor_functions.h"
|
7
6
|
#include "ruby_arg_parser.h"
|
@@ -11,6 +10,39 @@
|
|
11
10
|
using namespace Rice;
|
12
11
|
using torch::indexing::TensorIndex;
|
13
12
|
|
13
|
+
namespace Rice::detail
|
14
|
+
{
|
15
|
+
template<typename T>
|
16
|
+
struct Type<c10::complex<T>>
|
17
|
+
{
|
18
|
+
static bool verify()
|
19
|
+
{
|
20
|
+
return true;
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
template<typename T>
|
25
|
+
class To_Ruby<c10::complex<T>>
|
26
|
+
{
|
27
|
+
public:
|
28
|
+
VALUE convert(c10::complex<T> const& x)
|
29
|
+
{
|
30
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
+
}
|
32
|
+
};
|
33
|
+
}
|
34
|
+
|
35
|
+
template<typename T>
|
36
|
+
Array flat_data(Tensor& tensor) {
|
37
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
38
|
+
|
39
|
+
Array a;
|
40
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
41
|
+
a.push(view[i].item().to<T>());
|
42
|
+
}
|
43
|
+
return a;
|
44
|
+
}
|
45
|
+
|
14
46
|
Class rb_cTensor;
|
15
47
|
|
16
48
|
std::vector<TensorIndex> index_vector(Array a) {
|
@@ -23,19 +55,19 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
23
55
|
obj = a[i];
|
24
56
|
|
25
57
|
if (obj.is_instance_of(rb_cInteger)) {
|
26
|
-
indices.push_back(
|
58
|
+
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
27
59
|
} else if (obj.is_instance_of(rb_cRange)) {
|
28
60
|
torch::optional<int64_t> start_index = torch::nullopt;
|
29
61
|
torch::optional<int64_t> stop_index = torch::nullopt;
|
30
62
|
|
31
63
|
Object begin = obj.call("begin");
|
32
64
|
if (!begin.is_nil()) {
|
33
|
-
start_index =
|
65
|
+
start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
|
34
66
|
}
|
35
67
|
|
36
68
|
Object end = obj.call("end");
|
37
69
|
if (!end.is_nil()) {
|
38
|
-
stop_index =
|
70
|
+
stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
|
39
71
|
}
|
40
72
|
|
41
73
|
Object exclude_end = obj.call("exclude_end?");
|
@@ -49,11 +81,11 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
49
81
|
|
50
82
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
51
83
|
} else if (obj.is_instance_of(rb_cTensor)) {
|
52
|
-
indices.push_back(
|
84
|
+
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
53
85
|
} else if (obj.is_nil()) {
|
54
86
|
indices.push_back(torch::indexing::None);
|
55
87
|
} else if (obj == True || obj == False) {
|
56
|
-
indices.push_back(
|
88
|
+
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
57
89
|
} else {
|
58
90
|
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
59
91
|
}
|
@@ -68,7 +100,7 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
68
100
|
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
69
101
|
{
|
70
102
|
HANDLE_TH_ERRORS
|
71
|
-
Tensor& self =
|
103
|
+
Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
|
72
104
|
static RubyArgParser parser({
|
73
105
|
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
74
106
|
});
|
@@ -84,8 +116,8 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
84
116
|
END_HANDLE_TH_ERRORS
|
85
117
|
}
|
86
118
|
|
87
|
-
void init_tensor(Rice::Module& m) {
|
88
|
-
rb_cTensor =
|
119
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
|
120
|
+
rb_cTensor = c;
|
89
121
|
rb_cTensor.add_handler<torch::Error>(handle_error);
|
90
122
|
add_tensor_functions(rb_cTensor);
|
91
123
|
THPVariableClass = rb_cTensor.value();
|
@@ -102,18 +134,18 @@ void init_tensor(Rice::Module& m) {
|
|
102
134
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
103
135
|
.define_method(
|
104
136
|
"_size",
|
105
|
-
|
137
|
+
[](Tensor& self, int64_t dim) {
|
106
138
|
return self.size(dim);
|
107
139
|
})
|
108
140
|
.define_method(
|
109
141
|
"_stride",
|
110
|
-
|
142
|
+
[](Tensor& self, int64_t dim) {
|
111
143
|
return self.stride(dim);
|
112
144
|
})
|
113
145
|
// in C++ for performance
|
114
146
|
.define_method(
|
115
147
|
"shape",
|
116
|
-
|
148
|
+
[](Tensor& self) {
|
117
149
|
Array a;
|
118
150
|
for (auto &size : self.sizes()) {
|
119
151
|
a.push(size);
|
@@ -122,7 +154,7 @@ void init_tensor(Rice::Module& m) {
|
|
122
154
|
})
|
123
155
|
.define_method(
|
124
156
|
"_strides",
|
125
|
-
|
157
|
+
[](Tensor& self) {
|
126
158
|
Array a;
|
127
159
|
for (auto &stride : self.strides()) {
|
128
160
|
a.push(stride);
|
@@ -131,65 +163,65 @@ void init_tensor(Rice::Module& m) {
|
|
131
163
|
})
|
132
164
|
.define_method(
|
133
165
|
"_index",
|
134
|
-
|
166
|
+
[](Tensor& self, Array indices) {
|
135
167
|
auto vec = index_vector(indices);
|
136
168
|
return self.index(vec);
|
137
169
|
})
|
138
170
|
.define_method(
|
139
171
|
"_index_put_custom",
|
140
|
-
|
172
|
+
[](Tensor& self, Array indices, torch::Tensor& value) {
|
141
173
|
auto vec = index_vector(indices);
|
142
174
|
return self.index_put_(vec, value);
|
143
175
|
})
|
144
176
|
.define_method(
|
145
177
|
"contiguous?",
|
146
|
-
|
178
|
+
[](Tensor& self) {
|
147
179
|
return self.is_contiguous();
|
148
180
|
})
|
149
181
|
.define_method(
|
150
182
|
"_requires_grad!",
|
151
|
-
|
183
|
+
[](Tensor& self, bool requires_grad) {
|
152
184
|
return self.set_requires_grad(requires_grad);
|
153
185
|
})
|
154
186
|
.define_method(
|
155
187
|
"grad",
|
156
|
-
|
188
|
+
[](Tensor& self) {
|
157
189
|
auto grad = self.grad();
|
158
|
-
return grad.defined() ?
|
190
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
159
191
|
})
|
160
192
|
.define_method(
|
161
193
|
"grad=",
|
162
|
-
|
194
|
+
[](Tensor& self, torch::Tensor& grad) {
|
163
195
|
self.mutable_grad() = grad;
|
164
196
|
})
|
165
197
|
.define_method(
|
166
198
|
"_dtype",
|
167
|
-
|
199
|
+
[](Tensor& self) {
|
168
200
|
return (int) at::typeMetaToScalarType(self.dtype());
|
169
201
|
})
|
170
202
|
.define_method(
|
171
203
|
"_type",
|
172
|
-
|
204
|
+
[](Tensor& self, int dtype) {
|
173
205
|
return self.toType((torch::ScalarType) dtype);
|
174
206
|
})
|
175
207
|
.define_method(
|
176
208
|
"_layout",
|
177
|
-
|
209
|
+
[](Tensor& self) {
|
178
210
|
std::stringstream s;
|
179
211
|
s << self.layout();
|
180
212
|
return s.str();
|
181
213
|
})
|
182
214
|
.define_method(
|
183
215
|
"device",
|
184
|
-
|
216
|
+
[](Tensor& self) {
|
185
217
|
std::stringstream s;
|
186
218
|
s << self.device();
|
187
219
|
return s.str();
|
188
220
|
})
|
189
221
|
.define_method(
|
190
222
|
"_data_str",
|
191
|
-
|
192
|
-
|
223
|
+
[](Tensor& self) {
|
224
|
+
auto tensor = self;
|
193
225
|
|
194
226
|
// move to CPU to get data
|
195
227
|
if (tensor.device().type() != torch::kCPU) {
|
@@ -207,14 +239,14 @@ void init_tensor(Rice::Module& m) {
|
|
207
239
|
// for TorchVision
|
208
240
|
.define_method(
|
209
241
|
"_data_ptr",
|
210
|
-
|
242
|
+
[](Tensor& self) {
|
211
243
|
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
212
244
|
})
|
213
245
|
// TODO figure out a better way to do this
|
214
246
|
.define_method(
|
215
247
|
"_flat_data",
|
216
|
-
|
217
|
-
|
248
|
+
[](Tensor& self) {
|
249
|
+
auto tensor = self;
|
218
250
|
|
219
251
|
// move to CPU to get data
|
220
252
|
if (tensor.device().type() != torch::kCPU) {
|
@@ -222,66 +254,47 @@ void init_tensor(Rice::Module& m) {
|
|
222
254
|
tensor = tensor.to(device);
|
223
255
|
}
|
224
256
|
|
225
|
-
Array a;
|
226
257
|
auto dtype = tensor.dtype();
|
227
|
-
|
228
|
-
Tensor view = tensor.reshape({tensor.numel()});
|
229
|
-
|
230
|
-
// TODO DRY if someone knows C++
|
231
258
|
if (dtype == torch::kByte) {
|
232
|
-
|
233
|
-
a.push(view[i].item().to<uint8_t>());
|
234
|
-
}
|
259
|
+
return flat_data<uint8_t>(tensor);
|
235
260
|
} else if (dtype == torch::kChar) {
|
236
|
-
|
237
|
-
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
238
|
-
}
|
261
|
+
return flat_data<int8_t>(tensor);
|
239
262
|
} else if (dtype == torch::kShort) {
|
240
|
-
|
241
|
-
a.push(view[i].item().to<int16_t>());
|
242
|
-
}
|
263
|
+
return flat_data<int16_t>(tensor);
|
243
264
|
} else if (dtype == torch::kInt) {
|
244
|
-
|
245
|
-
a.push(view[i].item().to<int32_t>());
|
246
|
-
}
|
265
|
+
return flat_data<int32_t>(tensor);
|
247
266
|
} else if (dtype == torch::kLong) {
|
248
|
-
|
249
|
-
a.push(view[i].item().to<int64_t>());
|
250
|
-
}
|
267
|
+
return flat_data<int64_t>(tensor);
|
251
268
|
} else if (dtype == torch::kFloat) {
|
252
|
-
|
253
|
-
a.push(view[i].item().to<float>());
|
254
|
-
}
|
269
|
+
return flat_data<float>(tensor);
|
255
270
|
} else if (dtype == torch::kDouble) {
|
256
|
-
|
257
|
-
a.push(view[i].item().to<double>());
|
258
|
-
}
|
271
|
+
return flat_data<double>(tensor);
|
259
272
|
} else if (dtype == torch::kBool) {
|
260
|
-
|
261
|
-
|
262
|
-
|
273
|
+
return flat_data<bool>(tensor);
|
274
|
+
} else if (dtype == torch::kComplexFloat) {
|
275
|
+
return flat_data<c10::complex<float>>(tensor);
|
276
|
+
} else if (dtype == torch::kComplexDouble) {
|
277
|
+
return flat_data<c10::complex<double>>(tensor);
|
263
278
|
} else {
|
264
279
|
throw std::runtime_error("Unsupported type");
|
265
280
|
}
|
266
|
-
return a;
|
267
281
|
})
|
268
282
|
.define_method(
|
269
283
|
"_to",
|
270
|
-
|
284
|
+
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
271
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
272
286
|
});
|
273
287
|
|
274
|
-
|
288
|
+
rb_cTensorOptions
|
275
289
|
.add_handler<torch::Error>(handle_error)
|
276
|
-
.define_constructor(Rice::Constructor<torch::TensorOptions>())
|
277
290
|
.define_method(
|
278
291
|
"dtype",
|
279
|
-
|
292
|
+
[](torch::TensorOptions& self, int dtype) {
|
280
293
|
return self.dtype((torch::ScalarType) dtype);
|
281
294
|
})
|
282
295
|
.define_method(
|
283
296
|
"layout",
|
284
|
-
|
297
|
+
[](torch::TensorOptions& self, const std::string& layout) {
|
285
298
|
torch::Layout l;
|
286
299
|
if (layout == "strided") {
|
287
300
|
l = torch::kStrided;
|
@@ -295,13 +308,13 @@ void init_tensor(Rice::Module& m) {
|
|
295
308
|
})
|
296
309
|
.define_method(
|
297
310
|
"device",
|
298
|
-
|
311
|
+
[](torch::TensorOptions& self, const std::string& device) {
|
299
312
|
torch::Device d(device);
|
300
313
|
return self.device(d);
|
301
314
|
})
|
302
315
|
.define_method(
|
303
316
|
"requires_grad",
|
304
|
-
|
317
|
+
[](torch::TensorOptions& self, bool requires_grad) {
|
305
318
|
return self.requires_grad(requires_grad);
|
306
319
|
});
|
307
320
|
}
|