torch-rb 0.23.1 → 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/ivalue.cpp CHANGED
@@ -53,7 +53,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
53
53
  [](torch::IValue& self) {
54
54
  auto list = self.toListRef();
55
55
  Rice::Array obj;
56
- for (auto& elem : list) {
56
+ for (const auto& elem : list) {
57
57
  auto v = torch::IValue{elem};
58
58
  obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)), false);
59
59
  }
@@ -74,7 +74,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
74
74
  [](torch::IValue& self) {
75
75
  auto dict = self.toGenericDict();
76
76
  Rice::Hash obj;
77
- for (auto& pair : dict) {
77
+ for (const auto& pair : dict) {
78
78
  auto k = torch::IValue{pair.key()};
79
79
  auto v = torch::IValue{pair.value()};
80
80
  obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
@@ -91,7 +91,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
91
91
  "from_list",
92
92
  [](Rice::Array obj) {
93
93
  c10::impl::GenericList list(c10::AnyType::get());
94
- for (auto entry : obj) {
94
+ for (const auto& entry : obj) {
95
95
  list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
96
96
  }
97
97
  return torch::IValue(list);
@@ -125,7 +125,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
125
125
  auto value_type = c10::AnyType::get();
126
126
  c10::impl::GenericDict elems(key_type, value_type);
127
127
  elems.reserve(obj.size());
128
- for (auto entry : obj) {
128
+ for (const auto& entry : obj) {
129
129
  elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
130
130
  }
131
131
  return torch::IValue(std::move(elems));
@@ -117,11 +117,11 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
117
117
  }
118
118
  auto size = RARRAY_LEN(obj);
119
119
  for (int idx = 0; idx < size; idx++) {
120
- VALUE iobj = rb_ary_entry(obj, idx);
120
+ VALUE iobj = Rice::detail::protect(rb_ary_entry, obj, idx);
121
121
  if (!THPVariable_Check(iobj)) {
122
122
  if (throw_error) {
123
123
  throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
124
- static_cast<int>(idx), argnum, rb_obj_classname(obj));
124
+ static_cast<int>(idx), argnum, Rice::detail::protect(rb_obj_classname, obj));
125
125
  }
126
126
  return false;
127
127
  }
@@ -136,7 +136,7 @@ static bool is_int_list(VALUE obj, int broadcast_size) {
136
136
  return true;
137
137
  }
138
138
 
139
- auto item = rb_ary_entry(obj, 0);
139
+ auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
140
140
  bool int_first = false;
141
141
  if (THPUtils_checkIndex(item)) {
142
142
  // we still have to check that the rest of items are NOT symint nodes
@@ -178,7 +178,7 @@ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
178
178
  if (RARRAY_LEN(obj) == 0) {
179
179
  return true;
180
180
  }
181
- auto item = rb_ary_entry(obj, 0);
181
+ auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
182
182
 
183
183
  if (is_int_or_symint(item)) {
184
184
  return true;
@@ -631,23 +631,23 @@ static ssize_t find_param(FunctionSignature& signature, VALUE name) {
631
631
  static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
632
632
  VALUE key;
633
633
 
634
- VALUE keys = rb_funcall(kwargs, rb_intern("keys"), 0);
634
+ VALUE keys = Rice::detail::protect(rb_funcall, kwargs, rb_intern("keys"), 0);
635
635
  if (RARRAY_LEN(keys) > 0) {
636
- key = rb_ary_entry(keys, 0);
636
+ key = Rice::detail::protect(rb_ary_entry, keys, 0);
637
637
 
638
638
  if (!THPUtils_checkSymbol(key)) {
639
- throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
639
+ throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", Rice::detail::protect(rb_obj_classname, key));
640
640
  }
641
641
 
642
642
  auto param_idx = find_param(signature, key);
643
643
  if (param_idx < 0) {
644
644
  throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
645
- signature.name.c_str(), rb_id2name(rb_to_id(key)));
645
+ signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
646
646
  }
647
647
 
648
648
  if (param_idx < num_pos_args) {
649
649
  throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
650
- signature.name.c_str(), rb_id2name(rb_to_id(key)));
650
+ signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
651
651
  }
652
652
  }
653
653
 
@@ -703,14 +703,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
703
703
  }
704
704
  return false;
705
705
  }
706
- obj = rb_ary_entry(args, arg_pos);
706
+ obj = Rice::detail::protect(rb_ary_entry, args, arg_pos);
707
707
  } else if (!NIL_P(kwargs)) {
708
- obj = rb_hash_lookup2(kwargs, param.ruby_name, missing);
708
+ obj = Rice::detail::protect(rb_hash_lookup2, kwargs, param.ruby_name, missing);
709
709
  // for (VALUE numpy_name: param.numpy_python_names) {
710
710
  // if (obj) {
711
711
  // break;
712
712
  // }
713
- // obj = rb_hash_aref(kwargs, numpy_name);
713
+ // obj = Rice::detail::protect(rb_hash_aref, kwargs, numpy_name);
714
714
  // }
715
715
  is_kwd = true;
716
716
  }
@@ -742,12 +742,12 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
742
742
  // foo(): argument 'other' must be str, not int
743
743
  throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
744
744
  name.c_str(), param.name.c_str(), param.type_name().c_str(),
745
- rb_obj_classname(obj));
745
+ Rice::detail::protect(rb_obj_classname, obj));
746
746
  } else {
747
747
  // foo(): argument 'other' (position 2) must be str, not int
748
748
  throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
749
749
  name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
750
- param.type_name().c_str(), rb_obj_classname(obj));
750
+ param.type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj));
751
751
  }
752
752
  } else {
753
753
  return false;
@@ -168,7 +168,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
168
168
  throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
169
169
  }
170
170
  for (int idx = 0; idx < size; idx++) {
171
- VALUE obj = rb_ary_entry(arg, idx);
171
+ VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
172
172
  res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
173
173
  }
174
174
  return res;
@@ -202,13 +202,13 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
202
202
  size = RARRAY_LEN(arg);
203
203
  std::vector<int64_t> res(size);
204
204
  for (idx = 0; idx < size; idx++) {
205
- VALUE obj = rb_ary_entry(arg, idx);
205
+ VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
206
206
  if (FIXNUM_P(obj)) {
207
207
  res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
208
208
  } else {
209
209
  throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
210
210
  signature.name.c_str(), signature.params[i].name.c_str(),
211
- signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
211
+ signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
212
212
  }
213
213
  }
214
214
  return res;
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
270
270
 
271
271
  auto it = dtype_map.find(args[i]);
272
272
  if (it == dtype_map.end()) {
273
- throw Rice::Exception(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
273
+ throw Rice::Exception(rb_eArgError, "invalid dtype: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
274
274
  }
275
275
  return it->second;
276
276
  }
@@ -317,13 +317,13 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
317
317
  auto size = RARRAY_LEN(arg);
318
318
  std::vector<double> res(size);
319
319
  for (idx = 0; idx < size; idx++) {
320
- VALUE obj = rb_ary_entry(arg, idx);
320
+ VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
321
321
  if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
322
322
  res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
323
323
  } else {
324
324
  throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
325
325
  signature.name.c_str(), signature.params[i].name.c_str(),
326
- signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
326
+ signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
327
327
  }
328
328
  }
329
329
  return res;
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
338
338
 
339
339
  auto it = layout_map.find(args[i]);
340
340
  if (it == layout_map.end()) {
341
- throw Rice::Exception(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
341
+ throw Rice::Exception(rb_eArgError, "invalid layout: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
342
342
  }
343
343
  return it->second;
344
344
  }
@@ -358,7 +358,7 @@ inline at::Device RubyArgs::device(int i) {
358
358
  return at::Device("cpu");
359
359
  }
360
360
  if (RB_TYPE_P(args[i], T_STRING)) {
361
- const std::string &device_str = THPUtils_unpackString(args[i]);
361
+ const std::string& device_str = THPUtils_unpackString(args[i]);
362
362
  return at::Device(device_str);
363
363
  }
364
364
  return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
@@ -461,13 +461,13 @@ struct RubyArgParser {
461
461
 
462
462
  // Check deprecated signatures last
463
463
  std::stable_partition(signatures_.begin(), signatures_.end(),
464
- [](const FunctionSignature & sig) {
464
+ [](const FunctionSignature& sig) {
465
465
  return !sig.deprecated;
466
466
  });
467
467
  }
468
468
 
469
469
  template<int N>
470
- inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
470
+ inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N>& dst) {
471
471
  if (N < max_args) {
472
472
  throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
473
473
  }
@@ -476,7 +476,7 @@ struct RubyArgParser {
476
476
 
477
477
  inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
478
478
  VALUE args, kwargs;
479
- rb_scan_args(argc, argv, "*:", &args, &kwargs);
479
+ Rice::detail::protect(rb_scan_args, argc, argv, "*:", &args, &kwargs);
480
480
 
481
481
  if (signatures_.size() == 1) {
482
482
  auto& signature = signatures_[0];
@@ -54,7 +54,7 @@ namespace Rice::detail {
54
54
 
55
55
  explicit To_Ruby(Arg* arg) : arg_(arg) { }
56
56
 
57
- VALUE convert(c10::complex<T> const& x) {
57
+ VALUE convert(const c10::complex<T>& x) {
58
58
  return protect(rb_dbl_complex_new, x.real(), x.imag());
59
59
  }
60
60
 
data/ext/torch/tensor.cpp CHANGED
@@ -1,5 +1,6 @@
1
1
  #include <optional>
2
2
  #include <string>
3
+ #include <string_view>
3
4
  #include <vector>
4
5
 
5
6
  #include <torch/torch.h>
@@ -68,7 +69,7 @@ std::vector<TensorIndex> index_vector(Array a) {
68
69
  } else if (obj.value() == Qtrue || obj.value() == Qfalse) {
69
70
  indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
70
71
  } else {
71
- throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj.value()));
72
+ throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", Rice::detail::protect(rb_obj_classname, obj.value()));
72
73
  }
73
74
  }
74
75
  return indices;
@@ -108,8 +109,6 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
108
109
  .define_method("mps?", [](Tensor& self) { return self.is_mps(); })
109
110
  .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
110
111
  .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
111
- .define_method("dim", [](Tensor& self) { return self.dim(); })
112
- .define_method("numel", [](Tensor& self) { return self.numel(); })
113
112
  .define_method("element_size", [](Tensor& self) { return self.element_size(); })
114
113
  .define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
115
114
  .define_method(
@@ -127,7 +126,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
127
126
  "shape",
128
127
  [](Tensor& self) {
129
128
  Array a;
130
- for (auto &size : self.sizes()) {
129
+ for (const auto& size : self.sizes()) {
131
130
  a.push(size, false);
132
131
  }
133
132
  return a;
@@ -136,7 +135,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
136
135
  "_strides",
137
136
  [](Tensor& self) {
138
137
  Array a;
139
- for (auto &stride : self.strides()) {
138
+ for (const auto& stride : self.strides()) {
140
139
  a.push(stride, false);
141
140
  }
142
141
  return a;
@@ -153,11 +152,6 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
153
152
  auto vec = index_vector(indices);
154
153
  return self.index_put_(vec, value);
155
154
  })
156
- .define_method(
157
- "contiguous?",
158
- [](Tensor& self) {
159
- return self.is_contiguous();
160
- })
161
155
  .define_method(
162
156
  "_requires_grad!",
163
157
  [](Tensor& self, bool requires_grad) {
@@ -233,8 +227,8 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
233
227
  tensor = tensor.contiguous();
234
228
  }
235
229
 
236
- auto data_ptr = (const char *) tensor.data_ptr();
237
- return std::string(data_ptr, tensor.numel() * tensor.element_size());
230
+ auto data_ptr = static_cast<const char *>(tensor.data_ptr());
231
+ return Rice::String(std::string_view(data_ptr, tensor.numel() * tensor.element_size()));
238
232
  })
239
233
  // for TorchVision
240
234
  .define_method(
data/ext/torch/torch.cpp CHANGED
@@ -12,7 +12,7 @@
12
12
  #include "utils.h"
13
13
 
14
14
  template<typename T>
15
- torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
15
+ torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
16
16
  std::vector<T> vec;
17
17
  vec.reserve(a.size());
18
18
  for (long i = 0; i < a.size(); i++) {
@@ -61,13 +61,13 @@ void init_torch(Rice::Module& m) {
61
61
  // begin operations
62
62
  .define_singleton_function(
63
63
  "_save",
64
- [](const torch::IValue &value) {
64
+ [](const torch::IValue& value) {
65
65
  auto v = torch::pickle_save(value);
66
- return Rice::Object(rb_str_new(v.data(), v.size()));
66
+ return Rice::Object(Rice::detail::protect(rb_str_new, v.data(), v.size()));
67
67
  })
68
68
  .define_singleton_function(
69
69
  "_load",
70
- [](const std::string &filename) {
70
+ [](const std::string& filename) {
71
71
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
72
72
  std::ifstream input(filename, std::ios::binary);
73
73
  std::vector<char> bytes(
@@ -78,13 +78,13 @@ void init_torch(Rice::Module& m) {
78
78
  })
79
79
  .define_singleton_function(
80
80
  "_from_blob",
81
- [](Rice::String s, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
81
+ [](Rice::String s, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
82
82
  void *data = const_cast<char *>(s.c_str());
83
83
  return torch::from_blob(data, size, options);
84
84
  })
85
85
  .define_singleton_function(
86
86
  "_tensor",
87
- [](Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
87
+ [](Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
88
88
  auto dtype = options.dtype();
89
89
  if (dtype == torch::kByte) {
90
90
  return make_tensor<uint8_t>(a, size, options);
data/ext/torch/utils.h CHANGED
@@ -8,7 +8,7 @@
8
8
  #include <rice/stl.hpp>
9
9
 
10
10
  static_assert(
11
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 10,
11
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 12,
12
12
  "Incompatible LibTorch version"
13
13
  );
14
14
 
@@ -50,17 +50,17 @@ inline bool THPUtils_checkScalar(VALUE obj) {
50
50
  }
51
51
 
52
52
  inline bool THPDevice_Check(VALUE obj) {
53
- return rb_obj_is_kind_of(obj, THPDeviceClass);
53
+ return Rice::detail::protect(rb_obj_is_kind_of, obj, THPDeviceClass);
54
54
  }
55
55
 
56
56
  inline bool THPGenerator_Check(VALUE obj) {
57
- return rb_obj_is_kind_of(obj, THPGeneratorClass);
57
+ return Rice::detail::protect(rb_obj_is_kind_of, obj, THPGeneratorClass);
58
58
  }
59
59
 
60
60
  inline bool THPVariable_Check(VALUE obj) {
61
- return rb_obj_is_kind_of(obj, THPVariableClass);
61
+ return Rice::detail::protect(rb_obj_is_kind_of, obj, THPVariableClass);
62
62
  }
63
63
 
64
64
  inline bool THPVariable_CheckExact(VALUE obj) {
65
- return rb_obj_is_instance_of(obj, THPVariableClass);
65
+ return Rice::detail::protect(rb_obj_is_instance_of, obj, THPVariableClass);
66
66
  }
@@ -15,32 +15,34 @@ inline VALUE wrap(double x) {
15
15
  return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline VALUE wrap(torch::Tensor x) {
18
+ inline VALUE wrap(const torch::Tensor& x) {
19
19
  return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline VALUE wrap(torch::Scalar x) {
22
+ inline VALUE wrap(const torch::Scalar& x) {
23
23
  return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline VALUE wrap(torch::ScalarType x) {
26
+ inline VALUE wrap(const torch::ScalarType& x) {
27
27
  return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline VALUE wrap(torch::QScheme x) {
30
+ inline VALUE wrap(const torch::QScheme& x) {
31
31
  return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- return rb_ary_new3(
34
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor>& x) {
35
+ return Rice::detail::protect(
36
+ rb_ary_new3,
36
37
  2,
37
38
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
39
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
40
  );
40
41
  }
41
42
 
42
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
- return rb_ary_new3(
43
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>& x) {
44
+ return Rice::detail::protect(
45
+ rb_ary_new3,
44
46
  3,
45
47
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
48
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -48,8 +50,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
48
50
  );
49
51
  }
50
52
 
51
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
- return rb_ary_new3(
53
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
54
+ return Rice::detail::protect(
55
+ rb_ary_new3,
53
56
  4,
54
57
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
58
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -58,8 +61,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
58
61
  );
59
62
  }
60
63
 
61
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
- return rb_ary_new3(
64
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
65
+ return Rice::detail::protect(
66
+ rb_ary_new3,
63
67
  5,
64
68
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
69
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -69,8 +73,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
69
73
  );
70
74
  }
71
75
 
72
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
- return rb_ary_new3(
76
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t>& x) {
77
+ return Rice::detail::protect(
78
+ rb_ary_new3,
74
79
  4,
75
80
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
81
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -79,8 +84,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_
79
84
  );
80
85
  }
81
86
 
82
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
- return rb_ary_new3(
87
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, double, int64_t>& x) {
88
+ return Rice::detail::protect(
89
+ rb_ary_new3,
84
90
  4,
85
91
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
92
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -89,16 +95,17 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
89
95
  );
90
96
  }
91
97
 
92
- inline VALUE wrap(torch::TensorList x) {
93
- auto a = rb_ary_new2(x.size());
94
- for (auto t : x) {
95
- rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
98
+ inline VALUE wrap(const torch::TensorList& x) {
99
+ auto a = Rice::detail::protect(rb_ary_new2, x.size());
100
+ for (const auto& t : x) {
101
+ Rice::detail::protect(rb_ary_push, a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
96
102
  }
97
103
  return a;
98
104
  }
99
105
 
100
- inline VALUE wrap(std::tuple<double, double> x) {
101
- return rb_ary_new3(
106
+ inline VALUE wrap(const std::tuple<double, double>& x) {
107
+ return Rice::detail::protect(
108
+ rb_ary_new3,
102
109
  2,
103
110
  Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
111
  Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
data/lib/torch/hub.rb CHANGED
@@ -6,37 +6,17 @@ module Torch
6
6
  end
7
7
 
8
8
  def download_url_to_file(url, dst)
9
- uri = URI(url)
10
- tmp = nil
11
- location = nil
9
+ require "open-uri"
12
10
 
13
- puts "Downloading #{url}..."
14
- Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
15
- request = Net::HTTP::Get.new(uri)
16
-
17
- http.request(request) do |response|
18
- case response
19
- when Net::HTTPRedirection
20
- location = response["location"]
21
- when Net::HTTPSuccess
22
- tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
23
- File.open(tmp, "wb") do |f|
24
- response.read_body do |chunk|
25
- f.write(chunk)
26
- end
27
- end
28
- else
29
- raise Error, "Bad response"
30
- end
31
- end
32
- end
11
+ uri = URI.parse(url)
12
+ raise "Invalid URL" unless uri.is_a?(URI::HTTP) # includes https
33
13
 
34
- if location
35
- download_url_to_file(location, dst)
36
- else
37
- FileUtils.mv(tmp, dst)
38
- nil
14
+ puts "Downloading #{url}..."
15
+ uri.open(max_redirects: 10) do |download|
16
+ # TODO move file when possible
17
+ IO.copy_stream(download, dst.to_str)
39
18
  end
19
+ nil
40
20
  end
41
21
 
42
22
  def load_state_dict_from_url(url, model_dir: nil)
@@ -197,7 +197,7 @@ module Torch
197
197
  named_buffers.values
198
198
  end
199
199
 
200
- # TODO set recurse: true in 0.18.0
200
+ # TODO set recurse: true in future
201
201
  def named_buffers(prefix: "", recurse: false)
202
202
  buffers = {}
203
203
  if recurse
@@ -161,7 +161,7 @@ module Torch
161
161
  private
162
162
 
163
163
  def _flat_weights
164
- @all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
164
+ @all_weights.flatten.filter_map { |v| instance_variable_get("@#{v}") }
165
165
  end
166
166
 
167
167
  def _get_flat_weights
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.23.1"
2
+ VERSION = "0.24.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -279,7 +279,7 @@ module Torch
279
279
  def self._make_tensor_class(dtype, cuda = false)
280
280
  cls = Class.new
281
281
  device = cuda ? "cuda" : "cpu"
282
- cls.define_singleton_method("new") do |*args|
282
+ cls.define_singleton_method(:new) do |*args|
283
283
  if args.size == 1 && args.first.is_a?(Tensor)
284
284
  args.first.send(dtype).to(device)
285
285
  elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
@@ -347,7 +347,7 @@ module Torch
347
347
  # from_blob does not own the data, so we need to keep
348
348
  # a reference to it for duration of tensor
349
349
  # can remove when passing pointer directly
350
- tensor.instance_variable_set("@_numo_data", data)
350
+ tensor.instance_variable_set(:@_numo_data, data)
351
351
  tensor
352
352
  end
353
353
 
@@ -426,11 +426,11 @@ module Torch
426
426
  end
427
427
 
428
428
  if options[:dtype].nil?
429
- if data.all? { |v| v.is_a?(Integer) }
429
+ if data.all?(Integer)
430
430
  options[:dtype] = :int64
431
431
  elsif data.all? { |v| v == true || v == false }
432
432
  options[:dtype] = :bool
433
- elsif data.any? { |v| v.is_a?(Complex) }
433
+ elsif data.any?(Complex)
434
434
  options[:dtype] = :complex64
435
435
  end
436
436
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.23.1
4
+ version: 0.24.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
@@ -234,14 +234,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
234
234
  requirements:
235
235
  - - ">="
236
236
  - !ruby/object:Gem::Version
237
- version: '3.2'
237
+ version: '3.3'
238
238
  required_rubygems_version: !ruby/object:Gem::Requirement
239
239
  requirements:
240
240
  - - ">="
241
241
  - !ruby/object:Gem::Version
242
242
  version: '0'
243
243
  requirements: []
244
- rubygems_version: 4.0.3
244
+ rubygems_version: 4.0.6
245
245
  specification_version: 4
246
246
  summary: Deep learning for Ruby, powered by LibTorch
247
247
  test_files: []