torch-rb 0.11.2 → 0.12.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +5 -4
- data/codegen/function.rb +8 -0
- data/codegen/generate_functions.rb +30 -5
- data/codegen/native_functions.yaml +2067 -652
- data/ext/torch/ruby_arg_parser.cpp +88 -3
- data/ext/torch/ruby_arg_parser.h +39 -3
- data/ext/torch/utils.h +5 -1
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
@@ -11,6 +11,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
|
11
11
|
{"double", ParameterType::DOUBLE},
|
12
12
|
{"complex", ParameterType::COMPLEX},
|
13
13
|
{"TensorList", ParameterType::TENSOR_LIST},
|
14
|
+
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
|
14
15
|
{"IntArrayRef", ParameterType::INT_LIST},
|
15
16
|
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
16
17
|
{"Generator", ParameterType::GENERATOR},
|
@@ -22,9 +23,14 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
|
22
23
|
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
|
23
24
|
{"QScheme", ParameterType::QSCHEME},
|
24
25
|
{"Device", ParameterType::DEVICE},
|
26
|
+
{"Stream", ParameterType::STREAM},
|
25
27
|
{"std::string", ParameterType::STRING},
|
28
|
+
{"c10::string_view", ParameterType::STRING},
|
29
|
+
{"SymInt", ParameterType::SYM_INT},
|
26
30
|
{"Dimname", ParameterType::DIMNAME},
|
31
|
+
{"SymIntArrayRef", ParameterType::SYM_INT_LIST},
|
27
32
|
{"DimnameList", ParameterType::DIMNAME_LIST},
|
33
|
+
{"ScalarList", ParameterType::SCALAR_LIST}
|
28
34
|
};
|
29
35
|
|
30
36
|
static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
|
@@ -116,6 +122,72 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
|
116
122
|
return true;
|
117
123
|
}
|
118
124
|
|
125
|
+
static bool is_int_list(VALUE obj, int broadcast_size) {
|
126
|
+
if (RB_TYPE_P(obj, T_ARRAY)) {
|
127
|
+
auto len = RARRAY_LEN(obj);
|
128
|
+
if (len == 0) {
|
129
|
+
return true;
|
130
|
+
}
|
131
|
+
|
132
|
+
auto item = rb_ary_entry(obj, 0);
|
133
|
+
bool int_first = false;
|
134
|
+
if (THPUtils_checkIndex(item)) {
|
135
|
+
// we still have to check that the rest of items are NOT symint nodes
|
136
|
+
int_first = true;
|
137
|
+
}
|
138
|
+
|
139
|
+
// Make sure none of the later arguments are SymInt
|
140
|
+
// NB: do NOT check that the later arguments are ints, as this is
|
141
|
+
// BC-breaking for FX
|
142
|
+
// for (int i = 1; i < len; i++) {
|
143
|
+
// if (torch::is_symint_node(
|
144
|
+
// py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
|
145
|
+
// return false;
|
146
|
+
// }
|
147
|
+
// }
|
148
|
+
|
149
|
+
if (int_first) {
|
150
|
+
return true;
|
151
|
+
}
|
152
|
+
|
153
|
+
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
154
|
+
// in an intlist argument. Even float or complex scalar tensors.
|
155
|
+
// return (
|
156
|
+
// jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
|
157
|
+
// THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
|
158
|
+
return false;
|
159
|
+
}
|
160
|
+
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
|
161
|
+
// int
|
162
|
+
return broadcast_size > 0 && THPUtils_checkLong(obj);
|
163
|
+
}
|
164
|
+
|
165
|
+
static bool is_int_or_symint(VALUE obj) {
|
166
|
+
return THPUtils_checkIndex(obj);
|
167
|
+
}
|
168
|
+
|
169
|
+
static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
|
170
|
+
if (RB_TYPE_P(obj, T_ARRAY)) {
|
171
|
+
if (RARRAY_LEN(obj) == 0) {
|
172
|
+
return true;
|
173
|
+
}
|
174
|
+
auto item = rb_ary_entry(obj, 0);
|
175
|
+
|
176
|
+
if (is_int_or_symint(item)) {
|
177
|
+
return true;
|
178
|
+
}
|
179
|
+
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
180
|
+
// in an intlist argument. Even float or complex scalar tensors.
|
181
|
+
// return (
|
182
|
+
// jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
|
183
|
+
// THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
|
184
|
+
return false;
|
185
|
+
}
|
186
|
+
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
|
187
|
+
// int
|
188
|
+
return broadcast_size > 0 && THPUtils_checkLong(obj);
|
189
|
+
}
|
190
|
+
|
119
191
|
// argnum is needed for raising the TypeError, it's used in the error message.
|
120
192
|
auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
121
193
|
{
|
@@ -182,6 +254,8 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
182
254
|
case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
|
183
255
|
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
|
184
256
|
case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
|
257
|
+
case ParameterType::SYM_INT: return is_int_or_symint(obj);
|
258
|
+
case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
|
185
259
|
default: throw std::runtime_error("unknown parameter type");
|
186
260
|
}
|
187
261
|
}
|
@@ -191,6 +265,7 @@ std::string FunctionParameter::type_name() const {
|
|
191
265
|
case ParameterType::TENSOR: return "Tensor";
|
192
266
|
case ParameterType::SCALAR: return "Number";
|
193
267
|
case ParameterType::INT64: return "int";
|
268
|
+
case ParameterType::SYM_INT: return "SymInt";
|
194
269
|
case ParameterType::DOUBLE: return "float";
|
195
270
|
case ParameterType::COMPLEX: return "complex";
|
196
271
|
case ParameterType::TENSOR_LIST: return "array of Tensors";
|
@@ -208,6 +283,8 @@ std::string FunctionParameter::type_name() const {
|
|
208
283
|
case ParameterType::STRING: return "str";
|
209
284
|
case ParameterType::DIMNAME: return "name";
|
210
285
|
case ParameterType::DIMNAME_LIST: return "array of names";
|
286
|
+
case ParameterType::SCALAR_LIST: return "array of Scalars";
|
287
|
+
case ParameterType::SYM_INT_LIST: return "array of SymInts";
|
211
288
|
default: throw std::runtime_error("unknown parameter type");
|
212
289
|
}
|
213
290
|
}
|
@@ -558,8 +635,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
558
635
|
|
559
636
|
// if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
|
560
637
|
// allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
|
561
|
-
|
638
|
+
int int_list_overload = false;
|
639
|
+
if (max_pos_args == 1 &&
|
640
|
+
(params[0].type_ == ParameterType::INT_LIST ||
|
641
|
+
params[0].type_ == ParameterType::SYM_INT_LIST)) {
|
562
642
|
allow_varargs_intlist = true;
|
643
|
+
if (params[0].type_ == ParameterType::INT_LIST) {
|
644
|
+
int_list_overload = true;
|
645
|
+
}
|
563
646
|
}
|
564
647
|
|
565
648
|
if (nargs > max_pos_args && !allow_varargs_intlist) {
|
@@ -614,8 +697,10 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
614
697
|
// XXX: the Variable check is necessary because sizes become tensors when
|
615
698
|
// tracer is enabled. This behavior easily leads to ambiguities, and we
|
616
699
|
// should avoid having complex signatures that make use of it...
|
617
|
-
} else if (
|
618
|
-
|
700
|
+
} else if (
|
701
|
+
allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
|
702
|
+
((int_list_overload ? is_int_list(args, param.size)
|
703
|
+
: is_int_or_symint_list(args, param.size)))) {
|
619
704
|
// take all positional arguments as this parameter
|
620
705
|
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
|
621
706
|
dst[i++] = args;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -11,9 +11,9 @@
|
|
11
11
|
#include "utils.h"
|
12
12
|
|
13
13
|
enum class ParameterType {
|
14
|
-
TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
15
|
-
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
|
16
|
-
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
|
14
|
+
TENSOR, SCALAR, INT64, SYM_INT, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
15
|
+
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING,
|
16
|
+
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, SYM_INT_LIST
|
17
17
|
};
|
18
18
|
|
19
19
|
struct FunctionParameter {
|
@@ -84,7 +84,9 @@ struct RubyArgs {
|
|
84
84
|
template<int N>
|
85
85
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
86
86
|
inline std::vector<int64_t> intlist(int i);
|
87
|
+
inline std::vector<c10::SymInt> symintlist(int i);
|
87
88
|
inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
89
|
+
inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
|
88
90
|
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
89
91
|
inline c10::optional<at::Generator> generator(int i);
|
90
92
|
inline at::Storage storage(int i);
|
@@ -93,6 +95,7 @@ struct RubyArgs {
|
|
93
95
|
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
94
96
|
inline c10::optional<at::Scalar> scalarOptional(int i);
|
95
97
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
98
|
+
inline c10::optional<c10::SymInt> toSymIntOptional(int i);
|
96
99
|
inline c10::optional<bool> toBoolOptional(int i);
|
97
100
|
inline c10::optional<double> toDoubleOptional(int i);
|
98
101
|
inline c10::OptionalArray<double> doublelistOptional(int i);
|
@@ -116,6 +119,7 @@ struct RubyArgs {
|
|
116
119
|
inline c10::optional<c10::string_view> stringViewOptional(int i);
|
117
120
|
// inline PyObject* pyobject(int i);
|
118
121
|
inline int64_t toInt64(int i);
|
122
|
+
inline c10::SymInt toSymInt(int i);
|
119
123
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
120
124
|
inline double toDouble(int i);
|
121
125
|
// inline double toDoubleWithDefault(int i, double default_double);
|
@@ -171,6 +175,19 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
171
175
|
return intlistWithDefault(i, signature.params[i].default_intlist);
|
172
176
|
}
|
173
177
|
|
178
|
+
inline std::vector<c10::SymInt> RubyArgs::symintlist(int i) {
|
179
|
+
if (NIL_P(args[i])) {
|
180
|
+
return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
|
181
|
+
return c10::SymInt(di);
|
182
|
+
});
|
183
|
+
}
|
184
|
+
|
185
|
+
// TODO improve
|
186
|
+
return c10::fmap(intlist(i), [](int64_t di) {
|
187
|
+
return c10::SymInt(di);
|
188
|
+
});
|
189
|
+
}
|
190
|
+
|
174
191
|
inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
|
175
192
|
if (NIL_P(args[i])) return default_intlist;
|
176
193
|
VALUE arg = args[i];
|
@@ -199,6 +216,11 @@ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
|
|
199
216
|
return intlist(i);
|
200
217
|
}
|
201
218
|
|
219
|
+
inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
|
220
|
+
if (NIL_P(args[i])) return {};
|
221
|
+
return symintlist(i);
|
222
|
+
}
|
223
|
+
|
202
224
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
203
225
|
if (NIL_P(args[i])) return c10::nullopt;
|
204
226
|
throw std::runtime_error("generator not supported yet");
|
@@ -270,6 +292,11 @@ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
|
|
270
292
|
return toInt64(i);
|
271
293
|
}
|
272
294
|
|
295
|
+
inline c10::optional<c10::SymInt> RubyArgs::toSymIntOptional(int i) {
|
296
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
297
|
+
return toSymInt(i);
|
298
|
+
}
|
299
|
+
|
273
300
|
inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
|
274
301
|
if (NIL_P(args[i])) return c10::nullopt;
|
275
302
|
return toBool(i);
|
@@ -376,6 +403,15 @@ inline int64_t RubyArgs::toInt64(int i) {
|
|
376
403
|
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
377
404
|
}
|
378
405
|
|
406
|
+
inline c10::SymInt RubyArgs::toSymInt(int i) {
|
407
|
+
if (NIL_P(args[i])) {
|
408
|
+
return c10::SymInt(signature.params[i].default_int);
|
409
|
+
}
|
410
|
+
|
411
|
+
// TODO improve
|
412
|
+
return c10::SymInt(toInt64(i));
|
413
|
+
}
|
414
|
+
|
379
415
|
inline double RubyArgs::toDouble(int i) {
|
380
416
|
if (NIL_P(args[i])) return signature.params[i].default_double;
|
381
417
|
return Rice::detail::From_Ruby<double>().convert(args[i]);
|
data/ext/torch/utils.h
CHANGED
@@ -6,7 +6,7 @@
|
|
6
6
|
#include <rice/stl.hpp>
|
7
7
|
|
8
8
|
static_assert(
|
9
|
-
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR ==
|
9
|
+
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 13,
|
10
10
|
"Incompatible LibTorch version"
|
11
11
|
);
|
12
12
|
|
@@ -36,6 +36,10 @@ inline bool THPUtils_checkIndex(VALUE obj) {
|
|
36
36
|
return FIXNUM_P(obj);
|
37
37
|
}
|
38
38
|
|
39
|
+
inline bool THPUtils_checkLong(VALUE obj) {
|
40
|
+
return FIXNUM_P(obj);
|
41
|
+
}
|
42
|
+
|
39
43
|
inline bool THPUtils_checkScalar(VALUE obj) {
|
40
44
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
41
45
|
}
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.12.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-11-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|