torch-rb 0.6.0 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/codegen/generate_functions.rb +2 -2
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +14 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +16 -11
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +5 -0
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
137
137
|
return true;
|
138
138
|
}
|
139
139
|
if (THPVariable_Check(obj)) {
|
140
|
-
auto var =
|
140
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
141
141
|
return !var.requires_grad() && var.dim() == 0;
|
142
142
|
}
|
143
143
|
return false;
|
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
147
147
|
return true;
|
148
148
|
}
|
149
149
|
if (THPVariable_Check(obj)) {
|
150
|
-
auto var =
|
150
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
151
151
|
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
|
152
152
|
}
|
153
153
|
return false;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -5,7 +5,7 @@
|
|
5
5
|
#include <sstream>
|
6
6
|
|
7
7
|
#include <torch/torch.h>
|
8
|
-
#include <rice/
|
8
|
+
#include <rice/rice.hpp>
|
9
9
|
|
10
10
|
#include "templates.h"
|
11
11
|
#include "utils.h"
|
@@ -121,7 +121,7 @@ struct RubyArgs {
|
|
121
121
|
};
|
122
122
|
|
123
123
|
inline at::Tensor RubyArgs::tensor(int i) {
|
124
|
-
return
|
124
|
+
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
125
125
|
}
|
126
126
|
|
127
127
|
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
@@ -131,12 +131,12 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
|
131
131
|
|
132
132
|
inline at::Scalar RubyArgs::scalar(int i) {
|
133
133
|
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
134
|
-
return
|
134
|
+
return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
|
135
135
|
}
|
136
136
|
|
137
137
|
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
138
138
|
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
139
|
-
return
|
139
|
+
return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
|
140
140
|
}
|
141
141
|
|
142
142
|
template<int N>
|
@@ -151,7 +151,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
151
151
|
}
|
152
152
|
for (int idx = 0; idx < size; idx++) {
|
153
153
|
VALUE obj = rb_ary_entry(arg, idx);
|
154
|
-
res[idx] =
|
154
|
+
res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
|
155
155
|
}
|
156
156
|
return res;
|
157
157
|
}
|
@@ -170,7 +170,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
170
170
|
for (idx = 0; idx < size; idx++) {
|
171
171
|
VALUE obj = rb_ary_entry(arg, idx);
|
172
172
|
if (FIXNUM_P(obj)) {
|
173
|
-
res[idx] =
|
173
|
+
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
174
174
|
} else {
|
175
175
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
176
176
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -210,8 +210,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
210
210
|
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
211
211
|
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
212
212
|
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
213
|
+
{ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
|
213
214
|
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
215
|
+
{ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
|
216
|
+
{ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
|
214
217
|
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
218
|
+
{ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
|
219
|
+
{ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
|
215
220
|
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
216
221
|
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
217
222
|
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
@@ -260,7 +265,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
260
265
|
for (idx = 0; idx < size; idx++) {
|
261
266
|
VALUE obj = rb_ary_entry(arg, idx);
|
262
267
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
263
|
-
res[idx] =
|
268
|
+
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
264
269
|
} else {
|
265
270
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
266
271
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -303,22 +308,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
303
308
|
}
|
304
309
|
|
305
310
|
inline std::string RubyArgs::string(int i) {
|
306
|
-
return
|
311
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
307
312
|
}
|
308
313
|
|
309
314
|
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
310
315
|
if (!args[i]) return c10::nullopt;
|
311
|
-
return
|
316
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
312
317
|
}
|
313
318
|
|
314
319
|
inline int64_t RubyArgs::toInt64(int i) {
|
315
320
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
316
|
-
return
|
321
|
+
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
317
322
|
}
|
318
323
|
|
319
324
|
inline double RubyArgs::toDouble(int i) {
|
320
325
|
if (NIL_P(args[i])) return signature.params[i].default_double;
|
321
|
-
return
|
326
|
+
return Rice::detail::From_Ruby<double>().convert(args[i]);
|
322
327
|
}
|
323
328
|
|
324
329
|
inline bool RubyArgs::toBool(int i) {
|
data/ext/torch/templates.h
CHANGED
@@ -4,8 +4,7 @@
|
|
4
4
|
#undef isfinite
|
5
5
|
#endif
|
6
6
|
|
7
|
-
#include <rice/
|
8
|
-
#include <rice/Object.hpp>
|
7
|
+
#include <rice/rice.hpp>
|
9
8
|
|
10
9
|
using namespace Rice;
|
11
10
|
|
@@ -23,6 +22,9 @@ using torch::ArrayRef;
|
|
23
22
|
using torch::TensorList;
|
24
23
|
using torch::Storage;
|
25
24
|
|
25
|
+
using torch::nn::init::FanModeType;
|
26
|
+
using torch::nn::init::NonlinearityType;
|
27
|
+
|
26
28
|
#define HANDLE_TH_ERRORS \
|
27
29
|
try {
|
28
30
|
|
@@ -38,37 +40,42 @@ using torch::Storage;
|
|
38
40
|
#define RETURN_NIL \
|
39
41
|
return Qnil;
|
40
42
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
{
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
43
|
+
class OptionalTensor {
|
44
|
+
torch::Tensor value;
|
45
|
+
public:
|
46
|
+
OptionalTensor(Object o) {
|
47
|
+
if (o.is_nil()) {
|
48
|
+
value = {};
|
49
|
+
} else {
|
50
|
+
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
51
|
+
}
|
52
|
+
}
|
53
|
+
OptionalTensor(torch::Tensor o) {
|
54
|
+
value = o;
|
55
|
+
}
|
56
|
+
operator torch::Tensor() const {
|
57
|
+
return value;
|
58
|
+
}
|
59
|
+
};
|
52
60
|
|
53
|
-
|
54
|
-
inline
|
55
|
-
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
61
|
+
namespace Rice::detail
|
56
62
|
{
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
}
|
63
|
+
template<>
|
64
|
+
struct Type<FanModeType>
|
65
|
+
{
|
66
|
+
static bool verify()
|
67
|
+
{
|
68
|
+
return true;
|
69
|
+
}
|
70
|
+
};
|
64
71
|
|
65
|
-
|
66
|
-
|
72
|
+
template<>
|
73
|
+
class From_Ruby<FanModeType>
|
74
|
+
{
|
67
75
|
public:
|
68
|
-
FanModeType(
|
69
|
-
|
70
|
-
|
71
|
-
operator torch::nn::init::FanModeType() {
|
76
|
+
FanModeType convert(VALUE x)
|
77
|
+
{
|
78
|
+
auto s = String(x).str();
|
72
79
|
if (s == "fan_in") {
|
73
80
|
return torch::kFanIn;
|
74
81
|
} else if (s == "fan_out") {
|
@@ -77,22 +84,24 @@ class FanModeType {
|
|
77
84
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
78
85
|
}
|
79
86
|
}
|
80
|
-
};
|
81
|
-
|
82
|
-
template<>
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
87
|
+
};
|
88
|
+
|
89
|
+
template<>
|
90
|
+
struct Type<NonlinearityType>
|
91
|
+
{
|
92
|
+
static bool verify()
|
93
|
+
{
|
94
|
+
return true;
|
95
|
+
}
|
96
|
+
};
|
88
97
|
|
89
|
-
|
90
|
-
|
98
|
+
template<>
|
99
|
+
class From_Ruby<NonlinearityType>
|
100
|
+
{
|
91
101
|
public:
|
92
|
-
NonlinearityType(
|
93
|
-
|
94
|
-
|
95
|
-
operator torch::nn::init::NonlinearityType() {
|
102
|
+
NonlinearityType convert(VALUE x)
|
103
|
+
{
|
104
|
+
auto s = String(x).str();
|
96
105
|
if (s == "linear") {
|
97
106
|
return torch::kLinear;
|
98
107
|
} else if (s == "conv1d") {
|
@@ -119,102 +128,70 @@ class NonlinearityType {
|
|
119
128
|
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
120
129
|
}
|
121
130
|
}
|
122
|
-
};
|
131
|
+
};
|
132
|
+
|
133
|
+
template<>
|
134
|
+
struct Type<OptionalTensor>
|
135
|
+
{
|
136
|
+
static bool verify()
|
137
|
+
{
|
138
|
+
return true;
|
139
|
+
}
|
140
|
+
};
|
123
141
|
|
124
|
-
template<>
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
142
|
+
template<>
|
143
|
+
class From_Ruby<OptionalTensor>
|
144
|
+
{
|
145
|
+
public:
|
146
|
+
OptionalTensor convert(VALUE x)
|
147
|
+
{
|
148
|
+
return OptionalTensor(x);
|
149
|
+
}
|
150
|
+
};
|
151
|
+
|
152
|
+
template<>
|
153
|
+
struct Type<Scalar>
|
154
|
+
{
|
155
|
+
static bool verify()
|
156
|
+
{
|
157
|
+
return true;
|
158
|
+
}
|
159
|
+
};
|
130
160
|
|
131
|
-
|
132
|
-
|
161
|
+
template<>
|
162
|
+
class From_Ruby<Scalar>
|
163
|
+
{
|
133
164
|
public:
|
134
|
-
|
135
|
-
|
136
|
-
|
165
|
+
Scalar convert(VALUE x)
|
166
|
+
{
|
167
|
+
if (FIXNUM_P(x)) {
|
168
|
+
return torch::Scalar(From_Ruby<int64_t>().convert(x));
|
137
169
|
} else {
|
138
|
-
|
170
|
+
return torch::Scalar(From_Ruby<double>().convert(x));
|
139
171
|
}
|
140
172
|
}
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
173
|
+
};
|
174
|
+
|
175
|
+
template<typename T>
|
176
|
+
struct Type<torch::optional<T>>
|
177
|
+
{
|
178
|
+
static bool verify()
|
179
|
+
{
|
180
|
+
return true;
|
146
181
|
}
|
147
|
-
};
|
148
|
-
|
149
|
-
template<>
|
150
|
-
inline
|
151
|
-
Scalar from_ruby<Scalar>(Object x)
|
152
|
-
{
|
153
|
-
if (x.rb_type() == T_FIXNUM) {
|
154
|
-
return torch::Scalar(from_ruby<int64_t>(x));
|
155
|
-
} else {
|
156
|
-
return torch::Scalar(from_ruby<double>(x));
|
157
|
-
}
|
158
|
-
}
|
159
|
-
|
160
|
-
template<>
|
161
|
-
inline
|
162
|
-
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
163
|
-
{
|
164
|
-
return OptionalTensor(x);
|
165
|
-
}
|
166
|
-
|
167
|
-
template<>
|
168
|
-
inline
|
169
|
-
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
170
|
-
{
|
171
|
-
if (x.is_nil()) {
|
172
|
-
return torch::nullopt;
|
173
|
-
} else {
|
174
|
-
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
175
|
-
}
|
176
|
-
}
|
177
|
-
|
178
|
-
template<>
|
179
|
-
inline
|
180
|
-
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
181
|
-
{
|
182
|
-
if (x.is_nil()) {
|
183
|
-
return torch::nullopt;
|
184
|
-
} else {
|
185
|
-
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
186
|
-
}
|
187
|
-
}
|
182
|
+
};
|
188
183
|
|
189
|
-
template
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
}
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
203
|
-
{
|
204
|
-
if (x.is_nil()) {
|
205
|
-
return torch::nullopt;
|
206
|
-
} else {
|
207
|
-
return torch::optional<bool>{from_ruby<bool>(x)};
|
208
|
-
}
|
209
|
-
}
|
210
|
-
|
211
|
-
template<>
|
212
|
-
inline
|
213
|
-
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
214
|
-
{
|
215
|
-
if (x.is_nil()) {
|
216
|
-
return torch::nullopt;
|
217
|
-
} else {
|
218
|
-
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
219
|
-
}
|
184
|
+
template<typename T>
|
185
|
+
class From_Ruby<torch::optional<T>>
|
186
|
+
{
|
187
|
+
public:
|
188
|
+
torch::optional<T> convert(VALUE x)
|
189
|
+
{
|
190
|
+
if (NIL_P(x)) {
|
191
|
+
return torch::nullopt;
|
192
|
+
} else {
|
193
|
+
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
194
|
+
}
|
195
|
+
}
|
196
|
+
};
|
220
197
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Module.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
5
4
|
|
6
5
|
#include "tensor_functions.h"
|
7
6
|
#include "ruby_arg_parser.h"
|
@@ -11,6 +10,39 @@
|
|
11
10
|
using namespace Rice;
|
12
11
|
using torch::indexing::TensorIndex;
|
13
12
|
|
13
|
+
namespace Rice::detail
|
14
|
+
{
|
15
|
+
template<typename T>
|
16
|
+
struct Type<c10::complex<T>>
|
17
|
+
{
|
18
|
+
static bool verify()
|
19
|
+
{
|
20
|
+
return true;
|
21
|
+
}
|
22
|
+
};
|
23
|
+
|
24
|
+
template<typename T>
|
25
|
+
class To_Ruby<c10::complex<T>>
|
26
|
+
{
|
27
|
+
public:
|
28
|
+
VALUE convert(c10::complex<T> const& x)
|
29
|
+
{
|
30
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
+
}
|
32
|
+
};
|
33
|
+
}
|
34
|
+
|
35
|
+
template<typename T>
|
36
|
+
Array flat_data(Tensor& tensor) {
|
37
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
38
|
+
|
39
|
+
Array a;
|
40
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
41
|
+
a.push(view[i].item().to<T>());
|
42
|
+
}
|
43
|
+
return a;
|
44
|
+
}
|
45
|
+
|
14
46
|
Class rb_cTensor;
|
15
47
|
|
16
48
|
std::vector<TensorIndex> index_vector(Array a) {
|
@@ -23,19 +55,19 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
23
55
|
obj = a[i];
|
24
56
|
|
25
57
|
if (obj.is_instance_of(rb_cInteger)) {
|
26
|
-
indices.push_back(
|
58
|
+
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
27
59
|
} else if (obj.is_instance_of(rb_cRange)) {
|
28
60
|
torch::optional<int64_t> start_index = torch::nullopt;
|
29
61
|
torch::optional<int64_t> stop_index = torch::nullopt;
|
30
62
|
|
31
63
|
Object begin = obj.call("begin");
|
32
64
|
if (!begin.is_nil()) {
|
33
|
-
start_index =
|
65
|
+
start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
|
34
66
|
}
|
35
67
|
|
36
68
|
Object end = obj.call("end");
|
37
69
|
if (!end.is_nil()) {
|
38
|
-
stop_index =
|
70
|
+
stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
|
39
71
|
}
|
40
72
|
|
41
73
|
Object exclude_end = obj.call("exclude_end?");
|
@@ -49,11 +81,11 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
49
81
|
|
50
82
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
51
83
|
} else if (obj.is_instance_of(rb_cTensor)) {
|
52
|
-
indices.push_back(
|
84
|
+
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
53
85
|
} else if (obj.is_nil()) {
|
54
86
|
indices.push_back(torch::indexing::None);
|
55
87
|
} else if (obj == True || obj == False) {
|
56
|
-
indices.push_back(
|
88
|
+
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
57
89
|
} else {
|
58
90
|
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
59
91
|
}
|
@@ -68,7 +100,7 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
68
100
|
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
69
101
|
{
|
70
102
|
HANDLE_TH_ERRORS
|
71
|
-
Tensor& self =
|
103
|
+
Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
|
72
104
|
static RubyArgParser parser({
|
73
105
|
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
74
106
|
});
|
@@ -84,8 +116,8 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
84
116
|
END_HANDLE_TH_ERRORS
|
85
117
|
}
|
86
118
|
|
87
|
-
void init_tensor(Rice::Module& m) {
|
88
|
-
rb_cTensor =
|
119
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
|
120
|
+
rb_cTensor = c;
|
89
121
|
rb_cTensor.add_handler<torch::Error>(handle_error);
|
90
122
|
add_tensor_functions(rb_cTensor);
|
91
123
|
THPVariableClass = rb_cTensor.value();
|
@@ -102,18 +134,18 @@ void init_tensor(Rice::Module& m) {
|
|
102
134
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
103
135
|
.define_method(
|
104
136
|
"_size",
|
105
|
-
|
137
|
+
[](Tensor& self, int64_t dim) {
|
106
138
|
return self.size(dim);
|
107
139
|
})
|
108
140
|
.define_method(
|
109
141
|
"_stride",
|
110
|
-
|
142
|
+
[](Tensor& self, int64_t dim) {
|
111
143
|
return self.stride(dim);
|
112
144
|
})
|
113
145
|
// in C++ for performance
|
114
146
|
.define_method(
|
115
147
|
"shape",
|
116
|
-
|
148
|
+
[](Tensor& self) {
|
117
149
|
Array a;
|
118
150
|
for (auto &size : self.sizes()) {
|
119
151
|
a.push(size);
|
@@ -122,7 +154,7 @@ void init_tensor(Rice::Module& m) {
|
|
122
154
|
})
|
123
155
|
.define_method(
|
124
156
|
"_strides",
|
125
|
-
|
157
|
+
[](Tensor& self) {
|
126
158
|
Array a;
|
127
159
|
for (auto &stride : self.strides()) {
|
128
160
|
a.push(stride);
|
@@ -131,65 +163,65 @@ void init_tensor(Rice::Module& m) {
|
|
131
163
|
})
|
132
164
|
.define_method(
|
133
165
|
"_index",
|
134
|
-
|
166
|
+
[](Tensor& self, Array indices) {
|
135
167
|
auto vec = index_vector(indices);
|
136
168
|
return self.index(vec);
|
137
169
|
})
|
138
170
|
.define_method(
|
139
171
|
"_index_put_custom",
|
140
|
-
|
172
|
+
[](Tensor& self, Array indices, torch::Tensor& value) {
|
141
173
|
auto vec = index_vector(indices);
|
142
174
|
return self.index_put_(vec, value);
|
143
175
|
})
|
144
176
|
.define_method(
|
145
177
|
"contiguous?",
|
146
|
-
|
178
|
+
[](Tensor& self) {
|
147
179
|
return self.is_contiguous();
|
148
180
|
})
|
149
181
|
.define_method(
|
150
182
|
"_requires_grad!",
|
151
|
-
|
183
|
+
[](Tensor& self, bool requires_grad) {
|
152
184
|
return self.set_requires_grad(requires_grad);
|
153
185
|
})
|
154
186
|
.define_method(
|
155
187
|
"grad",
|
156
|
-
|
188
|
+
[](Tensor& self) {
|
157
189
|
auto grad = self.grad();
|
158
|
-
return grad.defined() ?
|
190
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
159
191
|
})
|
160
192
|
.define_method(
|
161
193
|
"grad=",
|
162
|
-
|
194
|
+
[](Tensor& self, torch::Tensor& grad) {
|
163
195
|
self.mutable_grad() = grad;
|
164
196
|
})
|
165
197
|
.define_method(
|
166
198
|
"_dtype",
|
167
|
-
|
199
|
+
[](Tensor& self) {
|
168
200
|
return (int) at::typeMetaToScalarType(self.dtype());
|
169
201
|
})
|
170
202
|
.define_method(
|
171
203
|
"_type",
|
172
|
-
|
204
|
+
[](Tensor& self, int dtype) {
|
173
205
|
return self.toType((torch::ScalarType) dtype);
|
174
206
|
})
|
175
207
|
.define_method(
|
176
208
|
"_layout",
|
177
|
-
|
209
|
+
[](Tensor& self) {
|
178
210
|
std::stringstream s;
|
179
211
|
s << self.layout();
|
180
212
|
return s.str();
|
181
213
|
})
|
182
214
|
.define_method(
|
183
215
|
"device",
|
184
|
-
|
216
|
+
[](Tensor& self) {
|
185
217
|
std::stringstream s;
|
186
218
|
s << self.device();
|
187
219
|
return s.str();
|
188
220
|
})
|
189
221
|
.define_method(
|
190
222
|
"_data_str",
|
191
|
-
|
192
|
-
|
223
|
+
[](Tensor& self) {
|
224
|
+
auto tensor = self;
|
193
225
|
|
194
226
|
// move to CPU to get data
|
195
227
|
if (tensor.device().type() != torch::kCPU) {
|
@@ -207,14 +239,14 @@ void init_tensor(Rice::Module& m) {
|
|
207
239
|
// for TorchVision
|
208
240
|
.define_method(
|
209
241
|
"_data_ptr",
|
210
|
-
|
242
|
+
[](Tensor& self) {
|
211
243
|
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
212
244
|
})
|
213
245
|
// TODO figure out a better way to do this
|
214
246
|
.define_method(
|
215
247
|
"_flat_data",
|
216
|
-
|
217
|
-
|
248
|
+
[](Tensor& self) {
|
249
|
+
auto tensor = self;
|
218
250
|
|
219
251
|
// move to CPU to get data
|
220
252
|
if (tensor.device().type() != torch::kCPU) {
|
@@ -222,66 +254,47 @@ void init_tensor(Rice::Module& m) {
|
|
222
254
|
tensor = tensor.to(device);
|
223
255
|
}
|
224
256
|
|
225
|
-
Array a;
|
226
257
|
auto dtype = tensor.dtype();
|
227
|
-
|
228
|
-
Tensor view = tensor.reshape({tensor.numel()});
|
229
|
-
|
230
|
-
// TODO DRY if someone knows C++
|
231
258
|
if (dtype == torch::kByte) {
|
232
|
-
|
233
|
-
a.push(view[i].item().to<uint8_t>());
|
234
|
-
}
|
259
|
+
return flat_data<uint8_t>(tensor);
|
235
260
|
} else if (dtype == torch::kChar) {
|
236
|
-
|
237
|
-
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
238
|
-
}
|
261
|
+
return flat_data<int8_t>(tensor);
|
239
262
|
} else if (dtype == torch::kShort) {
|
240
|
-
|
241
|
-
a.push(view[i].item().to<int16_t>());
|
242
|
-
}
|
263
|
+
return flat_data<int16_t>(tensor);
|
243
264
|
} else if (dtype == torch::kInt) {
|
244
|
-
|
245
|
-
a.push(view[i].item().to<int32_t>());
|
246
|
-
}
|
265
|
+
return flat_data<int32_t>(tensor);
|
247
266
|
} else if (dtype == torch::kLong) {
|
248
|
-
|
249
|
-
a.push(view[i].item().to<int64_t>());
|
250
|
-
}
|
267
|
+
return flat_data<int64_t>(tensor);
|
251
268
|
} else if (dtype == torch::kFloat) {
|
252
|
-
|
253
|
-
a.push(view[i].item().to<float>());
|
254
|
-
}
|
269
|
+
return flat_data<float>(tensor);
|
255
270
|
} else if (dtype == torch::kDouble) {
|
256
|
-
|
257
|
-
a.push(view[i].item().to<double>());
|
258
|
-
}
|
271
|
+
return flat_data<double>(tensor);
|
259
272
|
} else if (dtype == torch::kBool) {
|
260
|
-
|
261
|
-
|
262
|
-
|
273
|
+
return flat_data<bool>(tensor);
|
274
|
+
} else if (dtype == torch::kComplexFloat) {
|
275
|
+
return flat_data<c10::complex<float>>(tensor);
|
276
|
+
} else if (dtype == torch::kComplexDouble) {
|
277
|
+
return flat_data<c10::complex<double>>(tensor);
|
263
278
|
} else {
|
264
279
|
throw std::runtime_error("Unsupported type");
|
265
280
|
}
|
266
|
-
return a;
|
267
281
|
})
|
268
282
|
.define_method(
|
269
283
|
"_to",
|
270
|
-
|
284
|
+
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
271
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
272
286
|
});
|
273
287
|
|
274
|
-
|
288
|
+
rb_cTensorOptions
|
275
289
|
.add_handler<torch::Error>(handle_error)
|
276
|
-
.define_constructor(Rice::Constructor<torch::TensorOptions>())
|
277
290
|
.define_method(
|
278
291
|
"dtype",
|
279
|
-
|
292
|
+
[](torch::TensorOptions& self, int dtype) {
|
280
293
|
return self.dtype((torch::ScalarType) dtype);
|
281
294
|
})
|
282
295
|
.define_method(
|
283
296
|
"layout",
|
284
|
-
|
297
|
+
[](torch::TensorOptions& self, const std::string& layout) {
|
285
298
|
torch::Layout l;
|
286
299
|
if (layout == "strided") {
|
287
300
|
l = torch::kStrided;
|
@@ -295,13 +308,13 @@ void init_tensor(Rice::Module& m) {
|
|
295
308
|
})
|
296
309
|
.define_method(
|
297
310
|
"device",
|
298
|
-
|
311
|
+
[](torch::TensorOptions& self, const std::string& device) {
|
299
312
|
torch::Device d(device);
|
300
313
|
return self.device(d);
|
301
314
|
})
|
302
315
|
.define_method(
|
303
316
|
"requires_grad",
|
304
|
-
|
317
|
+
[](torch::TensorOptions& self, bool requires_grad) {
|
305
318
|
return self.requires_grad(requires_grad);
|
306
319
|
});
|
307
320
|
}
|