torch-rb 0.11.2 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/ext.cpp CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
12
12
  void init_backends(Rice::Module& m);
13
13
  void init_cuda(Rice::Module& m);
14
14
  void init_device(Rice::Module& m);
15
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
15
16
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
16
17
  void init_random(Rice::Module& m);
17
18
 
@@ -23,6 +24,7 @@ void Init_ext()
23
24
  // need to define certain classes up front to keep Rice happy
24
25
  auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
26
  .define_constructor(Rice::Constructor<torch::IValue>());
27
+ auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
26
28
  auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
29
  auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
30
  .define_constructor(Rice::Constructor<torch::TensorOptions>());
@@ -38,6 +40,7 @@ void Init_ext()
38
40
  init_backends(m);
39
41
  init_cuda(m);
40
42
  init_device(m);
43
+ init_generator(m, rb_cGenerator);
41
44
  init_ivalue(m, rb_cIValue);
42
45
  init_random(m);
43
46
  }
@@ -0,0 +1,50 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
8
+ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
9
+ rb_cGenerator
10
+ .add_handler<torch::Error>(handle_error)
11
+ .define_singleton_function(
12
+ "new",
13
+ []() {
14
+ // TODO support more devices
15
+ return torch::make_generator<torch::CPUGeneratorImpl>();
16
+ })
17
+ .define_method(
18
+ "device",
19
+ [](torch::Generator& self) {
20
+ return self.device();
21
+ })
22
+ .define_method(
23
+ "initial_seed",
24
+ [](torch::Generator& self) {
25
+ return self.current_seed();
26
+ })
27
+ .define_method(
28
+ "manual_seed",
29
+ [](torch::Generator& self, uint64_t seed) {
30
+ self.set_current_seed(seed);
31
+ return self;
32
+ })
33
+ .define_method(
34
+ "seed",
35
+ [](torch::Generator& self) {
36
+ return self.seed();
37
+ })
38
+ .define_method(
39
+ "state",
40
+ [](torch::Generator& self) {
41
+ return self.get_state();
42
+ })
43
+ .define_method(
44
+ "state=",
45
+ [](torch::Generator& self, const torch::Tensor& state) {
46
+ self.set_state(state);
47
+ });
48
+
49
+ THPGeneratorClass = rb_cGenerator.value();
50
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  #include "ruby_arg_parser.h"
4
4
 
5
+ VALUE THPGeneratorClass = Qnil;
5
6
  VALUE THPVariableClass = Qnil;
6
7
 
7
8
  static std::unordered_map<std::string, ParameterType> type_map = {
@@ -11,6 +12,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
11
12
  {"double", ParameterType::DOUBLE},
12
13
  {"complex", ParameterType::COMPLEX},
13
14
  {"TensorList", ParameterType::TENSOR_LIST},
15
+ {"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
14
16
  {"IntArrayRef", ParameterType::INT_LIST},
15
17
  {"ArrayRef<double>", ParameterType::FLOAT_LIST},
16
18
  {"Generator", ParameterType::GENERATOR},
@@ -22,9 +24,14 @@ static std::unordered_map<std::string, ParameterType> type_map = {
22
24
  {"MemoryFormat", ParameterType::MEMORY_FORMAT},
23
25
  {"QScheme", ParameterType::QSCHEME},
24
26
  {"Device", ParameterType::DEVICE},
27
+ {"Stream", ParameterType::STREAM},
25
28
  {"std::string", ParameterType::STRING},
29
+ {"c10::string_view", ParameterType::STRING},
30
+ {"SymInt", ParameterType::SYM_INT},
26
31
  {"Dimname", ParameterType::DIMNAME},
32
+ {"SymIntArrayRef", ParameterType::SYM_INT_LIST},
27
33
  {"DimnameList", ParameterType::DIMNAME_LIST},
34
+ {"ScalarList", ParameterType::SCALAR_LIST}
28
35
  };
29
36
 
30
37
  static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
@@ -116,6 +123,72 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
116
123
  return true;
117
124
  }
118
125
 
126
+ static bool is_int_list(VALUE obj, int broadcast_size) {
127
+ if (RB_TYPE_P(obj, T_ARRAY)) {
128
+ auto len = RARRAY_LEN(obj);
129
+ if (len == 0) {
130
+ return true;
131
+ }
132
+
133
+ auto item = rb_ary_entry(obj, 0);
134
+ bool int_first = false;
135
+ if (THPUtils_checkIndex(item)) {
136
+ // we still have to check that the rest of items are NOT symint nodes
137
+ int_first = true;
138
+ }
139
+
140
+ // Make sure none of the later arguments are SymInt
141
+ // NB: do NOT check that the later arguments are ints, as this is
142
+ // BC-breaking for FX
143
+ // for (int i = 1; i < len; i++) {
144
+ // if (torch::is_symint_node(
145
+ // py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
146
+ // return false;
147
+ // }
148
+ // }
149
+
150
+ if (int_first) {
151
+ return true;
152
+ }
153
+
154
+ // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
155
+ // in an intlist argument. Even float or complex scalar tensors.
156
+ // return (
157
+ // jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
158
+ // THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
159
+ return false;
160
+ }
161
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
162
+ // int
163
+ return broadcast_size > 0 && THPUtils_checkLong(obj);
164
+ }
165
+
166
+ static bool is_int_or_symint(VALUE obj) {
167
+ return THPUtils_checkIndex(obj);
168
+ }
169
+
170
+ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
171
+ if (RB_TYPE_P(obj, T_ARRAY)) {
172
+ if (RARRAY_LEN(obj) == 0) {
173
+ return true;
174
+ }
175
+ auto item = rb_ary_entry(obj, 0);
176
+
177
+ if (is_int_or_symint(item)) {
178
+ return true;
179
+ }
180
+ // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
181
+ // in an intlist argument. Even float or complex scalar tensors.
182
+ // return (
183
+ // jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
184
+ // THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
185
+ return false;
186
+ }
187
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
188
+ // int
189
+ return broadcast_size > 0 && THPUtils_checkLong(obj);
190
+ }
191
+
119
192
  // argnum is needed for raising the TypeError, it's used in the error message.
120
193
  auto FunctionParameter::check(VALUE obj, int argnum) -> bool
121
194
  {
@@ -172,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
172
245
  return size > 0 && FIXNUM_P(obj);
173
246
  }
174
247
  case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
175
- case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
248
+ case ParameterType::GENERATOR: return THPGenerator_Check(obj);
176
249
  case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
177
250
  case ParameterType::STORAGE: return false; // return isStorage(obj);
178
251
  // case ParameterType::PYOBJECT: return true;
@@ -182,6 +255,8 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
182
255
  case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
183
256
  case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
184
257
  case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
258
+ case ParameterType::SYM_INT: return is_int_or_symint(obj);
259
+ case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
185
260
  default: throw std::runtime_error("unknown parameter type");
186
261
  }
187
262
  }
@@ -191,6 +266,7 @@ std::string FunctionParameter::type_name() const {
191
266
  case ParameterType::TENSOR: return "Tensor";
192
267
  case ParameterType::SCALAR: return "Number";
193
268
  case ParameterType::INT64: return "int";
269
+ case ParameterType::SYM_INT: return "SymInt";
194
270
  case ParameterType::DOUBLE: return "float";
195
271
  case ParameterType::COMPLEX: return "complex";
196
272
  case ParameterType::TENSOR_LIST: return "array of Tensors";
@@ -208,6 +284,8 @@ std::string FunctionParameter::type_name() const {
208
284
  case ParameterType::STRING: return "str";
209
285
  case ParameterType::DIMNAME: return "name";
210
286
  case ParameterType::DIMNAME_LIST: return "array of names";
287
+ case ParameterType::SCALAR_LIST: return "array of Scalars";
288
+ case ParameterType::SYM_INT_LIST: return "array of SymInts";
211
289
  default: throw std::runtime_error("unknown parameter type");
212
290
  }
213
291
  }
@@ -558,8 +636,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
558
636
 
559
637
  // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
560
638
  // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
561
- if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
639
+ int int_list_overload = false;
640
+ if (max_pos_args == 1 &&
641
+ (params[0].type_ == ParameterType::INT_LIST ||
642
+ params[0].type_ == ParameterType::SYM_INT_LIST)) {
562
643
  allow_varargs_intlist = true;
644
+ if (params[0].type_ == ParameterType::INT_LIST) {
645
+ int_list_overload = true;
646
+ }
563
647
  }
564
648
 
565
649
  if (nargs > max_pos_args && !allow_varargs_intlist) {
@@ -614,8 +698,10 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
614
698
  // XXX: the Variable check is necessary because sizes become tensors when
615
699
  // tracer is enabled. This behavior easily leads to ambiguities, and we
616
700
  // should avoid having complex signatures that make use of it...
617
- } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
618
- THPUtils_checkIndex(obj)) {
701
+ } else if (
702
+ allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
703
+ ((int_list_overload ? is_int_list(args, param.size)
704
+ : is_int_or_symint_list(args, param.size)))) {
619
705
  // take all positional arguments as this parameter
620
706
  // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
621
707
  dst[i++] = args;
@@ -11,9 +11,9 @@
11
11
  #include "utils.h"
12
12
 
13
13
  enum class ParameterType {
14
- TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
15
- BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
16
- DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
14
+ TENSOR, SCALAR, INT64, SYM_INT, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
15
+ BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STREAM, STRING,
16
+ DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST, SCALAR_LIST, SYM_INT_LIST
17
17
  };
18
18
 
19
19
  struct FunctionParameter {
@@ -84,7 +84,9 @@ struct RubyArgs {
84
84
  template<int N>
85
85
  inline std::array<at::Tensor, N> tensorlist_n(int i);
86
86
  inline std::vector<int64_t> intlist(int i);
87
+ inline std::vector<c10::SymInt> symintlist(int i);
87
88
  inline c10::OptionalArray<int64_t> intlistOptional(int i);
89
+ inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
88
90
  inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
89
91
  inline c10::optional<at::Generator> generator(int i);
90
92
  inline at::Storage storage(int i);
@@ -93,6 +95,7 @@ struct RubyArgs {
93
95
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
94
96
  inline c10::optional<at::Scalar> scalarOptional(int i);
95
97
  inline c10::optional<int64_t> toInt64Optional(int i);
98
+ inline c10::optional<c10::SymInt> toSymIntOptional(int i);
96
99
  inline c10::optional<bool> toBoolOptional(int i);
97
100
  inline c10::optional<double> toDoubleOptional(int i);
98
101
  inline c10::OptionalArray<double> doublelistOptional(int i);
@@ -116,6 +119,7 @@ struct RubyArgs {
116
119
  inline c10::optional<c10::string_view> stringViewOptional(int i);
117
120
  // inline PyObject* pyobject(int i);
118
121
  inline int64_t toInt64(int i);
122
+ inline c10::SymInt toSymInt(int i);
119
123
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
120
124
  inline double toDouble(int i);
121
125
  // inline double toDoubleWithDefault(int i, double default_double);
@@ -171,6 +175,19 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
171
175
  return intlistWithDefault(i, signature.params[i].default_intlist);
172
176
  }
173
177
 
178
+ inline std::vector<c10::SymInt> RubyArgs::symintlist(int i) {
179
+ if (NIL_P(args[i])) {
180
+ return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
181
+ return c10::SymInt(di);
182
+ });
183
+ }
184
+
185
+ // TODO improve
186
+ return c10::fmap(intlist(i), [](int64_t di) {
187
+ return c10::SymInt(di);
188
+ });
189
+ }
190
+
174
191
  inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
175
192
  if (NIL_P(args[i])) return default_intlist;
176
193
  VALUE arg = args[i];
@@ -199,9 +216,14 @@ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
199
216
  return intlist(i);
200
217
  }
201
218
 
219
+ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
220
+ if (NIL_P(args[i])) return {};
221
+ return symintlist(i);
222
+ }
223
+
202
224
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
203
225
  if (NIL_P(args[i])) return c10::nullopt;
204
- throw std::runtime_error("generator not supported yet");
226
+ return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
205
227
  }
206
228
 
207
229
  inline at::Storage RubyArgs::storage(int i) {
@@ -270,6 +292,11 @@ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
270
292
  return toInt64(i);
271
293
  }
272
294
 
295
+ inline c10::optional<c10::SymInt> RubyArgs::toSymIntOptional(int i) {
296
+ if (NIL_P(args[i])) return c10::nullopt;
297
+ return toSymInt(i);
298
+ }
299
+
273
300
  inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
274
301
  if (NIL_P(args[i])) return c10::nullopt;
275
302
  return toBool(i);
@@ -376,6 +403,15 @@ inline int64_t RubyArgs::toInt64(int i) {
376
403
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
377
404
  }
378
405
 
406
+ inline c10::SymInt RubyArgs::toSymInt(int i) {
407
+ if (NIL_P(args[i])) {
408
+ return c10::SymInt(signature.params[i].default_int);
409
+ }
410
+
411
+ // TODO improve
412
+ return c10::SymInt(toInt64(i));
413
+ }
414
+
379
415
  inline double RubyArgs::toDouble(int i) {
380
416
  if (NIL_P(args[i])) return signature.params[i].default_double;
381
417
  return Rice::detail::From_Ruby<double>().convert(args[i]);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 12,
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 13,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
17
17
 
18
18
  // keep THP prefix for now to make it easier to compare code
19
19
 
20
+ extern VALUE THPGeneratorClass;
20
21
  extern VALUE THPVariableClass;
21
22
 
22
23
  inline VALUE THPUtils_internSymbol(const std::string& str) {
@@ -36,10 +37,18 @@ inline bool THPUtils_checkIndex(VALUE obj) {
36
37
  return FIXNUM_P(obj);
37
38
  }
38
39
 
40
+ inline bool THPUtils_checkLong(VALUE obj) {
41
+ return FIXNUM_P(obj);
42
+ }
43
+
39
44
  inline bool THPUtils_checkScalar(VALUE obj) {
40
45
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
41
46
  }
42
47
 
48
+ inline bool THPGenerator_Check(VALUE obj) {
49
+ return rb_obj_is_kind_of(obj, THPGeneratorClass);
50
+ }
51
+
43
52
  inline bool THPVariable_Check(VALUE obj) {
44
53
  return rb_obj_is_kind_of(obj, THPVariableClass);
45
54
  }
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.11.2"
2
+ VERSION = "0.12.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.11.2
4
+ version: 0.12.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-09-25 00:00:00.000000000 Z
11
+ date: 2023-01-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -44,6 +44,7 @@ files:
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
46
  - ext/torch/fft_functions.h
47
+ - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
49
50
  - ext/torch/linalg_functions.h
@@ -229,7 +230,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
229
230
  - !ruby/object:Gem::Version
230
231
  version: '0'
231
232
  requirements: []
232
- rubygems_version: 3.3.7
233
+ rubygems_version: 3.4.1
233
234
  signing_key:
234
235
  specification_version: 4
235
236
  summary: Deep learning for Ruby, powered by LibTorch