torch-rb 0.6.0 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7ef4b5215eabaac0dfa2c95fa19ec47c84ffdc2caaa004b3eed5fd0d88329d4c
4
- data.tar.gz: 3c41ea854a8dc791b6b67631a1e9502f47c26c1afb4eff0e77da3885b5b57e8e
3
+ metadata.gz: 859f5858ce45f6a73fbf9afd6073e3af798ac7ea64713405c54df0fc40b0a1e6
4
+ data.tar.gz: 4c371640902d1226c69135874aaa076251d2539c56dc78479203bc28473b071b
5
5
  SHA512:
6
- metadata.gz: 2378a07b1b7b5470909f920f9f2ace055b2d8f860e1e530ddcfe57b02aa79b61ddd65c9441d62107065d8326f81af90f6741f67eab1f078de85a5a0c140a1f27
7
- data.tar.gz: db2d6c16a388571377d920a5ffa464447501441730feb8b016a7090f7a81e2d0ce0a4626b99081f2b9d07fe9448dc1247adb9e8d1906d4a87c4d6e1023d39be5
6
+ metadata.gz: f8ab6c1af0da9ad36f1fa7fae0cb67b54a60c11783d13cedc4c66164fb1fcb40638a68bbd78fa0f644e6d6d1c7df44d83c68c61bc0e41ea4c361817b0fe7b9cd
7
+ data.tar.gz: f64913cf8c2566539fef54e29e42173f4e2e0529ea49d381c6cf41af64868653865fc2cfa366f5086e035b45008a07f2d620a000c64dabf92e3c9134489b5a6b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.7.0 (2021-05-23)
2
+
3
+ - Updated to Rice 4
4
+ - Added support for complex numbers
5
+
1
6
  ## 0.6.0 (2021-03-25)
2
7
 
3
8
  - Updated LibTorch to 1.8.0
@@ -75,7 +75,7 @@ def write_body(type, method_defs, attach_defs)
75
75
  // do not edit by hand
76
76
 
77
77
  #include <torch/torch.h>
78
- #include <rice/Module.hpp>
78
+ #include <rice/rice.hpp>
79
79
 
80
80
  #include "ruby_arg_parser.h"
81
81
  #include "templates.h"
@@ -119,7 +119,7 @@ def generate_attach_def(name, type, def_method)
119
119
  end
120
120
 
121
121
  def generate_method_def(name, functions, type, def_method)
122
- assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
122
+ assign_self = type == "tensor" ? "\n Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);" : ""
123
123
 
124
124
  functions = group_overloads(functions, type)
125
125
  signatures = functions.map { |f| f["signature"] }
data/ext/torch/cuda.cpp CHANGED
@@ -1,14 +1,14 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "utils.h"
6
6
 
7
7
  void init_cuda(Rice::Module& m) {
8
8
  Rice::define_module_under(m, "CUDA")
9
9
  .add_handler<torch::Error>(handle_error)
10
- .define_singleton_method("available?", &torch::cuda::is_available)
11
- .define_singleton_method("device_count", &torch::cuda::device_count)
12
- .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
13
- .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
10
+ .define_singleton_function("available?", &torch::cuda::is_available)
11
+ .define_singleton_function("device_count", &torch::cuda::device_count)
12
+ .define_singleton_function("manual_seed", &torch::cuda::manual_seed)
13
+ .define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
14
14
  }
data/ext/torch/device.cpp CHANGED
@@ -1,19 +1,26 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Constructor.hpp>
4
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
5
4
 
6
5
  #include "utils.h"
7
6
 
8
7
  void init_device(Rice::Module& m) {
9
8
  Rice::define_class_under<torch::Device>(m, "Device")
10
9
  .add_handler<torch::Error>(handle_error)
11
- .define_constructor(Rice::Constructor<torch::Device, std::string>())
12
- .define_method("index", &torch::Device::index)
13
- .define_method("index?", &torch::Device::has_index)
10
+ .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
11
+ .define_method(
12
+ "index",
13
+ [](torch::Device& self) {
14
+ return self.index();
15
+ })
16
+ .define_method(
17
+ "index?",
18
+ [](torch::Device& self) {
19
+ return self.has_index();
20
+ })
14
21
  .define_method(
15
22
  "type",
16
- *[](torch::Device& self) {
23
+ [](torch::Device& self) {
17
24
  std::stringstream s;
18
25
  s << self.type();
19
26
  return s.str();
data/ext/torch/ext.cpp CHANGED
@@ -1,12 +1,14 @@
1
- #include <rice/Module.hpp>
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
2
4
 
3
5
  void init_nn(Rice::Module& m);
4
- void init_tensor(Rice::Module& m);
6
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
5
7
  void init_torch(Rice::Module& m);
6
8
 
7
9
  void init_cuda(Rice::Module& m);
8
10
  void init_device(Rice::Module& m);
9
- void init_ivalue(Rice::Module& m);
11
+ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
10
12
  void init_random(Rice::Module& m);
11
13
 
12
14
  extern "C"
@@ -14,13 +16,20 @@ void Init_ext()
14
16
  {
15
17
  auto m = Rice::define_module("Torch");
16
18
 
19
+ // need to define certain classes up front to keep Rice happy
20
+ auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
21
+ .define_constructor(Rice::Constructor<torch::IValue>());
22
+ auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
23
+ auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
24
+ .define_constructor(Rice::Constructor<torch::TensorOptions>());
25
+
17
26
  // keep this order
18
27
  init_torch(m);
19
- init_tensor(m);
28
+ init_tensor(m, rb_cTensor, rb_cTensorOptions);
20
29
  init_nn(m);
21
30
 
22
31
  init_cuda(m);
23
32
  init_device(m);
24
- init_ivalue(m);
33
+ init_ivalue(m, rb_cIValue);
25
34
  init_random(m);
26
35
  }
data/ext/torch/extconf.rb CHANGED
@@ -1,8 +1,6 @@
1
1
  require "mkmf-rice"
2
2
 
3
- abort "Missing stdc++" unless have_library("stdc++")
4
-
5
- $CXXFLAGS += " -std=c++14"
3
+ $CXXFLAGS += " -std=c++17 $(optflags)"
6
4
 
7
5
  # change to 0 for Linux pre-cxx11 ABI version
8
6
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
data/ext/torch/ivalue.cpp CHANGED
@@ -1,18 +1,13 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Array.hpp>
4
- #include <rice/Constructor.hpp>
5
- #include <rice/Hash.hpp>
6
- #include <rice/Module.hpp>
7
- #include <rice/String.hpp>
3
+ #include <rice/rice.hpp>
8
4
 
9
5
  #include "utils.h"
10
6
 
11
- void init_ivalue(Rice::Module& m) {
7
+ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
12
8
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
13
- Rice::define_class_under<torch::IValue>(m, "IValue")
9
+ rb_cIValue
14
10
  .add_handler<torch::Error>(handle_error)
15
- .define_constructor(Rice::Constructor<torch::IValue>())
16
11
  .define_method("bool?", &torch::IValue::isBool)
17
12
  .define_method("bool_list?", &torch::IValue::isBoolList)
18
13
  .define_method("capsule?", &torch::IValue::isCapsule)
@@ -39,95 +34,98 @@ void init_ivalue(Rice::Module& m) {
39
34
  .define_method("tuple?", &torch::IValue::isTuple)
40
35
  .define_method(
41
36
  "to_bool",
42
- *[](torch::IValue& self) {
37
+ [](torch::IValue& self) {
43
38
  return self.toBool();
44
39
  })
45
40
  .define_method(
46
41
  "to_double",
47
- *[](torch::IValue& self) {
42
+ [](torch::IValue& self) {
48
43
  return self.toDouble();
49
44
  })
50
45
  .define_method(
51
46
  "to_int",
52
- *[](torch::IValue& self) {
47
+ [](torch::IValue& self) {
53
48
  return self.toInt();
54
49
  })
55
50
  .define_method(
56
51
  "to_list",
57
- *[](torch::IValue& self) {
52
+ [](torch::IValue& self) {
58
53
  auto list = self.toListRef();
59
54
  Rice::Array obj;
60
55
  for (auto& elem : list) {
61
- obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
56
+ auto v = torch::IValue{elem};
57
+ obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
62
58
  }
63
59
  return obj;
64
60
  })
65
61
  .define_method(
66
62
  "to_string_ref",
67
- *[](torch::IValue& self) {
63
+ [](torch::IValue& self) {
68
64
  return self.toStringRef();
69
65
  })
70
66
  .define_method(
71
67
  "to_tensor",
72
- *[](torch::IValue& self) {
68
+ [](torch::IValue& self) {
73
69
  return self.toTensor();
74
70
  })
75
71
  .define_method(
76
72
  "to_generic_dict",
77
- *[](torch::IValue& self) {
73
+ [](torch::IValue& self) {
78
74
  auto dict = self.toGenericDict();
79
75
  Rice::Hash obj;
80
76
  for (auto& pair : dict) {
81
- obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
77
+ auto k = torch::IValue{pair.key()};
78
+ auto v = torch::IValue{pair.value()};
79
+ obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
82
80
  }
83
81
  return obj;
84
82
  })
85
- .define_singleton_method(
83
+ .define_singleton_function(
86
84
  "from_tensor",
87
- *[](torch::Tensor& v) {
85
+ [](torch::Tensor& v) {
88
86
  return torch::IValue(v);
89
87
  })
90
88
  // TODO create specialized list types?
91
- .define_singleton_method(
89
+ .define_singleton_function(
92
90
  "from_list",
93
- *[](Rice::Array obj) {
91
+ [](Rice::Array obj) {
94
92
  c10::impl::GenericList list(c10::AnyType::get());
95
93
  for (auto entry : obj) {
96
- list.push_back(from_ruby<torch::IValue>(entry));
94
+ list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
97
95
  }
98
96
  return torch::IValue(list);
99
97
  })
100
- .define_singleton_method(
98
+ .define_singleton_function(
101
99
  "from_string",
102
- *[](Rice::String v) {
100
+ [](Rice::String v) {
103
101
  return torch::IValue(v.str());
104
102
  })
105
- .define_singleton_method(
103
+ .define_singleton_function(
106
104
  "from_int",
107
- *[](int64_t v) {
105
+ [](int64_t v) {
108
106
  return torch::IValue(v);
109
107
  })
110
- .define_singleton_method(
108
+ .define_singleton_function(
111
109
  "from_double",
112
- *[](double v) {
110
+ [](double v) {
113
111
  return torch::IValue(v);
114
112
  })
115
- .define_singleton_method(
113
+ .define_singleton_function(
116
114
  "from_bool",
117
- *[](bool v) {
115
+ [](bool v) {
118
116
  return torch::IValue(v);
119
117
  })
120
118
  // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
121
119
  // createGenericDict and toIValue
122
- .define_singleton_method(
120
+ .define_singleton_function(
123
121
  "from_dict",
124
- *[](Rice::Hash obj) {
122
+ [](Rice::Hash obj) {
125
123
  auto key_type = c10::AnyType::get();
126
124
  auto value_type = c10::AnyType::get();
127
125
  c10::impl::GenericDict elems(key_type, value_type);
128
126
  elems.reserve(obj.size());
129
127
  for (auto entry : obj) {
130
- elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second));
128
+ elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
131
129
  }
132
130
  return torch::IValue(std::move(elems));
133
131
  });
data/ext/torch/nn.cpp CHANGED
@@ -1,6 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "nn_functions.h"
6
6
  #include "templates.h"
@@ -19,74 +19,74 @@ void init_nn(Rice::Module& m) {
19
19
 
20
20
  Rice::define_module_under(rb_mNN, "Init")
21
21
  .add_handler<torch::Error>(handle_error)
22
- .define_singleton_method(
22
+ .define_singleton_function(
23
23
  "_calculate_gain",
24
- *[](NonlinearityType nonlinearity, double param) {
24
+ [](NonlinearityType nonlinearity, double param) {
25
25
  return torch::nn::init::calculate_gain(nonlinearity, param);
26
26
  })
27
- .define_singleton_method(
27
+ .define_singleton_function(
28
28
  "_uniform!",
29
- *[](Tensor tensor, double low, double high) {
29
+ [](Tensor tensor, double low, double high) {
30
30
  return torch::nn::init::uniform_(tensor, low, high);
31
31
  })
32
- .define_singleton_method(
32
+ .define_singleton_function(
33
33
  "_normal!",
34
- *[](Tensor tensor, double mean, double std) {
34
+ [](Tensor tensor, double mean, double std) {
35
35
  return torch::nn::init::normal_(tensor, mean, std);
36
36
  })
37
- .define_singleton_method(
37
+ .define_singleton_function(
38
38
  "_constant!",
39
- *[](Tensor tensor, Scalar value) {
39
+ [](Tensor tensor, Scalar value) {
40
40
  return torch::nn::init::constant_(tensor, value);
41
41
  })
42
- .define_singleton_method(
42
+ .define_singleton_function(
43
43
  "_ones!",
44
- *[](Tensor tensor) {
44
+ [](Tensor tensor) {
45
45
  return torch::nn::init::ones_(tensor);
46
46
  })
47
- .define_singleton_method(
47
+ .define_singleton_function(
48
48
  "_zeros!",
49
- *[](Tensor tensor) {
49
+ [](Tensor tensor) {
50
50
  return torch::nn::init::zeros_(tensor);
51
51
  })
52
- .define_singleton_method(
52
+ .define_singleton_function(
53
53
  "_eye!",
54
- *[](Tensor tensor) {
54
+ [](Tensor tensor) {
55
55
  return torch::nn::init::eye_(tensor);
56
56
  })
57
- .define_singleton_method(
57
+ .define_singleton_function(
58
58
  "_dirac!",
59
- *[](Tensor tensor) {
59
+ [](Tensor tensor) {
60
60
  return torch::nn::init::dirac_(tensor);
61
61
  })
62
- .define_singleton_method(
62
+ .define_singleton_function(
63
63
  "_xavier_uniform!",
64
- *[](Tensor tensor, double gain) {
64
+ [](Tensor tensor, double gain) {
65
65
  return torch::nn::init::xavier_uniform_(tensor, gain);
66
66
  })
67
- .define_singleton_method(
67
+ .define_singleton_function(
68
68
  "_xavier_normal!",
69
- *[](Tensor tensor, double gain) {
69
+ [](Tensor tensor, double gain) {
70
70
  return torch::nn::init::xavier_normal_(tensor, gain);
71
71
  })
72
- .define_singleton_method(
72
+ .define_singleton_function(
73
73
  "_kaiming_uniform!",
74
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
74
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
75
75
  return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
76
76
  })
77
- .define_singleton_method(
77
+ .define_singleton_function(
78
78
  "_kaiming_normal!",
79
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
79
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
80
80
  return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
81
81
  })
82
- .define_singleton_method(
82
+ .define_singleton_function(
83
83
  "_orthogonal!",
84
- *[](Tensor tensor, double gain) {
84
+ [](Tensor tensor, double gain) {
85
85
  return torch::nn::init::orthogonal_(tensor, gain);
86
86
  })
87
- .define_singleton_method(
87
+ .define_singleton_function(
88
88
  "_sparse!",
89
- *[](Tensor tensor, double sparsity, double std) {
89
+ [](Tensor tensor, double sparsity, double std) {
90
90
  return torch::nn::init::sparse_(tensor, sparsity, std);
91
91
  });
92
92
 
@@ -94,18 +94,18 @@ void init_nn(Rice::Module& m) {
94
94
  .add_handler<torch::Error>(handle_error)
95
95
  .define_method(
96
96
  "grad",
97
- *[](Parameter& self) {
97
+ [](Parameter& self) {
98
98
  auto grad = self.grad();
99
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
99
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
100
100
  })
101
101
  .define_method(
102
102
  "grad=",
103
- *[](Parameter& self, torch::Tensor& grad) {
103
+ [](Parameter& self, torch::Tensor& grad) {
104
104
  self.mutable_grad() = grad;
105
105
  })
106
- .define_singleton_method(
106
+ .define_singleton_function(
107
107
  "_make_subclass",
108
- *[](Tensor& rd, bool requires_grad) {
108
+ [](Tensor& rd, bool requires_grad) {
109
109
  auto data = rd.detach();
110
110
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
111
111
  auto var = data.set_requires_grad(requires_grad);
data/ext/torch/random.cpp CHANGED
@@ -1,20 +1,20 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "utils.h"
6
6
 
7
7
  void init_random(Rice::Module& m) {
8
8
  Rice::define_module_under(m, "Random")
9
9
  .add_handler<torch::Error>(handle_error)
10
- .define_singleton_method(
10
+ .define_singleton_function(
11
11
  "initial_seed",
12
- *[]() {
12
+ []() {
13
13
  return at::detail::getDefaultCPUGenerator().current_seed();
14
14
  })
15
- .define_singleton_method(
15
+ .define_singleton_function(
16
16
  "seed",
17
- *[]() {
17
+ []() {
18
18
  // TODO set for CUDA when available
19
19
  auto generator = at::detail::getDefaultCPUGenerator();
20
20
  return generator.seed();