torch-rb 0.6.0 → 0.8.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +21 -0
  3. data/README.md +23 -41
  4. data/codegen/function.rb +2 -0
  5. data/codegen/generate_functions.rb +43 -6
  6. data/codegen/native_functions.yaml +2007 -1327
  7. data/ext/torch/backends.cpp +17 -0
  8. data/ext/torch/cuda.cpp +5 -5
  9. data/ext/torch/device.cpp +13 -6
  10. data/ext/torch/ext.cpp +22 -5
  11. data/ext/torch/extconf.rb +1 -3
  12. data/ext/torch/fft.cpp +13 -0
  13. data/ext/torch/fft_functions.h +6 -0
  14. data/ext/torch/ivalue.cpp +31 -33
  15. data/ext/torch/linalg.cpp +13 -0
  16. data/ext/torch/linalg_functions.h +6 -0
  17. data/ext/torch/nn.cpp +34 -34
  18. data/ext/torch/random.cpp +5 -5
  19. data/ext/torch/ruby_arg_parser.cpp +2 -2
  20. data/ext/torch/ruby_arg_parser.h +23 -12
  21. data/ext/torch/special.cpp +13 -0
  22. data/ext/torch/special_functions.h +6 -0
  23. data/ext/torch/templates.h +111 -133
  24. data/ext/torch/tensor.cpp +80 -67
  25. data/ext/torch/torch.cpp +30 -21
  26. data/ext/torch/utils.h +3 -4
  27. data/ext/torch/wrap_outputs.h +72 -65
  28. data/lib/torch/inspector.rb +5 -2
  29. data/lib/torch/nn/convnd.rb +2 -0
  30. data/lib/torch/nn/functional_attention.rb +241 -0
  31. data/lib/torch/nn/module.rb +2 -0
  32. data/lib/torch/nn/module_list.rb +49 -0
  33. data/lib/torch/nn/multihead_attention.rb +123 -0
  34. data/lib/torch/nn/transformer.rb +92 -0
  35. data/lib/torch/nn/transformer_decoder.rb +25 -0
  36. data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
  37. data/lib/torch/nn/transformer_encoder.rb +25 -0
  38. data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
  39. data/lib/torch/nn/utils.rb +16 -0
  40. data/lib/torch/tensor.rb +2 -0
  41. data/lib/torch/utils/data/data_loader.rb +2 -0
  42. data/lib/torch/version.rb +1 -1
  43. data/lib/torch.rb +11 -0
  44. metadata +20 -5
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/cuda.cpp CHANGED
@@ -1,14 +1,14 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "utils.h"
6
6
 
7
7
  void init_cuda(Rice::Module& m) {
8
8
  Rice::define_module_under(m, "CUDA")
9
9
  .add_handler<torch::Error>(handle_error)
10
- .define_singleton_method("available?", &torch::cuda::is_available)
11
- .define_singleton_method("device_count", &torch::cuda::device_count)
12
- .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
13
- .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
10
+ .define_singleton_function("available?", &torch::cuda::is_available)
11
+ .define_singleton_function("device_count", &torch::cuda::device_count)
12
+ .define_singleton_function("manual_seed", &torch::cuda::manual_seed)
13
+ .define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
14
14
  }
data/ext/torch/device.cpp CHANGED
@@ -1,19 +1,26 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Constructor.hpp>
4
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
5
4
 
6
5
  #include "utils.h"
7
6
 
8
7
  void init_device(Rice::Module& m) {
9
8
  Rice::define_class_under<torch::Device>(m, "Device")
10
9
  .add_handler<torch::Error>(handle_error)
11
- .define_constructor(Rice::Constructor<torch::Device, std::string>())
12
- .define_method("index", &torch::Device::index)
13
- .define_method("index?", &torch::Device::has_index)
10
+ .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
11
+ .define_method(
12
+ "index",
13
+ [](torch::Device& self) {
14
+ return self.index();
15
+ })
16
+ .define_method(
17
+ "index?",
18
+ [](torch::Device& self) {
19
+ return self.has_index();
20
+ })
14
21
  .define_method(
15
22
  "type",
16
- *[](torch::Device& self) {
23
+ [](torch::Device& self) {
17
24
  std::stringstream s;
18
25
  s << self.type();
19
26
  return s.str();
data/ext/torch/ext.cpp CHANGED
@@ -1,12 +1,18 @@
1
- #include <rice/Module.hpp>
1
+ #include <torch/torch.h>
2
2
 
3
+ #include <rice/rice.hpp>
4
+
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
3
7
  void init_nn(Rice::Module& m);
4
- void init_tensor(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
9
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
5
10
  void init_torch(Rice::Module& m);
6
11
 
12
+ void init_backends(Rice::Module& m);
7
13
  void init_cuda(Rice::Module& m);
8
14
  void init_device(Rice::Module& m);
9
- void init_ivalue(Rice::Module& m);
15
+ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
10
16
  void init_random(Rice::Module& m);
11
17
 
12
18
  extern "C"
@@ -14,13 +20,24 @@ void Init_ext()
14
20
  {
15
21
  auto m = Rice::define_module("Torch");
16
22
 
23
+ // need to define certain classes up front to keep Rice happy
24
+ auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
+ .define_constructor(Rice::Constructor<torch::IValue>());
26
+ auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
+ auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
+ .define_constructor(Rice::Constructor<torch::TensorOptions>());
29
+
17
30
  // keep this order
18
31
  init_torch(m);
19
- init_tensor(m);
32
+ init_tensor(m, rb_cTensor, rb_cTensorOptions);
20
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
21
37
 
38
+ init_backends(m);
22
39
  init_cuda(m);
23
40
  init_device(m);
24
- init_ivalue(m);
41
+ init_ivalue(m, rb_cIValue);
25
42
  init_random(m);
26
43
  }
data/ext/torch/extconf.rb CHANGED
@@ -1,8 +1,6 @@
1
1
  require "mkmf-rice"
2
2
 
3
- abort "Missing stdc++" unless have_library("stdc++")
4
-
5
- $CXXFLAGS += " -std=c++14"
3
+ $CXXFLAGS += " -std=c++17 $(optflags)"
6
4
 
7
5
  # change to 0 for Linux pre-cxx11 ABI version
8
6
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
data/ext/torch/ivalue.cpp CHANGED
@@ -1,18 +1,13 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Array.hpp>
4
- #include <rice/Constructor.hpp>
5
- #include <rice/Hash.hpp>
6
- #include <rice/Module.hpp>
7
- #include <rice/String.hpp>
3
+ #include <rice/rice.hpp>
8
4
 
9
5
  #include "utils.h"
10
6
 
11
- void init_ivalue(Rice::Module& m) {
7
+ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
12
8
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
13
- Rice::define_class_under<torch::IValue>(m, "IValue")
9
+ rb_cIValue
14
10
  .add_handler<torch::Error>(handle_error)
15
- .define_constructor(Rice::Constructor<torch::IValue>())
16
11
  .define_method("bool?", &torch::IValue::isBool)
17
12
  .define_method("bool_list?", &torch::IValue::isBoolList)
18
13
  .define_method("capsule?", &torch::IValue::isCapsule)
@@ -39,95 +34,98 @@ void init_ivalue(Rice::Module& m) {
39
34
  .define_method("tuple?", &torch::IValue::isTuple)
40
35
  .define_method(
41
36
  "to_bool",
42
- *[](torch::IValue& self) {
37
+ [](torch::IValue& self) {
43
38
  return self.toBool();
44
39
  })
45
40
  .define_method(
46
41
  "to_double",
47
- *[](torch::IValue& self) {
42
+ [](torch::IValue& self) {
48
43
  return self.toDouble();
49
44
  })
50
45
  .define_method(
51
46
  "to_int",
52
- *[](torch::IValue& self) {
47
+ [](torch::IValue& self) {
53
48
  return self.toInt();
54
49
  })
55
50
  .define_method(
56
51
  "to_list",
57
- *[](torch::IValue& self) {
52
+ [](torch::IValue& self) {
58
53
  auto list = self.toListRef();
59
54
  Rice::Array obj;
60
55
  for (auto& elem : list) {
61
- obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
56
+ auto v = torch::IValue{elem};
57
+ obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
62
58
  }
63
59
  return obj;
64
60
  })
65
61
  .define_method(
66
62
  "to_string_ref",
67
- *[](torch::IValue& self) {
63
+ [](torch::IValue& self) {
68
64
  return self.toStringRef();
69
65
  })
70
66
  .define_method(
71
67
  "to_tensor",
72
- *[](torch::IValue& self) {
68
+ [](torch::IValue& self) {
73
69
  return self.toTensor();
74
70
  })
75
71
  .define_method(
76
72
  "to_generic_dict",
77
- *[](torch::IValue& self) {
73
+ [](torch::IValue& self) {
78
74
  auto dict = self.toGenericDict();
79
75
  Rice::Hash obj;
80
76
  for (auto& pair : dict) {
81
- obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
77
+ auto k = torch::IValue{pair.key()};
78
+ auto v = torch::IValue{pair.value()};
79
+ obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
82
80
  }
83
81
  return obj;
84
82
  })
85
- .define_singleton_method(
83
+ .define_singleton_function(
86
84
  "from_tensor",
87
- *[](torch::Tensor& v) {
85
+ [](torch::Tensor& v) {
88
86
  return torch::IValue(v);
89
87
  })
90
88
  // TODO create specialized list types?
91
- .define_singleton_method(
89
+ .define_singleton_function(
92
90
  "from_list",
93
- *[](Rice::Array obj) {
91
+ [](Rice::Array obj) {
94
92
  c10::impl::GenericList list(c10::AnyType::get());
95
93
  for (auto entry : obj) {
96
- list.push_back(from_ruby<torch::IValue>(entry));
94
+ list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
97
95
  }
98
96
  return torch::IValue(list);
99
97
  })
100
- .define_singleton_method(
98
+ .define_singleton_function(
101
99
  "from_string",
102
- *[](Rice::String v) {
100
+ [](Rice::String v) {
103
101
  return torch::IValue(v.str());
104
102
  })
105
- .define_singleton_method(
103
+ .define_singleton_function(
106
104
  "from_int",
107
- *[](int64_t v) {
105
+ [](int64_t v) {
108
106
  return torch::IValue(v);
109
107
  })
110
- .define_singleton_method(
108
+ .define_singleton_function(
111
109
  "from_double",
112
- *[](double v) {
110
+ [](double v) {
113
111
  return torch::IValue(v);
114
112
  })
115
- .define_singleton_method(
113
+ .define_singleton_function(
116
114
  "from_bool",
117
- *[](bool v) {
115
+ [](bool v) {
118
116
  return torch::IValue(v);
119
117
  })
120
118
  // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
121
119
  // createGenericDict and toIValue
122
- .define_singleton_method(
120
+ .define_singleton_function(
123
121
  "from_dict",
124
- *[](Rice::Hash obj) {
122
+ [](Rice::Hash obj) {
125
123
  auto key_type = c10::AnyType::get();
126
124
  auto value_type = c10::AnyType::get();
127
125
  c10::impl::GenericDict elems(key_type, value_type);
128
126
  elems.reserve(obj.size());
129
127
  for (auto entry : obj) {
130
- elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second));
128
+ elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
131
129
  }
132
130
  return torch::IValue(std::move(elems));
133
131
  });
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
data/ext/torch/nn.cpp CHANGED
@@ -1,6 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "nn_functions.h"
6
6
  #include "templates.h"
@@ -19,74 +19,74 @@ void init_nn(Rice::Module& m) {
19
19
 
20
20
  Rice::define_module_under(rb_mNN, "Init")
21
21
  .add_handler<torch::Error>(handle_error)
22
- .define_singleton_method(
22
+ .define_singleton_function(
23
23
  "_calculate_gain",
24
- *[](NonlinearityType nonlinearity, double param) {
24
+ [](NonlinearityType nonlinearity, double param) {
25
25
  return torch::nn::init::calculate_gain(nonlinearity, param);
26
26
  })
27
- .define_singleton_method(
27
+ .define_singleton_function(
28
28
  "_uniform!",
29
- *[](Tensor tensor, double low, double high) {
29
+ [](Tensor tensor, double low, double high) {
30
30
  return torch::nn::init::uniform_(tensor, low, high);
31
31
  })
32
- .define_singleton_method(
32
+ .define_singleton_function(
33
33
  "_normal!",
34
- *[](Tensor tensor, double mean, double std) {
34
+ [](Tensor tensor, double mean, double std) {
35
35
  return torch::nn::init::normal_(tensor, mean, std);
36
36
  })
37
- .define_singleton_method(
37
+ .define_singleton_function(
38
38
  "_constant!",
39
- *[](Tensor tensor, Scalar value) {
39
+ [](Tensor tensor, Scalar value) {
40
40
  return torch::nn::init::constant_(tensor, value);
41
41
  })
42
- .define_singleton_method(
42
+ .define_singleton_function(
43
43
  "_ones!",
44
- *[](Tensor tensor) {
44
+ [](Tensor tensor) {
45
45
  return torch::nn::init::ones_(tensor);
46
46
  })
47
- .define_singleton_method(
47
+ .define_singleton_function(
48
48
  "_zeros!",
49
- *[](Tensor tensor) {
49
+ [](Tensor tensor) {
50
50
  return torch::nn::init::zeros_(tensor);
51
51
  })
52
- .define_singleton_method(
52
+ .define_singleton_function(
53
53
  "_eye!",
54
- *[](Tensor tensor) {
54
+ [](Tensor tensor) {
55
55
  return torch::nn::init::eye_(tensor);
56
56
  })
57
- .define_singleton_method(
57
+ .define_singleton_function(
58
58
  "_dirac!",
59
- *[](Tensor tensor) {
59
+ [](Tensor tensor) {
60
60
  return torch::nn::init::dirac_(tensor);
61
61
  })
62
- .define_singleton_method(
62
+ .define_singleton_function(
63
63
  "_xavier_uniform!",
64
- *[](Tensor tensor, double gain) {
64
+ [](Tensor tensor, double gain) {
65
65
  return torch::nn::init::xavier_uniform_(tensor, gain);
66
66
  })
67
- .define_singleton_method(
67
+ .define_singleton_function(
68
68
  "_xavier_normal!",
69
- *[](Tensor tensor, double gain) {
69
+ [](Tensor tensor, double gain) {
70
70
  return torch::nn::init::xavier_normal_(tensor, gain);
71
71
  })
72
- .define_singleton_method(
72
+ .define_singleton_function(
73
73
  "_kaiming_uniform!",
74
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
74
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
75
75
  return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
76
76
  })
77
- .define_singleton_method(
77
+ .define_singleton_function(
78
78
  "_kaiming_normal!",
79
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
79
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
80
80
  return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
81
81
  })
82
- .define_singleton_method(
82
+ .define_singleton_function(
83
83
  "_orthogonal!",
84
- *[](Tensor tensor, double gain) {
84
+ [](Tensor tensor, double gain) {
85
85
  return torch::nn::init::orthogonal_(tensor, gain);
86
86
  })
87
- .define_singleton_method(
87
+ .define_singleton_function(
88
88
  "_sparse!",
89
- *[](Tensor tensor, double sparsity, double std) {
89
+ [](Tensor tensor, double sparsity, double std) {
90
90
  return torch::nn::init::sparse_(tensor, sparsity, std);
91
91
  });
92
92
 
@@ -94,18 +94,18 @@ void init_nn(Rice::Module& m) {
94
94
  .add_handler<torch::Error>(handle_error)
95
95
  .define_method(
96
96
  "grad",
97
- *[](Parameter& self) {
97
+ [](Parameter& self) {
98
98
  auto grad = self.grad();
99
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
99
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
100
100
  })
101
101
  .define_method(
102
102
  "grad=",
103
- *[](Parameter& self, torch::Tensor& grad) {
103
+ [](Parameter& self, torch::Tensor& grad) {
104
104
  self.mutable_grad() = grad;
105
105
  })
106
- .define_singleton_method(
106
+ .define_singleton_function(
107
107
  "_make_subclass",
108
- *[](Tensor& rd, bool requires_grad) {
108
+ [](Tensor& rd, bool requires_grad) {
109
109
  auto data = rd.detach();
110
110
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
111
111
  auto var = data.set_requires_grad(requires_grad);
data/ext/torch/random.cpp CHANGED
@@ -1,20 +1,20 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "utils.h"
6
6
 
7
7
  void init_random(Rice::Module& m) {
8
8
  Rice::define_module_under(m, "Random")
9
9
  .add_handler<torch::Error>(handle_error)
10
- .define_singleton_method(
10
+ .define_singleton_function(
11
11
  "initial_seed",
12
- *[]() {
12
+ []() {
13
13
  return at::detail::getDefaultCPUGenerator().current_seed();
14
14
  })
15
- .define_singleton_method(
15
+ .define_singleton_function(
16
16
  "seed",
17
- *[]() {
17
+ []() {
18
18
  // TODO set for CUDA when available
19
19
  auto generator = at::detail::getDefaultCPUGenerator();
20
20
  return generator.seed();
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
137
137
  return true;
138
138
  }
139
139
  if (THPVariable_Check(obj)) {
140
- auto var = from_ruby<torch::Tensor>(obj);
140
+ auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
141
141
  return !var.requires_grad() && var.dim() == 0;
142
142
  }
143
143
  return false;
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
147
147
  return true;
148
148
  }
149
149
  if (THPVariable_Check(obj)) {
150
- auto var = from_ruby<torch::Tensor>(obj);
150
+ auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
151
151
  return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
152
152
  }
153
153
  return false;
@@ -5,7 +5,7 @@
5
5
  #include <sstream>
6
6
 
7
7
  #include <torch/torch.h>
8
- #include <rice/Exception.hpp>
8
+ #include <rice/rice.hpp>
9
9
 
10
10
  #include "templates.h"
11
11
  #include "utils.h"
@@ -78,6 +78,7 @@ struct RubyArgs {
78
78
  inline OptionalTensor optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
+ inline std::vector<at::Scalar> scalarlist(int i);
81
82
  inline std::vector<at::Tensor> tensorlist(int i);
82
83
  template<int N>
83
84
  inline std::array<at::Tensor, N> tensorlist_n(int i);
@@ -121,7 +122,7 @@ struct RubyArgs {
121
122
  };
122
123
 
123
124
  inline at::Tensor RubyArgs::tensor(int i) {
124
- return from_ruby<torch::Tensor>(args[i]);
125
+ return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
125
126
  }
126
127
 
127
128
  inline OptionalTensor RubyArgs::optionalTensor(int i) {
@@ -131,12 +132,17 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
131
132
 
132
133
  inline at::Scalar RubyArgs::scalar(int i) {
133
134
  if (NIL_P(args[i])) return signature.params[i].default_scalar;
134
- return from_ruby<torch::Scalar>(args[i]);
135
+ return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
136
+ }
137
+
138
+ inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
139
+ if (NIL_P(args[i])) return std::vector<at::Scalar>();
140
+ return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
135
141
  }
136
142
 
137
143
  inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
138
144
  if (NIL_P(args[i])) return std::vector<at::Tensor>();
139
- return from_ruby<std::vector<Tensor>>(args[i]);
145
+ return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
140
146
  }
141
147
 
142
148
  template<int N>
@@ -151,7 +157,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
151
157
  }
152
158
  for (int idx = 0; idx < size; idx++) {
153
159
  VALUE obj = rb_ary_entry(arg, idx);
154
- res[idx] = from_ruby<Tensor>(obj);
160
+ res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
155
161
  }
156
162
  return res;
157
163
  }
@@ -170,7 +176,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
170
176
  for (idx = 0; idx < size; idx++) {
171
177
  VALUE obj = rb_ary_entry(arg, idx);
172
178
  if (FIXNUM_P(obj)) {
173
- res[idx] = from_ruby<int64_t>(obj);
179
+ res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
174
180
  } else {
175
181
  rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
176
182
  signature.name.c_str(), signature.params[i].name.c_str(),
@@ -210,8 +216,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
210
216
  {ID2SYM(rb_intern("double")), ScalarType::Double},
211
217
  {ID2SYM(rb_intern("float64")), ScalarType::Double},
212
218
  {ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
219
+ {ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
213
220
  {ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
221
+ {ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
222
+ {ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
214
223
  {ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
224
+ {ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
225
+ {ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
215
226
  {ID2SYM(rb_intern("bool")), ScalarType::Bool},
216
227
  {ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
217
228
  {ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
@@ -260,7 +271,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
260
271
  for (idx = 0; idx < size; idx++) {
261
272
  VALUE obj = rb_ary_entry(arg, idx);
262
273
  if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
263
- res[idx] = from_ruby<double>(obj);
274
+ res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
264
275
  } else {
265
276
  rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
266
277
  signature.name.c_str(), signature.params[i].name.c_str(),
@@ -303,22 +314,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
303
314
  }
304
315
 
305
316
  inline std::string RubyArgs::string(int i) {
306
- return from_ruby<std::string>(args[i]);
317
+ return Rice::detail::From_Ruby<std::string>().convert(args[i]);
307
318
  }
308
319
 
309
320
  inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
310
- if (!args[i]) return c10::nullopt;
311
- return from_ruby<std::string>(args[i]);
321
+ if (NIL_P(args[i])) return c10::nullopt;
322
+ return Rice::detail::From_Ruby<std::string>().convert(args[i]);
312
323
  }
313
324
 
314
325
  inline int64_t RubyArgs::toInt64(int i) {
315
326
  if (NIL_P(args[i])) return signature.params[i].default_int;
316
- return from_ruby<int64_t>(args[i]);
327
+ return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
317
328
  }
318
329
 
319
330
  inline double RubyArgs::toDouble(int i) {
320
331
  if (NIL_P(args[i])) return signature.params[i].default_double;
321
- return from_ruby<double>(args[i]);
332
+ return Rice::detail::From_Ruby<double>().convert(args[i]);
322
333
  }
323
334
 
324
335
  inline bool RubyArgs::toBool(int i) {
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);