torch-rb 0.11.2 → 0.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
11
11
  {"double", ParameterType::DOUBLE},
12
12
  {"complex", ParameterType::COMPLEX},
13
13
  {"TensorList", ParameterType::TENSOR_LIST},
14
+ {"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
14
15
  {"IntArrayRef", ParameterType::INT_LIST},
15
16
  {"ArrayRef<double>", ParameterType::FLOAT_LIST},
16
17
  {"Generator", ParameterType::GENERATOR},
@@ -22,9 +23,14 @@ static std::unordered_map<std::string, ParameterType> type_map = {
22
23
  {"MemoryFormat", ParameterType::MEMORY_FORMAT},
23
24
  {"QScheme", ParameterType::QSCHEME},
24
25
  {"Device", ParameterType::DEVICE},
26
+ {"Stream", ParameterType::STREAM},
25
27
  {"std::string", ParameterType::STRING},
28
+ {"c10::string_view", ParameterType::STRING},
29
+ {"SymInt", ParameterType::SYM_INT},
26
30
  {"Dimname", ParameterType::DIMNAME},
31
+ {"SymIntArrayRef", ParameterType::SYM_INT_LIST},
27
32
  {"DimnameList", ParameterType::DIMNAME_LIST},
33
+ {"ScalarList", ParameterType::SCALAR_LIST}
28
34
  };
29
35
 
30
36
  static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
@@ -116,6 +122,72 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
116
122
  return true;
117
123
  }
118
124
 
125
+ static bool is_int_list(VALUE obj, int broadcast_size) {
126
+ if (RB_TYPE_P(obj, T_ARRAY)) {
127
+ auto len = RARRAY_LEN(obj);
128
+ if (len == 0) {
129
+ return true;
130
+ }
131
+
132
+ auto item = rb_ary_entry(obj, 0);
133
+ bool int_first = false;
134
+ if (THPUtils_checkIndex(item)) {
135
+ // we still have to check that the rest of items are NOT symint nodes
136
+ int_first = true;
137
+ }
138
+
139
+ // Make sure none of the later arguments are SymInt
140
+ // NB: do NOT check that the later arguments are ints, as this is
141
+ // BC-breaking for FX
142
+ // for (int i = 1; i < len; i++) {
143
+ // if (torch::is_symint_node(
144
+ // py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
145
+ // return false;
146
+ // }
147
+ // }
148
+
149
+ if (int_first) {
150
+ return true;
151
+ }
152
+
153
+ // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
154
+ // in an intlist argument. Even float or complex scalar tensors.
155
+ // return (
156
+ // jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
157
+ // THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
158
+ return false;
159
+ }
160
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
161
+ // int
162
+ return broadcast_size > 0 && THPUtils_checkLong(obj);
163
+ }
164
+
165
+ static bool is_int_or_symint(VALUE obj) {
166
+ return THPUtils_checkIndex(obj);
167
+ }
168
+
169
+ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
170
+ if (RB_TYPE_P(obj, T_ARRAY)) {
171
+ if (RARRAY_LEN(obj) == 0) {
172
+ return true;
173
+ }
174
+ auto item = rb_ary_entry(obj, 0);
175
+
176
+ if (is_int_or_symint(item)) {
177
+ return true;
178
+ }
179
+ // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
180
+ // in an intlist argument. Even float or complex scalar tensors.
181
+ // return (
182
+ // jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
183
+ // THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
184
+ return false;
185
+ }
186
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
187
+ // int
188
+ return broadcast_size > 0 && THPUtils_checkLong(obj);
189
+ }
190
+
119
191
  // argnum is needed for raising the TypeError, it's used in the error message.
120
192
  auto FunctionParameter::check(VALUE obj, int argnum) -> bool
121
193
  {
@@ -182,6 +254,8 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
182
254
  case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
183
255
  case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
184
256
  case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
257
+ case ParameterType::SYM_INT: return is_int_or_symint(obj);
258
+ case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
185
259
  default: throw std::runtime_error("unknown parameter type");
186
260
  }
187
261
  }
@@ -191,6 +265,7 @@ std::string FunctionParameter::type_name() const {
191
265
  case ParameterType::TENSOR: return "Tensor";
192
266
  case ParameterType::SCALAR: return "Number";
193
267
  case ParameterType::INT64: return "int";
268
+ case ParameterType::SYM_INT: return "SymInt";
194
269
  case ParameterType::DOUBLE: return "float";
195
270
  case ParameterType::COMPLEX: return "complex";
196
271
  case ParameterType::TENSOR_LIST: return "array of Tensors";
@@ -208,6 +283,8 @@ std::string FunctionParameter::type_name() const {
208
283
  case ParameterType::STRING: return "str";
209
284
  case ParameterType::DIMNAME: return "name";
210
285
  case ParameterType::DIMNAME_LIST: return "array of names";
286
+ case ParameterType::SCALAR_LIST: return "array of Scalars";
287
+ case ParameterType::SYM_INT_LIST: return "array of SymInts";
211
288
  default: throw std::runtime_error("unknown parameter type");
212
289
  }
213
290
  }
@@ -558,8 +635,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
558
635
 
559
636
  // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
560
637
  // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
561
- if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
638
+ int int_list_overload = false;
639
+ if (max_pos_args == 1 &&
640
+ (params[0].type_ == ParameterType::INT_LIST ||
641
+ params[0].type_ == ParameterType::SYM_INT_LIST)) {
562
642
  allow_varargs_intlist = true;
643
+ if (params[0].type_ == ParameterType::INT_LIST) {
644
+ int_list_overload = true;
645
+ }
563
646
  }
564
647
 
565
648
  if (nargs > max_pos_args && !allow_varargs_intlist) {
@@ -614,8 +697,10 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
614
697
  // XXX: the Variable check is necessary because sizes become tensors when
615
698
  // tracer is enabled. This behavior easily leads to ambiguities, and we
616
699
  // should avoid having complex signatures that make use of it...
617
- } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
618
- THPUtils_checkIndex(obj)) {
700
+ } else if (
701
+ allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
702
+ ((int_list_overload ? is_int_list(args, param.size)
703
+ : is_int_or_symint_list(args, param.size)))) {
619
704
  // take all positional arguments as this parameter
620
705
  // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
621
706
  dst[i++] = args;
@@ -11,9 +11,9 @@
11
11
  #include "utils.h"
12
12
 
13
13
  enum class ParameterType {
14
- TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
15
- BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
16
- DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
14
+ TENSOR, SCALAR, INT64, SYM_INT, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
15
+ BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING,
16
+ DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, SYM_INT_LIST
17
17
  };
18
18
 
19
19
  struct FunctionParameter {
@@ -84,7 +84,9 @@ struct RubyArgs {
84
84
  template<int N>
85
85
  inline std::array<at::Tensor, N> tensorlist_n(int i);
86
86
  inline std::vector<int64_t> intlist(int i);
87
+ inline std::vector<c10::SymInt> symintlist(int i);
87
88
  inline c10::OptionalArray<int64_t> intlistOptional(int i);
89
+ inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
88
90
  inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
89
91
  inline c10::optional<at::Generator> generator(int i);
90
92
  inline at::Storage storage(int i);
@@ -93,6 +95,7 @@ struct RubyArgs {
93
95
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
94
96
  inline c10::optional<at::Scalar> scalarOptional(int i);
95
97
  inline c10::optional<int64_t> toInt64Optional(int i);
98
+ inline c10::optional<c10::SymInt> toSymIntOptional(int i);
96
99
  inline c10::optional<bool> toBoolOptional(int i);
97
100
  inline c10::optional<double> toDoubleOptional(int i);
98
101
  inline c10::OptionalArray<double> doublelistOptional(int i);
@@ -116,6 +119,7 @@ struct RubyArgs {
116
119
  inline c10::optional<c10::string_view> stringViewOptional(int i);
117
120
  // inline PyObject* pyobject(int i);
118
121
  inline int64_t toInt64(int i);
122
+ inline c10::SymInt toSymInt(int i);
119
123
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
120
124
  inline double toDouble(int i);
121
125
  // inline double toDoubleWithDefault(int i, double default_double);
@@ -171,6 +175,19 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
171
175
  return intlistWithDefault(i, signature.params[i].default_intlist);
172
176
  }
173
177
 
178
+ inline std::vector<c10::SymInt> RubyArgs::symintlist(int i) {
179
+ if (NIL_P(args[i])) {
180
+ return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
181
+ return c10::SymInt(di);
182
+ });
183
+ }
184
+
185
+ // TODO improve
186
+ return c10::fmap(intlist(i), [](int64_t di) {
187
+ return c10::SymInt(di);
188
+ });
189
+ }
190
+
174
191
  inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
175
192
  if (NIL_P(args[i])) return default_intlist;
176
193
  VALUE arg = args[i];
@@ -199,6 +216,11 @@ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
199
216
  return intlist(i);
200
217
  }
201
218
 
219
+ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
220
+ if (NIL_P(args[i])) return {};
221
+ return symintlist(i);
222
+ }
223
+
202
224
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
203
225
  if (NIL_P(args[i])) return c10::nullopt;
204
226
  throw std::runtime_error("generator not supported yet");
@@ -270,6 +292,11 @@ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
270
292
  return toInt64(i);
271
293
  }
272
294
 
295
+ inline c10::optional<c10::SymInt> RubyArgs::toSymIntOptional(int i) {
296
+ if (NIL_P(args[i])) return c10::nullopt;
297
+ return toSymInt(i);
298
+ }
299
+
273
300
  inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
274
301
  if (NIL_P(args[i])) return c10::nullopt;
275
302
  return toBool(i);
@@ -376,6 +403,15 @@ inline int64_t RubyArgs::toInt64(int i) {
376
403
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
377
404
  }
378
405
 
406
+ inline c10::SymInt RubyArgs::toSymInt(int i) {
407
+ if (NIL_P(args[i])) {
408
+ return c10::SymInt(signature.params[i].default_int);
409
+ }
410
+
411
+ // TODO improve
412
+ return c10::SymInt(toInt64(i));
413
+ }
414
+
379
415
  inline double RubyArgs::toDouble(int i) {
380
416
  if (NIL_P(args[i])) return signature.params[i].default_double;
381
417
  return Rice::detail::From_Ruby<double>().convert(args[i]);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 12,
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 13,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -36,6 +36,10 @@ inline bool THPUtils_checkIndex(VALUE obj) {
36
36
  return FIXNUM_P(obj);
37
37
  }
38
38
 
39
+ inline bool THPUtils_checkLong(VALUE obj) {
40
+ return FIXNUM_P(obj);
41
+ }
42
+
39
43
  inline bool THPUtils_checkScalar(VALUE obj) {
40
44
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
41
45
  }
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.11.2"
2
+ VERSION = "0.12.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.11.2
4
+ version: 0.12.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-09-25 00:00:00.000000000 Z
11
+ date: 2022-11-05 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice