torch-rb 0.6.0 → 0.8.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +21 -0
  3. data/README.md +23 -41
  4. data/codegen/function.rb +2 -0
  5. data/codegen/generate_functions.rb +43 -6
  6. data/codegen/native_functions.yaml +2007 -1327
  7. data/ext/torch/backends.cpp +17 -0
  8. data/ext/torch/cuda.cpp +5 -5
  9. data/ext/torch/device.cpp +13 -6
  10. data/ext/torch/ext.cpp +22 -5
  11. data/ext/torch/extconf.rb +1 -3
  12. data/ext/torch/fft.cpp +13 -0
  13. data/ext/torch/fft_functions.h +6 -0
  14. data/ext/torch/ivalue.cpp +31 -33
  15. data/ext/torch/linalg.cpp +13 -0
  16. data/ext/torch/linalg_functions.h +6 -0
  17. data/ext/torch/nn.cpp +34 -34
  18. data/ext/torch/random.cpp +5 -5
  19. data/ext/torch/ruby_arg_parser.cpp +2 -2
  20. data/ext/torch/ruby_arg_parser.h +23 -12
  21. data/ext/torch/special.cpp +13 -0
  22. data/ext/torch/special_functions.h +6 -0
  23. data/ext/torch/templates.h +111 -133
  24. data/ext/torch/tensor.cpp +80 -67
  25. data/ext/torch/torch.cpp +30 -21
  26. data/ext/torch/utils.h +3 -4
  27. data/ext/torch/wrap_outputs.h +72 -65
  28. data/lib/torch/inspector.rb +5 -2
  29. data/lib/torch/nn/convnd.rb +2 -0
  30. data/lib/torch/nn/functional_attention.rb +241 -0
  31. data/lib/torch/nn/module.rb +2 -0
  32. data/lib/torch/nn/module_list.rb +49 -0
  33. data/lib/torch/nn/multihead_attention.rb +123 -0
  34. data/lib/torch/nn/transformer.rb +92 -0
  35. data/lib/torch/nn/transformer_decoder.rb +25 -0
  36. data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
  37. data/lib/torch/nn/transformer_encoder.rb +25 -0
  38. data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
  39. data/lib/torch/nn/utils.rb +16 -0
  40. data/lib/torch/tensor.rb +2 -0
  41. data/lib/torch/utils/data/data_loader.rb +2 -0
  42. data/lib/torch/version.rb +1 -1
  43. data/lib/torch.rb +11 -0
  44. metadata +20 -5
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/cuda.cpp CHANGED
@@ -1,14 +1,14 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "utils.h"
6
6
 
7
7
  void init_cuda(Rice::Module& m) {
8
8
  Rice::define_module_under(m, "CUDA")
9
9
  .add_handler<torch::Error>(handle_error)
10
- .define_singleton_method("available?", &torch::cuda::is_available)
11
- .define_singleton_method("device_count", &torch::cuda::device_count)
12
- .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
13
- .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
10
+ .define_singleton_function("available?", &torch::cuda::is_available)
11
+ .define_singleton_function("device_count", &torch::cuda::device_count)
12
+ .define_singleton_function("manual_seed", &torch::cuda::manual_seed)
13
+ .define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
14
14
  }
data/ext/torch/device.cpp CHANGED
@@ -1,19 +1,26 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Constructor.hpp>
4
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
5
4
 
6
5
  #include "utils.h"
7
6
 
8
7
  void init_device(Rice::Module& m) {
9
8
  Rice::define_class_under<torch::Device>(m, "Device")
10
9
  .add_handler<torch::Error>(handle_error)
11
- .define_constructor(Rice::Constructor<torch::Device, std::string>())
12
- .define_method("index", &torch::Device::index)
13
- .define_method("index?", &torch::Device::has_index)
10
+ .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
11
+ .define_method(
12
+ "index",
13
+ [](torch::Device& self) {
14
+ return self.index();
15
+ })
16
+ .define_method(
17
+ "index?",
18
+ [](torch::Device& self) {
19
+ return self.has_index();
20
+ })
14
21
  .define_method(
15
22
  "type",
16
- *[](torch::Device& self) {
23
+ [](torch::Device& self) {
17
24
  std::stringstream s;
18
25
  s << self.type();
19
26
  return s.str();
data/ext/torch/ext.cpp CHANGED
@@ -1,12 +1,18 @@
1
- #include <rice/Module.hpp>
1
+ #include <torch/torch.h>
2
2
 
3
+ #include <rice/rice.hpp>
4
+
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
3
7
  void init_nn(Rice::Module& m);
4
- void init_tensor(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
9
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
5
10
  void init_torch(Rice::Module& m);
6
11
 
12
+ void init_backends(Rice::Module& m);
7
13
  void init_cuda(Rice::Module& m);
8
14
  void init_device(Rice::Module& m);
9
- void init_ivalue(Rice::Module& m);
15
+ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
10
16
  void init_random(Rice::Module& m);
11
17
 
12
18
  extern "C"
@@ -14,13 +20,24 @@ void Init_ext()
14
20
  {
15
21
  auto m = Rice::define_module("Torch");
16
22
 
23
+ // need to define certain classes up front to keep Rice happy
24
+ auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
+ .define_constructor(Rice::Constructor<torch::IValue>());
26
+ auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
+ auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
+ .define_constructor(Rice::Constructor<torch::TensorOptions>());
29
+
17
30
  // keep this order
18
31
  init_torch(m);
19
- init_tensor(m);
32
+ init_tensor(m, rb_cTensor, rb_cTensorOptions);
20
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
21
37
 
38
+ init_backends(m);
22
39
  init_cuda(m);
23
40
  init_device(m);
24
- init_ivalue(m);
41
+ init_ivalue(m, rb_cIValue);
25
42
  init_random(m);
26
43
  }
data/ext/torch/extconf.rb CHANGED
@@ -1,8 +1,6 @@
1
1
  require "mkmf-rice"
2
2
 
3
- abort "Missing stdc++" unless have_library("stdc++")
4
-
5
- $CXXFLAGS += " -std=c++14"
3
+ $CXXFLAGS += " -std=c++17 $(optflags)"
6
4
 
7
5
  # change to 0 for Linux pre-cxx11 ABI version
8
6
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
data/ext/torch/ivalue.cpp CHANGED
@@ -1,18 +1,13 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Array.hpp>
4
- #include <rice/Constructor.hpp>
5
- #include <rice/Hash.hpp>
6
- #include <rice/Module.hpp>
7
- #include <rice/String.hpp>
3
+ #include <rice/rice.hpp>
8
4
 
9
5
  #include "utils.h"
10
6
 
11
- void init_ivalue(Rice::Module& m) {
7
+ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
12
8
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
13
- Rice::define_class_under<torch::IValue>(m, "IValue")
9
+ rb_cIValue
14
10
  .add_handler<torch::Error>(handle_error)
15
- .define_constructor(Rice::Constructor<torch::IValue>())
16
11
  .define_method("bool?", &torch::IValue::isBool)
17
12
  .define_method("bool_list?", &torch::IValue::isBoolList)
18
13
  .define_method("capsule?", &torch::IValue::isCapsule)
@@ -39,95 +34,98 @@ void init_ivalue(Rice::Module& m) {
39
34
  .define_method("tuple?", &torch::IValue::isTuple)
40
35
  .define_method(
41
36
  "to_bool",
42
- *[](torch::IValue& self) {
37
+ [](torch::IValue& self) {
43
38
  return self.toBool();
44
39
  })
45
40
  .define_method(
46
41
  "to_double",
47
- *[](torch::IValue& self) {
42
+ [](torch::IValue& self) {
48
43
  return self.toDouble();
49
44
  })
50
45
  .define_method(
51
46
  "to_int",
52
- *[](torch::IValue& self) {
47
+ [](torch::IValue& self) {
53
48
  return self.toInt();
54
49
  })
55
50
  .define_method(
56
51
  "to_list",
57
- *[](torch::IValue& self) {
52
+ [](torch::IValue& self) {
58
53
  auto list = self.toListRef();
59
54
  Rice::Array obj;
60
55
  for (auto& elem : list) {
61
- obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
56
+ auto v = torch::IValue{elem};
57
+ obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
62
58
  }
63
59
  return obj;
64
60
  })
65
61
  .define_method(
66
62
  "to_string_ref",
67
- *[](torch::IValue& self) {
63
+ [](torch::IValue& self) {
68
64
  return self.toStringRef();
69
65
  })
70
66
  .define_method(
71
67
  "to_tensor",
72
- *[](torch::IValue& self) {
68
+ [](torch::IValue& self) {
73
69
  return self.toTensor();
74
70
  })
75
71
  .define_method(
76
72
  "to_generic_dict",
77
- *[](torch::IValue& self) {
73
+ [](torch::IValue& self) {
78
74
  auto dict = self.toGenericDict();
79
75
  Rice::Hash obj;
80
76
  for (auto& pair : dict) {
81
- obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
77
+ auto k = torch::IValue{pair.key()};
78
+ auto v = torch::IValue{pair.value()};
79
+ obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
82
80
  }
83
81
  return obj;
84
82
  })
85
- .define_singleton_method(
83
+ .define_singleton_function(
86
84
  "from_tensor",
87
- *[](torch::Tensor& v) {
85
+ [](torch::Tensor& v) {
88
86
  return torch::IValue(v);
89
87
  })
90
88
  // TODO create specialized list types?
91
- .define_singleton_method(
89
+ .define_singleton_function(
92
90
  "from_list",
93
- *[](Rice::Array obj) {
91
+ [](Rice::Array obj) {
94
92
  c10::impl::GenericList list(c10::AnyType::get());
95
93
  for (auto entry : obj) {
96
- list.push_back(from_ruby<torch::IValue>(entry));
94
+ list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
97
95
  }
98
96
  return torch::IValue(list);
99
97
  })
100
- .define_singleton_method(
98
+ .define_singleton_function(
101
99
  "from_string",
102
- *[](Rice::String v) {
100
+ [](Rice::String v) {
103
101
  return torch::IValue(v.str());
104
102
  })
105
- .define_singleton_method(
103
+ .define_singleton_function(
106
104
  "from_int",
107
- *[](int64_t v) {
105
+ [](int64_t v) {
108
106
  return torch::IValue(v);
109
107
  })
110
- .define_singleton_method(
108
+ .define_singleton_function(
111
109
  "from_double",
112
- *[](double v) {
110
+ [](double v) {
113
111
  return torch::IValue(v);
114
112
  })
115
- .define_singleton_method(
113
+ .define_singleton_function(
116
114
  "from_bool",
117
- *[](bool v) {
115
+ [](bool v) {
118
116
  return torch::IValue(v);
119
117
  })
120
118
  // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
121
119
  // createGenericDict and toIValue
122
- .define_singleton_method(
120
+ .define_singleton_function(
123
121
  "from_dict",
124
- *[](Rice::Hash obj) {
122
+ [](Rice::Hash obj) {
125
123
  auto key_type = c10::AnyType::get();
126
124
  auto value_type = c10::AnyType::get();
127
125
  c10::impl::GenericDict elems(key_type, value_type);
128
126
  elems.reserve(obj.size());
129
127
  for (auto entry : obj) {
130
- elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second));
128
+ elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
131
129
  }
132
130
  return torch::IValue(std::move(elems));
133
131
  });
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
data/ext/torch/nn.cpp CHANGED
@@ -1,6 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "nn_functions.h"
6
6
  #include "templates.h"
@@ -19,74 +19,74 @@ void init_nn(Rice::Module& m) {
19
19
 
20
20
  Rice::define_module_under(rb_mNN, "Init")
21
21
  .add_handler<torch::Error>(handle_error)
22
- .define_singleton_method(
22
+ .define_singleton_function(
23
23
  "_calculate_gain",
24
- *[](NonlinearityType nonlinearity, double param) {
24
+ [](NonlinearityType nonlinearity, double param) {
25
25
  return torch::nn::init::calculate_gain(nonlinearity, param);
26
26
  })
27
- .define_singleton_method(
27
+ .define_singleton_function(
28
28
  "_uniform!",
29
- *[](Tensor tensor, double low, double high) {
29
+ [](Tensor tensor, double low, double high) {
30
30
  return torch::nn::init::uniform_(tensor, low, high);
31
31
  })
32
- .define_singleton_method(
32
+ .define_singleton_function(
33
33
  "_normal!",
34
- *[](Tensor tensor, double mean, double std) {
34
+ [](Tensor tensor, double mean, double std) {
35
35
  return torch::nn::init::normal_(tensor, mean, std);
36
36
  })
37
- .define_singleton_method(
37
+ .define_singleton_function(
38
38
  "_constant!",
39
- *[](Tensor tensor, Scalar value) {
39
+ [](Tensor tensor, Scalar value) {
40
40
  return torch::nn::init::constant_(tensor, value);
41
41
  })
42
- .define_singleton_method(
42
+ .define_singleton_function(
43
43
  "_ones!",
44
- *[](Tensor tensor) {
44
+ [](Tensor tensor) {
45
45
  return torch::nn::init::ones_(tensor);
46
46
  })
47
- .define_singleton_method(
47
+ .define_singleton_function(
48
48
  "_zeros!",
49
- *[](Tensor tensor) {
49
+ [](Tensor tensor) {
50
50
  return torch::nn::init::zeros_(tensor);
51
51
  })
52
- .define_singleton_method(
52
+ .define_singleton_function(
53
53
  "_eye!",
54
- *[](Tensor tensor) {
54
+ [](Tensor tensor) {
55
55
  return torch::nn::init::eye_(tensor);
56
56
  })
57
- .define_singleton_method(
57
+ .define_singleton_function(
58
58
  "_dirac!",
59
- *[](Tensor tensor) {
59
+ [](Tensor tensor) {
60
60
  return torch::nn::init::dirac_(tensor);
61
61
  })
62
- .define_singleton_method(
62
+ .define_singleton_function(
63
63
  "_xavier_uniform!",
64
- *[](Tensor tensor, double gain) {
64
+ [](Tensor tensor, double gain) {
65
65
  return torch::nn::init::xavier_uniform_(tensor, gain);
66
66
  })
67
- .define_singleton_method(
67
+ .define_singleton_function(
68
68
  "_xavier_normal!",
69
- *[](Tensor tensor, double gain) {
69
+ [](Tensor tensor, double gain) {
70
70
  return torch::nn::init::xavier_normal_(tensor, gain);
71
71
  })
72
- .define_singleton_method(
72
+ .define_singleton_function(
73
73
  "_kaiming_uniform!",
74
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
74
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
75
75
  return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
76
76
  })
77
- .define_singleton_method(
77
+ .define_singleton_function(
78
78
  "_kaiming_normal!",
79
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
79
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
80
80
  return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
81
81
  })
82
- .define_singleton_method(
82
+ .define_singleton_function(
83
83
  "_orthogonal!",
84
- *[](Tensor tensor, double gain) {
84
+ [](Tensor tensor, double gain) {
85
85
  return torch::nn::init::orthogonal_(tensor, gain);
86
86
  })
87
- .define_singleton_method(
87
+ .define_singleton_function(
88
88
  "_sparse!",
89
- *[](Tensor tensor, double sparsity, double std) {
89
+ [](Tensor tensor, double sparsity, double std) {
90
90
  return torch::nn::init::sparse_(tensor, sparsity, std);
91
91
  });
92
92
 
@@ -94,18 +94,18 @@ void init_nn(Rice::Module& m) {
94
94
  .add_handler<torch::Error>(handle_error)
95
95
  .define_method(
96
96
  "grad",
97
- *[](Parameter& self) {
97
+ [](Parameter& self) {
98
98
  auto grad = self.grad();
99
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
99
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
100
100
  })
101
101
  .define_method(
102
102
  "grad=",
103
- *[](Parameter& self, torch::Tensor& grad) {
103
+ [](Parameter& self, torch::Tensor& grad) {
104
104
  self.mutable_grad() = grad;
105
105
  })
106
- .define_singleton_method(
106
+ .define_singleton_function(
107
107
  "_make_subclass",
108
- *[](Tensor& rd, bool requires_grad) {
108
+ [](Tensor& rd, bool requires_grad) {
109
109
  auto data = rd.detach();
110
110
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
111
111
  auto var = data.set_requires_grad(requires_grad);
data/ext/torch/random.cpp CHANGED
@@ -1,20 +1,20 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "utils.h"
6
6
 
7
7
  void init_random(Rice::Module& m) {
8
8
  Rice::define_module_under(m, "Random")
9
9
  .add_handler<torch::Error>(handle_error)
10
- .define_singleton_method(
10
+ .define_singleton_function(
11
11
  "initial_seed",
12
- *[]() {
12
+ []() {
13
13
  return at::detail::getDefaultCPUGenerator().current_seed();
14
14
  })
15
- .define_singleton_method(
15
+ .define_singleton_function(
16
16
  "seed",
17
- *[]() {
17
+ []() {
18
18
  // TODO set for CUDA when available
19
19
  auto generator = at::detail::getDefaultCPUGenerator();
20
20
  return generator.seed();
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
137
137
  return true;
138
138
  }
139
139
  if (THPVariable_Check(obj)) {
140
- auto var = from_ruby<torch::Tensor>(obj);
140
+ auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
141
141
  return !var.requires_grad() && var.dim() == 0;
142
142
  }
143
143
  return false;
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
147
147
  return true;
148
148
  }
149
149
  if (THPVariable_Check(obj)) {
150
- auto var = from_ruby<torch::Tensor>(obj);
150
+ auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
151
151
  return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
152
152
  }
153
153
  return false;
@@ -5,7 +5,7 @@
5
5
  #include <sstream>
6
6
 
7
7
  #include <torch/torch.h>
8
- #include <rice/Exception.hpp>
8
+ #include <rice/rice.hpp>
9
9
 
10
10
  #include "templates.h"
11
11
  #include "utils.h"
@@ -78,6 +78,7 @@ struct RubyArgs {
78
78
  inline OptionalTensor optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
+ inline std::vector<at::Scalar> scalarlist(int i);
81
82
  inline std::vector<at::Tensor> tensorlist(int i);
82
83
  template<int N>
83
84
  inline std::array<at::Tensor, N> tensorlist_n(int i);
@@ -121,7 +122,7 @@ struct RubyArgs {
121
122
  };
122
123
 
123
124
  inline at::Tensor RubyArgs::tensor(int i) {
124
- return from_ruby<torch::Tensor>(args[i]);
125
+ return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
125
126
  }
126
127
 
127
128
  inline OptionalTensor RubyArgs::optionalTensor(int i) {
@@ -131,12 +132,17 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
131
132
 
132
133
  inline at::Scalar RubyArgs::scalar(int i) {
133
134
  if (NIL_P(args[i])) return signature.params[i].default_scalar;
134
- return from_ruby<torch::Scalar>(args[i]);
135
+ return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
136
+ }
137
+
138
+ inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
139
+ if (NIL_P(args[i])) return std::vector<at::Scalar>();
140
+ return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
135
141
  }
136
142
 
137
143
  inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
138
144
  if (NIL_P(args[i])) return std::vector<at::Tensor>();
139
- return from_ruby<std::vector<Tensor>>(args[i]);
145
+ return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
140
146
  }
141
147
 
142
148
  template<int N>
@@ -151,7 +157,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
151
157
  }
152
158
  for (int idx = 0; idx < size; idx++) {
153
159
  VALUE obj = rb_ary_entry(arg, idx);
154
- res[idx] = from_ruby<Tensor>(obj);
160
+ res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
155
161
  }
156
162
  return res;
157
163
  }
@@ -170,7 +176,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
170
176
  for (idx = 0; idx < size; idx++) {
171
177
  VALUE obj = rb_ary_entry(arg, idx);
172
178
  if (FIXNUM_P(obj)) {
173
- res[idx] = from_ruby<int64_t>(obj);
179
+ res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
174
180
  } else {
175
181
  rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
176
182
  signature.name.c_str(), signature.params[i].name.c_str(),
@@ -210,8 +216,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
210
216
  {ID2SYM(rb_intern("double")), ScalarType::Double},
211
217
  {ID2SYM(rb_intern("float64")), ScalarType::Double},
212
218
  {ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
219
+ {ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
213
220
  {ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
221
+ {ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
222
+ {ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
214
223
  {ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
224
+ {ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
225
+ {ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
215
226
  {ID2SYM(rb_intern("bool")), ScalarType::Bool},
216
227
  {ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
217
228
  {ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
@@ -260,7 +271,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
260
271
  for (idx = 0; idx < size; idx++) {
261
272
  VALUE obj = rb_ary_entry(arg, idx);
262
273
  if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
263
- res[idx] = from_ruby<double>(obj);
274
+ res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
264
275
  } else {
265
276
  rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
266
277
  signature.name.c_str(), signature.params[i].name.c_str(),
@@ -303,22 +314,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
303
314
  }
304
315
 
305
316
  inline std::string RubyArgs::string(int i) {
306
- return from_ruby<std::string>(args[i]);
317
+ return Rice::detail::From_Ruby<std::string>().convert(args[i]);
307
318
  }
308
319
 
309
320
  inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
310
- if (!args[i]) return c10::nullopt;
311
- return from_ruby<std::string>(args[i]);
321
+ if (NIL_P(args[i])) return c10::nullopt;
322
+ return Rice::detail::From_Ruby<std::string>().convert(args[i]);
312
323
  }
313
324
 
314
325
  inline int64_t RubyArgs::toInt64(int i) {
315
326
  if (NIL_P(args[i])) return signature.params[i].default_int;
316
- return from_ruby<int64_t>(args[i]);
327
+ return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
317
328
  }
318
329
 
319
330
  inline double RubyArgs::toDouble(int i) {
320
331
  if (NIL_P(args[i])) return signature.params[i].default_double;
321
- return from_ruby<double>(args[i]);
332
+ return Rice::detail::From_Ruby<double>().convert(args[i]);
322
333
  }
323
334
 
324
335
  inline bool RubyArgs::toBool(int i) {
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);