torch-rb 0.6.0 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/codegen/generate_functions.rb +2 -2
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +14 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +16 -11
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +5 -0
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 859f5858ce45f6a73fbf9afd6073e3af798ac7ea64713405c54df0fc40b0a1e6
|
4
|
+
data.tar.gz: 4c371640902d1226c69135874aaa076251d2539c56dc78479203bc28473b071b
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f8ab6c1af0da9ad36f1fa7fae0cb67b54a60c11783d13cedc4c66164fb1fcb40638a68bbd78fa0f644e6d6d1c7df44d83c68c61bc0e41ea4c361817b0fe7b9cd
|
7
|
+
data.tar.gz: f64913cf8c2566539fef54e29e42173f4e2e0529ea49d381c6cf41af64868653865fc2cfa366f5086e035b45008a07f2d620a000c64dabf92e3c9134489b5a6b
|
data/CHANGELOG.md
CHANGED
@@ -75,7 +75,7 @@ def write_body(type, method_defs, attach_defs)
|
|
75
75
|
// do not edit by hand
|
76
76
|
|
77
77
|
#include <torch/torch.h>
|
78
|
-
#include <rice/
|
78
|
+
#include <rice/rice.hpp>
|
79
79
|
|
80
80
|
#include "ruby_arg_parser.h"
|
81
81
|
#include "templates.h"
|
@@ -119,7 +119,7 @@ def generate_attach_def(name, type, def_method)
|
|
119
119
|
end
|
120
120
|
|
121
121
|
def generate_method_def(name, functions, type, def_method)
|
122
|
-
assign_self = type == "tensor" ? "\n Tensor& self =
|
122
|
+
assign_self = type == "tensor" ? "\n Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);" : ""
|
123
123
|
|
124
124
|
functions = group_overloads(functions, type)
|
125
125
|
signatures = functions.map { |f| f["signature"] }
|
data/ext/torch/cuda.cpp
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_cuda(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "CUDA")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
11
|
-
.
|
12
|
-
.
|
13
|
-
.
|
10
|
+
.define_singleton_function("available?", &torch::cuda::is_available)
|
11
|
+
.define_singleton_function("device_count", &torch::cuda::device_count)
|
12
|
+
.define_singleton_function("manual_seed", &torch::cuda::manual_seed)
|
13
|
+
.define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
|
14
14
|
}
|
data/ext/torch/device.cpp
CHANGED
@@ -1,19 +1,26 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Module.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
5
4
|
|
6
5
|
#include "utils.h"
|
7
6
|
|
8
7
|
void init_device(Rice::Module& m) {
|
9
8
|
Rice::define_class_under<torch::Device>(m, "Device")
|
10
9
|
.add_handler<torch::Error>(handle_error)
|
11
|
-
.define_constructor(Rice::Constructor<torch::Device, std::string
|
12
|
-
.define_method(
|
13
|
-
|
10
|
+
.define_constructor(Rice::Constructor<torch::Device, const std::string&>())
|
11
|
+
.define_method(
|
12
|
+
"index",
|
13
|
+
[](torch::Device& self) {
|
14
|
+
return self.index();
|
15
|
+
})
|
16
|
+
.define_method(
|
17
|
+
"index?",
|
18
|
+
[](torch::Device& self) {
|
19
|
+
return self.has_index();
|
20
|
+
})
|
14
21
|
.define_method(
|
15
22
|
"type",
|
16
|
-
|
23
|
+
[](torch::Device& self) {
|
17
24
|
std::stringstream s;
|
18
25
|
s << self.type();
|
19
26
|
return s.str();
|
data/ext/torch/ext.cpp
CHANGED
@@ -1,12 +1,14 @@
|
|
1
|
-
#include <
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
2
4
|
|
3
5
|
void init_nn(Rice::Module& m);
|
4
|
-
void init_tensor(Rice::Module& m);
|
6
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
5
7
|
void init_torch(Rice::Module& m);
|
6
8
|
|
7
9
|
void init_cuda(Rice::Module& m);
|
8
10
|
void init_device(Rice::Module& m);
|
9
|
-
void init_ivalue(Rice::Module& m);
|
11
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
10
12
|
void init_random(Rice::Module& m);
|
11
13
|
|
12
14
|
extern "C"
|
@@ -14,13 +16,20 @@ void Init_ext()
|
|
14
16
|
{
|
15
17
|
auto m = Rice::define_module("Torch");
|
16
18
|
|
19
|
+
// need to define certain classes up front to keep Rice happy
|
20
|
+
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
21
|
+
.define_constructor(Rice::Constructor<torch::IValue>());
|
22
|
+
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
23
|
+
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
24
|
+
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
25
|
+
|
17
26
|
// keep this order
|
18
27
|
init_torch(m);
|
19
|
-
init_tensor(m);
|
28
|
+
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
20
29
|
init_nn(m);
|
21
30
|
|
22
31
|
init_cuda(m);
|
23
32
|
init_device(m);
|
24
|
-
init_ivalue(m);
|
33
|
+
init_ivalue(m, rb_cIValue);
|
25
34
|
init_random(m);
|
26
35
|
}
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/ivalue.cpp
CHANGED
@@ -1,18 +1,13 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Constructor.hpp>
|
5
|
-
#include <rice/Hash.hpp>
|
6
|
-
#include <rice/Module.hpp>
|
7
|
-
#include <rice/String.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
8
4
|
|
9
5
|
#include "utils.h"
|
10
6
|
|
11
|
-
void init_ivalue(Rice::Module& m) {
|
7
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
12
8
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
13
|
-
|
9
|
+
rb_cIValue
|
14
10
|
.add_handler<torch::Error>(handle_error)
|
15
|
-
.define_constructor(Rice::Constructor<torch::IValue>())
|
16
11
|
.define_method("bool?", &torch::IValue::isBool)
|
17
12
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
18
13
|
.define_method("capsule?", &torch::IValue::isCapsule)
|
@@ -39,95 +34,98 @@ void init_ivalue(Rice::Module& m) {
|
|
39
34
|
.define_method("tuple?", &torch::IValue::isTuple)
|
40
35
|
.define_method(
|
41
36
|
"to_bool",
|
42
|
-
|
37
|
+
[](torch::IValue& self) {
|
43
38
|
return self.toBool();
|
44
39
|
})
|
45
40
|
.define_method(
|
46
41
|
"to_double",
|
47
|
-
|
42
|
+
[](torch::IValue& self) {
|
48
43
|
return self.toDouble();
|
49
44
|
})
|
50
45
|
.define_method(
|
51
46
|
"to_int",
|
52
|
-
|
47
|
+
[](torch::IValue& self) {
|
53
48
|
return self.toInt();
|
54
49
|
})
|
55
50
|
.define_method(
|
56
51
|
"to_list",
|
57
|
-
|
52
|
+
[](torch::IValue& self) {
|
58
53
|
auto list = self.toListRef();
|
59
54
|
Rice::Array obj;
|
60
55
|
for (auto& elem : list) {
|
61
|
-
|
56
|
+
auto v = torch::IValue{elem};
|
57
|
+
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
|
62
58
|
}
|
63
59
|
return obj;
|
64
60
|
})
|
65
61
|
.define_method(
|
66
62
|
"to_string_ref",
|
67
|
-
|
63
|
+
[](torch::IValue& self) {
|
68
64
|
return self.toStringRef();
|
69
65
|
})
|
70
66
|
.define_method(
|
71
67
|
"to_tensor",
|
72
|
-
|
68
|
+
[](torch::IValue& self) {
|
73
69
|
return self.toTensor();
|
74
70
|
})
|
75
71
|
.define_method(
|
76
72
|
"to_generic_dict",
|
77
|
-
|
73
|
+
[](torch::IValue& self) {
|
78
74
|
auto dict = self.toGenericDict();
|
79
75
|
Rice::Hash obj;
|
80
76
|
for (auto& pair : dict) {
|
81
|
-
|
77
|
+
auto k = torch::IValue{pair.key()};
|
78
|
+
auto v = torch::IValue{pair.value()};
|
79
|
+
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
82
80
|
}
|
83
81
|
return obj;
|
84
82
|
})
|
85
|
-
.
|
83
|
+
.define_singleton_function(
|
86
84
|
"from_tensor",
|
87
|
-
|
85
|
+
[](torch::Tensor& v) {
|
88
86
|
return torch::IValue(v);
|
89
87
|
})
|
90
88
|
// TODO create specialized list types?
|
91
|
-
.
|
89
|
+
.define_singleton_function(
|
92
90
|
"from_list",
|
93
|
-
|
91
|
+
[](Rice::Array obj) {
|
94
92
|
c10::impl::GenericList list(c10::AnyType::get());
|
95
93
|
for (auto entry : obj) {
|
96
|
-
list.push_back(
|
94
|
+
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
97
95
|
}
|
98
96
|
return torch::IValue(list);
|
99
97
|
})
|
100
|
-
.
|
98
|
+
.define_singleton_function(
|
101
99
|
"from_string",
|
102
|
-
|
100
|
+
[](Rice::String v) {
|
103
101
|
return torch::IValue(v.str());
|
104
102
|
})
|
105
|
-
.
|
103
|
+
.define_singleton_function(
|
106
104
|
"from_int",
|
107
|
-
|
105
|
+
[](int64_t v) {
|
108
106
|
return torch::IValue(v);
|
109
107
|
})
|
110
|
-
.
|
108
|
+
.define_singleton_function(
|
111
109
|
"from_double",
|
112
|
-
|
110
|
+
[](double v) {
|
113
111
|
return torch::IValue(v);
|
114
112
|
})
|
115
|
-
.
|
113
|
+
.define_singleton_function(
|
116
114
|
"from_bool",
|
117
|
-
|
115
|
+
[](bool v) {
|
118
116
|
return torch::IValue(v);
|
119
117
|
})
|
120
118
|
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
121
119
|
// createGenericDict and toIValue
|
122
|
-
.
|
120
|
+
.define_singleton_function(
|
123
121
|
"from_dict",
|
124
|
-
|
122
|
+
[](Rice::Hash obj) {
|
125
123
|
auto key_type = c10::AnyType::get();
|
126
124
|
auto value_type = c10::AnyType::get();
|
127
125
|
c10::impl::GenericDict elems(key_type, value_type);
|
128
126
|
elems.reserve(obj.size());
|
129
127
|
for (auto entry : obj) {
|
130
|
-
elems.insert(
|
128
|
+
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
131
129
|
}
|
132
130
|
return torch::IValue(std::move(elems));
|
133
131
|
});
|
data/ext/torch/nn.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "nn_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -19,74 +19,74 @@ void init_nn(Rice::Module& m) {
|
|
19
19
|
|
20
20
|
Rice::define_module_under(rb_mNN, "Init")
|
21
21
|
.add_handler<torch::Error>(handle_error)
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"_calculate_gain",
|
24
|
-
|
24
|
+
[](NonlinearityType nonlinearity, double param) {
|
25
25
|
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
26
|
})
|
27
|
-
.
|
27
|
+
.define_singleton_function(
|
28
28
|
"_uniform!",
|
29
|
-
|
29
|
+
[](Tensor tensor, double low, double high) {
|
30
30
|
return torch::nn::init::uniform_(tensor, low, high);
|
31
31
|
})
|
32
|
-
.
|
32
|
+
.define_singleton_function(
|
33
33
|
"_normal!",
|
34
|
-
|
34
|
+
[](Tensor tensor, double mean, double std) {
|
35
35
|
return torch::nn::init::normal_(tensor, mean, std);
|
36
36
|
})
|
37
|
-
.
|
37
|
+
.define_singleton_function(
|
38
38
|
"_constant!",
|
39
|
-
|
39
|
+
[](Tensor tensor, Scalar value) {
|
40
40
|
return torch::nn::init::constant_(tensor, value);
|
41
41
|
})
|
42
|
-
.
|
42
|
+
.define_singleton_function(
|
43
43
|
"_ones!",
|
44
|
-
|
44
|
+
[](Tensor tensor) {
|
45
45
|
return torch::nn::init::ones_(tensor);
|
46
46
|
})
|
47
|
-
.
|
47
|
+
.define_singleton_function(
|
48
48
|
"_zeros!",
|
49
|
-
|
49
|
+
[](Tensor tensor) {
|
50
50
|
return torch::nn::init::zeros_(tensor);
|
51
51
|
})
|
52
|
-
.
|
52
|
+
.define_singleton_function(
|
53
53
|
"_eye!",
|
54
|
-
|
54
|
+
[](Tensor tensor) {
|
55
55
|
return torch::nn::init::eye_(tensor);
|
56
56
|
})
|
57
|
-
.
|
57
|
+
.define_singleton_function(
|
58
58
|
"_dirac!",
|
59
|
-
|
59
|
+
[](Tensor tensor) {
|
60
60
|
return torch::nn::init::dirac_(tensor);
|
61
61
|
})
|
62
|
-
.
|
62
|
+
.define_singleton_function(
|
63
63
|
"_xavier_uniform!",
|
64
|
-
|
64
|
+
[](Tensor tensor, double gain) {
|
65
65
|
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
66
|
})
|
67
|
-
.
|
67
|
+
.define_singleton_function(
|
68
68
|
"_xavier_normal!",
|
69
|
-
|
69
|
+
[](Tensor tensor, double gain) {
|
70
70
|
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
71
|
})
|
72
|
-
.
|
72
|
+
.define_singleton_function(
|
73
73
|
"_kaiming_uniform!",
|
74
|
-
|
74
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
75
|
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
76
|
})
|
77
|
-
.
|
77
|
+
.define_singleton_function(
|
78
78
|
"_kaiming_normal!",
|
79
|
-
|
79
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
80
|
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
81
|
})
|
82
|
-
.
|
82
|
+
.define_singleton_function(
|
83
83
|
"_orthogonal!",
|
84
|
-
|
84
|
+
[](Tensor tensor, double gain) {
|
85
85
|
return torch::nn::init::orthogonal_(tensor, gain);
|
86
86
|
})
|
87
|
-
.
|
87
|
+
.define_singleton_function(
|
88
88
|
"_sparse!",
|
89
|
-
|
89
|
+
[](Tensor tensor, double sparsity, double std) {
|
90
90
|
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
91
|
});
|
92
92
|
|
@@ -94,18 +94,18 @@ void init_nn(Rice::Module& m) {
|
|
94
94
|
.add_handler<torch::Error>(handle_error)
|
95
95
|
.define_method(
|
96
96
|
"grad",
|
97
|
-
|
97
|
+
[](Parameter& self) {
|
98
98
|
auto grad = self.grad();
|
99
|
-
return grad.defined() ?
|
99
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
101
|
.define_method(
|
102
102
|
"grad=",
|
103
|
-
|
103
|
+
[](Parameter& self, torch::Tensor& grad) {
|
104
104
|
self.mutable_grad() = grad;
|
105
105
|
})
|
106
|
-
.
|
106
|
+
.define_singleton_function(
|
107
107
|
"_make_subclass",
|
108
|
-
|
108
|
+
[](Tensor& rd, bool requires_grad) {
|
109
109
|
auto data = rd.detach();
|
110
110
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
111
|
auto var = data.set_requires_grad(requires_grad);
|
data/ext/torch/random.cpp
CHANGED
@@ -1,20 +1,20 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_random(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "Random")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
10
|
+
.define_singleton_function(
|
11
11
|
"initial_seed",
|
12
|
-
|
12
|
+
[]() {
|
13
13
|
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
14
|
})
|
15
|
-
.
|
15
|
+
.define_singleton_function(
|
16
16
|
"seed",
|
17
|
-
|
17
|
+
[]() {
|
18
18
|
// TODO set for CUDA when available
|
19
19
|
auto generator = at::detail::getDefaultCPUGenerator();
|
20
20
|
return generator.seed();
|