torch-rb 0.11.2 → 0.12.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -11,6 +11,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
11
11
  {"double", ParameterType::DOUBLE},
12
12
  {"complex", ParameterType::COMPLEX},
13
13
  {"TensorList", ParameterType::TENSOR_LIST},
14
+ {"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
14
15
  {"IntArrayRef", ParameterType::INT_LIST},
15
16
  {"ArrayRef<double>", ParameterType::FLOAT_LIST},
16
17
  {"Generator", ParameterType::GENERATOR},
@@ -22,9 +23,14 @@ static std::unordered_map<std::string, ParameterType> type_map = {
22
23
  {"MemoryFormat", ParameterType::MEMORY_FORMAT},
23
24
  {"QScheme", ParameterType::QSCHEME},
24
25
  {"Device", ParameterType::DEVICE},
26
+ {"Stream", ParameterType::STREAM},
25
27
  {"std::string", ParameterType::STRING},
28
+ {"c10::string_view", ParameterType::STRING},
29
+ {"SymInt", ParameterType::SYM_INT},
26
30
  {"Dimname", ParameterType::DIMNAME},
31
+ {"SymIntArrayRef", ParameterType::SYM_INT_LIST},
27
32
  {"DimnameList", ParameterType::DIMNAME_LIST},
33
+ {"ScalarList", ParameterType::SCALAR_LIST}
28
34
  };
29
35
 
30
36
  static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
@@ -116,6 +122,72 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
116
122
  return true;
117
123
  }
118
124
 
125
+ static bool is_int_list(VALUE obj, int broadcast_size) {
126
+ if (RB_TYPE_P(obj, T_ARRAY)) {
127
+ auto len = RARRAY_LEN(obj);
128
+ if (len == 0) {
129
+ return true;
130
+ }
131
+
132
+ auto item = rb_ary_entry(obj, 0);
133
+ bool int_first = false;
134
+ if (THPUtils_checkIndex(item)) {
135
+ // we still have to check that the rest of items are NOT symint nodes
136
+ int_first = true;
137
+ }
138
+
139
+ // Make sure none of the later arguments are SymInt
140
+ // NB: do NOT check that the later arguments are ints, as this is
141
+ // BC-breaking for FX
142
+ // for (int i = 1; i < len; i++) {
143
+ // if (torch::is_symint_node(
144
+ // py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
145
+ // return false;
146
+ // }
147
+ // }
148
+
149
+ if (int_first) {
150
+ return true;
151
+ }
152
+
153
+ // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
154
+ // in an intlist argument. Even float or complex scalar tensors.
155
+ // return (
156
+ // jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
157
+ // THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
158
+ return false;
159
+ }
160
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
161
+ // int
162
+ return broadcast_size > 0 && THPUtils_checkLong(obj);
163
+ }
164
+
165
+ static bool is_int_or_symint(VALUE obj) {
166
+ return THPUtils_checkIndex(obj);
167
+ }
168
+
169
+ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
170
+ if (RB_TYPE_P(obj, T_ARRAY)) {
171
+ if (RARRAY_LEN(obj) == 0) {
172
+ return true;
173
+ }
174
+ auto item = rb_ary_entry(obj, 0);
175
+
176
+ if (is_int_or_symint(item)) {
177
+ return true;
178
+ }
179
+ // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
180
+ // in an intlist argument. Even float or complex scalar tensors.
181
+ // return (
182
+ // jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
183
+ // THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
184
+ return false;
185
+ }
186
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
187
+ // int
188
+ return broadcast_size > 0 && THPUtils_checkLong(obj);
189
+ }
190
+
119
191
  // argnum is needed for raising the TypeError, it's used in the error message.
120
192
  auto FunctionParameter::check(VALUE obj, int argnum) -> bool
121
193
  {
@@ -182,6 +254,8 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
182
254
  case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
183
255
  case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
184
256
  case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
257
+ case ParameterType::SYM_INT: return is_int_or_symint(obj);
258
+ case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
185
259
  default: throw std::runtime_error("unknown parameter type");
186
260
  }
187
261
  }
@@ -191,6 +265,7 @@ std::string FunctionParameter::type_name() const {
191
265
  case ParameterType::TENSOR: return "Tensor";
192
266
  case ParameterType::SCALAR: return "Number";
193
267
  case ParameterType::INT64: return "int";
268
+ case ParameterType::SYM_INT: return "SymInt";
194
269
  case ParameterType::DOUBLE: return "float";
195
270
  case ParameterType::COMPLEX: return "complex";
196
271
  case ParameterType::TENSOR_LIST: return "array of Tensors";
@@ -208,6 +283,8 @@ std::string FunctionParameter::type_name() const {
208
283
  case ParameterType::STRING: return "str";
209
284
  case ParameterType::DIMNAME: return "name";
210
285
  case ParameterType::DIMNAME_LIST: return "array of names";
286
+ case ParameterType::SCALAR_LIST: return "array of Scalars";
287
+ case ParameterType::SYM_INT_LIST: return "array of SymInts";
211
288
  default: throw std::runtime_error("unknown parameter type");
212
289
  }
213
290
  }
@@ -558,8 +635,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
558
635
 
559
636
  // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
560
637
  // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
561
- if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
638
+ int int_list_overload = false;
639
+ if (max_pos_args == 1 &&
640
+ (params[0].type_ == ParameterType::INT_LIST ||
641
+ params[0].type_ == ParameterType::SYM_INT_LIST)) {
562
642
  allow_varargs_intlist = true;
643
+ if (params[0].type_ == ParameterType::INT_LIST) {
644
+ int_list_overload = true;
645
+ }
563
646
  }
564
647
 
565
648
  if (nargs > max_pos_args && !allow_varargs_intlist) {
@@ -614,8 +697,10 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
614
697
  // XXX: the Variable check is necessary because sizes become tensors when
615
698
  // tracer is enabled. This behavior easily leads to ambiguities, and we
616
699
  // should avoid having complex signatures that make use of it...
617
- } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
618
- THPUtils_checkIndex(obj)) {
700
+ } else if (
701
+ allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
702
+ ((int_list_overload ? is_int_list(args, param.size)
703
+ : is_int_or_symint_list(args, param.size)))) {
619
704
  // take all positional arguments as this parameter
620
705
  // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
621
706
  dst[i++] = args;
@@ -11,9 +11,9 @@
11
11
  #include "utils.h"
12
12
 
13
13
  enum class ParameterType {
14
- TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
15
- BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
16
- DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
14
+ TENSOR, SCALAR, INT64, SYM_INT, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
15
+ BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING,
16
+ DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, SYM_INT_LIST
17
17
  };
18
18
 
19
19
  struct FunctionParameter {
@@ -84,7 +84,9 @@ struct RubyArgs {
84
84
  template<int N>
85
85
  inline std::array<at::Tensor, N> tensorlist_n(int i);
86
86
  inline std::vector<int64_t> intlist(int i);
87
+ inline std::vector<c10::SymInt> symintlist(int i);
87
88
  inline c10::OptionalArray<int64_t> intlistOptional(int i);
89
+ inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
88
90
  inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
89
91
  inline c10::optional<at::Generator> generator(int i);
90
92
  inline at::Storage storage(int i);
@@ -93,6 +95,7 @@ struct RubyArgs {
93
95
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
94
96
  inline c10::optional<at::Scalar> scalarOptional(int i);
95
97
  inline c10::optional<int64_t> toInt64Optional(int i);
98
+ inline c10::optional<c10::SymInt> toSymIntOptional(int i);
96
99
  inline c10::optional<bool> toBoolOptional(int i);
97
100
  inline c10::optional<double> toDoubleOptional(int i);
98
101
  inline c10::OptionalArray<double> doublelistOptional(int i);
@@ -116,6 +119,7 @@ struct RubyArgs {
116
119
  inline c10::optional<c10::string_view> stringViewOptional(int i);
117
120
  // inline PyObject* pyobject(int i);
118
121
  inline int64_t toInt64(int i);
122
+ inline c10::SymInt toSymInt(int i);
119
123
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
120
124
  inline double toDouble(int i);
121
125
  // inline double toDoubleWithDefault(int i, double default_double);
@@ -171,6 +175,19 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
171
175
  return intlistWithDefault(i, signature.params[i].default_intlist);
172
176
  }
173
177
 
178
+ inline std::vector<c10::SymInt> RubyArgs::symintlist(int i) {
179
+ if (NIL_P(args[i])) {
180
+ return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
181
+ return c10::SymInt(di);
182
+ });
183
+ }
184
+
185
+ // TODO improve
186
+ return c10::fmap(intlist(i), [](int64_t di) {
187
+ return c10::SymInt(di);
188
+ });
189
+ }
190
+
174
191
  inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
175
192
  if (NIL_P(args[i])) return default_intlist;
176
193
  VALUE arg = args[i];
@@ -199,6 +216,11 @@ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
199
216
  return intlist(i);
200
217
  }
201
218
 
219
+ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
220
+ if (NIL_P(args[i])) return {};
221
+ return symintlist(i);
222
+ }
223
+
202
224
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
203
225
  if (NIL_P(args[i])) return c10::nullopt;
204
226
  throw std::runtime_error("generator not supported yet");
@@ -270,6 +292,11 @@ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
270
292
  return toInt64(i);
271
293
  }
272
294
 
295
+ inline c10::optional<c10::SymInt> RubyArgs::toSymIntOptional(int i) {
296
+ if (NIL_P(args[i])) return c10::nullopt;
297
+ return toSymInt(i);
298
+ }
299
+
273
300
  inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
274
301
  if (NIL_P(args[i])) return c10::nullopt;
275
302
  return toBool(i);
@@ -376,6 +403,15 @@ inline int64_t RubyArgs::toInt64(int i) {
376
403
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
377
404
  }
378
405
 
406
+ inline c10::SymInt RubyArgs::toSymInt(int i) {
407
+ if (NIL_P(args[i])) {
408
+ return c10::SymInt(signature.params[i].default_int);
409
+ }
410
+
411
+ // TODO improve
412
+ return c10::SymInt(toInt64(i));
413
+ }
414
+
379
415
  inline double RubyArgs::toDouble(int i) {
380
416
  if (NIL_P(args[i])) return signature.params[i].default_double;
381
417
  return Rice::detail::From_Ruby<double>().convert(args[i]);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 12,
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 13,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -36,6 +36,10 @@ inline bool THPUtils_checkIndex(VALUE obj) {
36
36
  return FIXNUM_P(obj);
37
37
  }
38
38
 
39
+ inline bool THPUtils_checkLong(VALUE obj) {
40
+ return FIXNUM_P(obj);
41
+ }
42
+
39
43
  inline bool THPUtils_checkScalar(VALUE obj) {
40
44
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
41
45
  }
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.11.2"
2
+ VERSION = "0.12.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.11.2
4
+ version: 0.12.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-09-25 00:00:00.000000000 Z
11
+ date: 2022-11-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice