torch-rb 0.11.2 → 0.12.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/README.md +5 -4
- data/codegen/function.rb +8 -0
- data/codegen/generate_functions.rb +30 -5
- data/codegen/native_functions.yaml +2067 -652
- data/ext/torch/ruby_arg_parser.cpp +88 -3
- data/ext/torch/ruby_arg_parser.h +39 -3
- data/ext/torch/utils.h +5 -1
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
@@ -11,6 +11,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
|
11
11
|
{"double", ParameterType::DOUBLE},
|
12
12
|
{"complex", ParameterType::COMPLEX},
|
13
13
|
{"TensorList", ParameterType::TENSOR_LIST},
|
14
|
+
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
|
14
15
|
{"IntArrayRef", ParameterType::INT_LIST},
|
15
16
|
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
16
17
|
{"Generator", ParameterType::GENERATOR},
|
@@ -22,9 +23,14 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
|
22
23
|
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
|
23
24
|
{"QScheme", ParameterType::QSCHEME},
|
24
25
|
{"Device", ParameterType::DEVICE},
|
26
|
+
{"Stream", ParameterType::STREAM},
|
25
27
|
{"std::string", ParameterType::STRING},
|
28
|
+
{"c10::string_view", ParameterType::STRING},
|
29
|
+
{"SymInt", ParameterType::SYM_INT},
|
26
30
|
{"Dimname", ParameterType::DIMNAME},
|
31
|
+
{"SymIntArrayRef", ParameterType::SYM_INT_LIST},
|
27
32
|
{"DimnameList", ParameterType::DIMNAME_LIST},
|
33
|
+
{"ScalarList", ParameterType::SCALAR_LIST}
|
28
34
|
};
|
29
35
|
|
30
36
|
static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
|
@@ -116,6 +122,72 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
|
116
122
|
return true;
|
117
123
|
}
|
118
124
|
|
125
|
+
static bool is_int_list(VALUE obj, int broadcast_size) {
|
126
|
+
if (RB_TYPE_P(obj, T_ARRAY)) {
|
127
|
+
auto len = RARRAY_LEN(obj);
|
128
|
+
if (len == 0) {
|
129
|
+
return true;
|
130
|
+
}
|
131
|
+
|
132
|
+
auto item = rb_ary_entry(obj, 0);
|
133
|
+
bool int_first = false;
|
134
|
+
if (THPUtils_checkIndex(item)) {
|
135
|
+
// we still have to check that the rest of items are NOT symint nodes
|
136
|
+
int_first = true;
|
137
|
+
}
|
138
|
+
|
139
|
+
// Make sure none of the later arguments are SymInt
|
140
|
+
// NB: do NOT check that the later arguments are ints, as this is
|
141
|
+
// BC-breaking for FX
|
142
|
+
// for (int i = 1; i < len; i++) {
|
143
|
+
// if (torch::is_symint_node(
|
144
|
+
// py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
|
145
|
+
// return false;
|
146
|
+
// }
|
147
|
+
// }
|
148
|
+
|
149
|
+
if (int_first) {
|
150
|
+
return true;
|
151
|
+
}
|
152
|
+
|
153
|
+
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
154
|
+
// in an intlist argument. Even float or complex scalar tensors.
|
155
|
+
// return (
|
156
|
+
// jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
|
157
|
+
// THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
|
158
|
+
return false;
|
159
|
+
}
|
160
|
+
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
|
161
|
+
// int
|
162
|
+
return broadcast_size > 0 && THPUtils_checkLong(obj);
|
163
|
+
}
|
164
|
+
|
165
|
+
static bool is_int_or_symint(VALUE obj) {
|
166
|
+
return THPUtils_checkIndex(obj);
|
167
|
+
}
|
168
|
+
|
169
|
+
static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
|
170
|
+
if (RB_TYPE_P(obj, T_ARRAY)) {
|
171
|
+
if (RARRAY_LEN(obj) == 0) {
|
172
|
+
return true;
|
173
|
+
}
|
174
|
+
auto item = rb_ary_entry(obj, 0);
|
175
|
+
|
176
|
+
if (is_int_or_symint(item)) {
|
177
|
+
return true;
|
178
|
+
}
|
179
|
+
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
|
180
|
+
// in an intlist argument. Even float or complex scalar tensors.
|
181
|
+
// return (
|
182
|
+
// jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
|
183
|
+
// THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
|
184
|
+
return false;
|
185
|
+
}
|
186
|
+
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
|
187
|
+
// int
|
188
|
+
return broadcast_size > 0 && THPUtils_checkLong(obj);
|
189
|
+
}
|
190
|
+
|
119
191
|
// argnum is needed for raising the TypeError, it's used in the error message.
|
120
192
|
auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
121
193
|
{
|
@@ -182,6 +254,8 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
182
254
|
case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
|
183
255
|
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
|
184
256
|
case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
|
257
|
+
case ParameterType::SYM_INT: return is_int_or_symint(obj);
|
258
|
+
case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
|
185
259
|
default: throw std::runtime_error("unknown parameter type");
|
186
260
|
}
|
187
261
|
}
|
@@ -191,6 +265,7 @@ std::string FunctionParameter::type_name() const {
|
|
191
265
|
case ParameterType::TENSOR: return "Tensor";
|
192
266
|
case ParameterType::SCALAR: return "Number";
|
193
267
|
case ParameterType::INT64: return "int";
|
268
|
+
case ParameterType::SYM_INT: return "SymInt";
|
194
269
|
case ParameterType::DOUBLE: return "float";
|
195
270
|
case ParameterType::COMPLEX: return "complex";
|
196
271
|
case ParameterType::TENSOR_LIST: return "array of Tensors";
|
@@ -208,6 +283,8 @@ std::string FunctionParameter::type_name() const {
|
|
208
283
|
case ParameterType::STRING: return "str";
|
209
284
|
case ParameterType::DIMNAME: return "name";
|
210
285
|
case ParameterType::DIMNAME_LIST: return "array of names";
|
286
|
+
case ParameterType::SCALAR_LIST: return "array of Scalars";
|
287
|
+
case ParameterType::SYM_INT_LIST: return "array of SymInts";
|
211
288
|
default: throw std::runtime_error("unknown parameter type");
|
212
289
|
}
|
213
290
|
}
|
@@ -558,8 +635,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
558
635
|
|
559
636
|
// if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
|
560
637
|
// allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
|
561
|
-
|
638
|
+
int int_list_overload = false;
|
639
|
+
if (max_pos_args == 1 &&
|
640
|
+
(params[0].type_ == ParameterType::INT_LIST ||
|
641
|
+
params[0].type_ == ParameterType::SYM_INT_LIST)) {
|
562
642
|
allow_varargs_intlist = true;
|
643
|
+
if (params[0].type_ == ParameterType::INT_LIST) {
|
644
|
+
int_list_overload = true;
|
645
|
+
}
|
563
646
|
}
|
564
647
|
|
565
648
|
if (nargs > max_pos_args && !allow_varargs_intlist) {
|
@@ -614,8 +697,10 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
614
697
|
// XXX: the Variable check is necessary because sizes become tensors when
|
615
698
|
// tracer is enabled. This behavior easily leads to ambiguities, and we
|
616
699
|
// should avoid having complex signatures that make use of it...
|
617
|
-
} else if (
|
618
|
-
|
700
|
+
} else if (
|
701
|
+
allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
|
702
|
+
((int_list_overload ? is_int_list(args, param.size)
|
703
|
+
: is_int_or_symint_list(args, param.size)))) {
|
619
704
|
// take all positional arguments as this parameter
|
620
705
|
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
|
621
706
|
dst[i++] = args;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -11,9 +11,9 @@
|
|
11
11
|
#include "utils.h"
|
12
12
|
|
13
13
|
enum class ParameterType {
|
14
|
-
TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
15
|
-
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
|
16
|
-
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
|
14
|
+
TENSOR, SCALAR, INT64, SYM_INT, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
15
|
+
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING,
|
16
|
+
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, SYM_INT_LIST
|
17
17
|
};
|
18
18
|
|
19
19
|
struct FunctionParameter {
|
@@ -84,7 +84,9 @@ struct RubyArgs {
|
|
84
84
|
template<int N>
|
85
85
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
86
86
|
inline std::vector<int64_t> intlist(int i);
|
87
|
+
inline std::vector<c10::SymInt> symintlist(int i);
|
87
88
|
inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
89
|
+
inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
|
88
90
|
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
89
91
|
inline c10::optional<at::Generator> generator(int i);
|
90
92
|
inline at::Storage storage(int i);
|
@@ -93,6 +95,7 @@ struct RubyArgs {
|
|
93
95
|
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
94
96
|
inline c10::optional<at::Scalar> scalarOptional(int i);
|
95
97
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
98
|
+
inline c10::optional<c10::SymInt> toSymIntOptional(int i);
|
96
99
|
inline c10::optional<bool> toBoolOptional(int i);
|
97
100
|
inline c10::optional<double> toDoubleOptional(int i);
|
98
101
|
inline c10::OptionalArray<double> doublelistOptional(int i);
|
@@ -116,6 +119,7 @@ struct RubyArgs {
|
|
116
119
|
inline c10::optional<c10::string_view> stringViewOptional(int i);
|
117
120
|
// inline PyObject* pyobject(int i);
|
118
121
|
inline int64_t toInt64(int i);
|
122
|
+
inline c10::SymInt toSymInt(int i);
|
119
123
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
120
124
|
inline double toDouble(int i);
|
121
125
|
// inline double toDoubleWithDefault(int i, double default_double);
|
@@ -171,6 +175,19 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
171
175
|
return intlistWithDefault(i, signature.params[i].default_intlist);
|
172
176
|
}
|
173
177
|
|
178
|
+
inline std::vector<c10::SymInt> RubyArgs::symintlist(int i) {
|
179
|
+
if (NIL_P(args[i])) {
|
180
|
+
return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
|
181
|
+
return c10::SymInt(di);
|
182
|
+
});
|
183
|
+
}
|
184
|
+
|
185
|
+
// TODO improve
|
186
|
+
return c10::fmap(intlist(i), [](int64_t di) {
|
187
|
+
return c10::SymInt(di);
|
188
|
+
});
|
189
|
+
}
|
190
|
+
|
174
191
|
inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
|
175
192
|
if (NIL_P(args[i])) return default_intlist;
|
176
193
|
VALUE arg = args[i];
|
@@ -199,6 +216,11 @@ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
|
|
199
216
|
return intlist(i);
|
200
217
|
}
|
201
218
|
|
219
|
+
inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
|
220
|
+
if (NIL_P(args[i])) return {};
|
221
|
+
return symintlist(i);
|
222
|
+
}
|
223
|
+
|
202
224
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
203
225
|
if (NIL_P(args[i])) return c10::nullopt;
|
204
226
|
throw std::runtime_error("generator not supported yet");
|
@@ -270,6 +292,11 @@ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
|
|
270
292
|
return toInt64(i);
|
271
293
|
}
|
272
294
|
|
295
|
+
inline c10::optional<c10::SymInt> RubyArgs::toSymIntOptional(int i) {
|
296
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
297
|
+
return toSymInt(i);
|
298
|
+
}
|
299
|
+
|
273
300
|
inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
|
274
301
|
if (NIL_P(args[i])) return c10::nullopt;
|
275
302
|
return toBool(i);
|
@@ -376,6 +403,15 @@ inline int64_t RubyArgs::toInt64(int i) {
|
|
376
403
|
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
377
404
|
}
|
378
405
|
|
406
|
+
inline c10::SymInt RubyArgs::toSymInt(int i) {
|
407
|
+
if (NIL_P(args[i])) {
|
408
|
+
return c10::SymInt(signature.params[i].default_int);
|
409
|
+
}
|
410
|
+
|
411
|
+
// TODO improve
|
412
|
+
return c10::SymInt(toInt64(i));
|
413
|
+
}
|
414
|
+
|
379
415
|
inline double RubyArgs::toDouble(int i) {
|
380
416
|
if (NIL_P(args[i])) return signature.params[i].default_double;
|
381
417
|
return Rice::detail::From_Ruby<double>().convert(args[i]);
|
data/ext/torch/utils.h
CHANGED
@@ -6,7 +6,7 @@
|
|
6
6
|
#include <rice/stl.hpp>
|
7
7
|
|
8
8
|
static_assert(
|
9
|
-
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR ==
|
9
|
+
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 13,
|
10
10
|
"Incompatible LibTorch version"
|
11
11
|
);
|
12
12
|
|
@@ -36,6 +36,10 @@ inline bool THPUtils_checkIndex(VALUE obj) {
|
|
36
36
|
return FIXNUM_P(obj);
|
37
37
|
}
|
38
38
|
|
39
|
+
inline bool THPUtils_checkLong(VALUE obj) {
|
40
|
+
return FIXNUM_P(obj);
|
41
|
+
}
|
42
|
+
|
39
43
|
inline bool THPUtils_checkScalar(VALUE obj) {
|
40
44
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
41
45
|
}
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.12.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-11-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|