torch-rb 0.6.0 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/codegen/generate_functions.rb +2 -2
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +14 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +16 -11
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +5 -0
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 859f5858ce45f6a73fbf9afd6073e3af798ac7ea64713405c54df0fc40b0a1e6
|
4
|
+
data.tar.gz: 4c371640902d1226c69135874aaa076251d2539c56dc78479203bc28473b071b
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f8ab6c1af0da9ad36f1fa7fae0cb67b54a60c11783d13cedc4c66164fb1fcb40638a68bbd78fa0f644e6d6d1c7df44d83c68c61bc0e41ea4c361817b0fe7b9cd
|
7
|
+
data.tar.gz: f64913cf8c2566539fef54e29e42173f4e2e0529ea49d381c6cf41af64868653865fc2cfa366f5086e035b45008a07f2d620a000c64dabf92e3c9134489b5a6b
|
data/CHANGELOG.md
CHANGED
@@ -75,7 +75,7 @@ def write_body(type, method_defs, attach_defs)
|
|
75
75
|
// do not edit by hand
|
76
76
|
|
77
77
|
#include <torch/torch.h>
|
78
|
-
#include <rice/
|
78
|
+
#include <rice/rice.hpp>
|
79
79
|
|
80
80
|
#include "ruby_arg_parser.h"
|
81
81
|
#include "templates.h"
|
@@ -119,7 +119,7 @@ def generate_attach_def(name, type, def_method)
|
|
119
119
|
end
|
120
120
|
|
121
121
|
def generate_method_def(name, functions, type, def_method)
|
122
|
-
assign_self = type == "tensor" ? "\n Tensor& self =
|
122
|
+
assign_self = type == "tensor" ? "\n Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);" : ""
|
123
123
|
|
124
124
|
functions = group_overloads(functions, type)
|
125
125
|
signatures = functions.map { |f| f["signature"] }
|
data/ext/torch/cuda.cpp
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_cuda(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "CUDA")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
11
|
-
.
|
12
|
-
.
|
13
|
-
.
|
10
|
+
.define_singleton_function("available?", &torch::cuda::is_available)
|
11
|
+
.define_singleton_function("device_count", &torch::cuda::device_count)
|
12
|
+
.define_singleton_function("manual_seed", &torch::cuda::manual_seed)
|
13
|
+
.define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
|
14
14
|
}
|
data/ext/torch/device.cpp
CHANGED
@@ -1,19 +1,26 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Module.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
5
4
|
|
6
5
|
#include "utils.h"
|
7
6
|
|
8
7
|
void init_device(Rice::Module& m) {
|
9
8
|
Rice::define_class_under<torch::Device>(m, "Device")
|
10
9
|
.add_handler<torch::Error>(handle_error)
|
11
|
-
.define_constructor(Rice::Constructor<torch::Device, std::string
|
12
|
-
.define_method(
|
13
|
-
|
10
|
+
.define_constructor(Rice::Constructor<torch::Device, const std::string&>())
|
11
|
+
.define_method(
|
12
|
+
"index",
|
13
|
+
[](torch::Device& self) {
|
14
|
+
return self.index();
|
15
|
+
})
|
16
|
+
.define_method(
|
17
|
+
"index?",
|
18
|
+
[](torch::Device& self) {
|
19
|
+
return self.has_index();
|
20
|
+
})
|
14
21
|
.define_method(
|
15
22
|
"type",
|
16
|
-
|
23
|
+
[](torch::Device& self) {
|
17
24
|
std::stringstream s;
|
18
25
|
s << self.type();
|
19
26
|
return s.str();
|
data/ext/torch/ext.cpp
CHANGED
@@ -1,12 +1,14 @@
|
|
1
|
-
#include <
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
2
4
|
|
3
5
|
void init_nn(Rice::Module& m);
|
4
|
-
void init_tensor(Rice::Module& m);
|
6
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
5
7
|
void init_torch(Rice::Module& m);
|
6
8
|
|
7
9
|
void init_cuda(Rice::Module& m);
|
8
10
|
void init_device(Rice::Module& m);
|
9
|
-
void init_ivalue(Rice::Module& m);
|
11
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
10
12
|
void init_random(Rice::Module& m);
|
11
13
|
|
12
14
|
extern "C"
|
@@ -14,13 +16,20 @@ void Init_ext()
|
|
14
16
|
{
|
15
17
|
auto m = Rice::define_module("Torch");
|
16
18
|
|
19
|
+
// need to define certain classes up front to keep Rice happy
|
20
|
+
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
21
|
+
.define_constructor(Rice::Constructor<torch::IValue>());
|
22
|
+
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
23
|
+
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
24
|
+
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
25
|
+
|
17
26
|
// keep this order
|
18
27
|
init_torch(m);
|
19
|
-
init_tensor(m);
|
28
|
+
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
20
29
|
init_nn(m);
|
21
30
|
|
22
31
|
init_cuda(m);
|
23
32
|
init_device(m);
|
24
|
-
init_ivalue(m);
|
33
|
+
init_ivalue(m, rb_cIValue);
|
25
34
|
init_random(m);
|
26
35
|
}
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/ivalue.cpp
CHANGED
@@ -1,18 +1,13 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Constructor.hpp>
|
5
|
-
#include <rice/Hash.hpp>
|
6
|
-
#include <rice/Module.hpp>
|
7
|
-
#include <rice/String.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
8
4
|
|
9
5
|
#include "utils.h"
|
10
6
|
|
11
|
-
void init_ivalue(Rice::Module& m) {
|
7
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
12
8
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
13
|
-
|
9
|
+
rb_cIValue
|
14
10
|
.add_handler<torch::Error>(handle_error)
|
15
|
-
.define_constructor(Rice::Constructor<torch::IValue>())
|
16
11
|
.define_method("bool?", &torch::IValue::isBool)
|
17
12
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
18
13
|
.define_method("capsule?", &torch::IValue::isCapsule)
|
@@ -39,95 +34,98 @@ void init_ivalue(Rice::Module& m) {
|
|
39
34
|
.define_method("tuple?", &torch::IValue::isTuple)
|
40
35
|
.define_method(
|
41
36
|
"to_bool",
|
42
|
-
|
37
|
+
[](torch::IValue& self) {
|
43
38
|
return self.toBool();
|
44
39
|
})
|
45
40
|
.define_method(
|
46
41
|
"to_double",
|
47
|
-
|
42
|
+
[](torch::IValue& self) {
|
48
43
|
return self.toDouble();
|
49
44
|
})
|
50
45
|
.define_method(
|
51
46
|
"to_int",
|
52
|
-
|
47
|
+
[](torch::IValue& self) {
|
53
48
|
return self.toInt();
|
54
49
|
})
|
55
50
|
.define_method(
|
56
51
|
"to_list",
|
57
|
-
|
52
|
+
[](torch::IValue& self) {
|
58
53
|
auto list = self.toListRef();
|
59
54
|
Rice::Array obj;
|
60
55
|
for (auto& elem : list) {
|
61
|
-
|
56
|
+
auto v = torch::IValue{elem};
|
57
|
+
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
|
62
58
|
}
|
63
59
|
return obj;
|
64
60
|
})
|
65
61
|
.define_method(
|
66
62
|
"to_string_ref",
|
67
|
-
|
63
|
+
[](torch::IValue& self) {
|
68
64
|
return self.toStringRef();
|
69
65
|
})
|
70
66
|
.define_method(
|
71
67
|
"to_tensor",
|
72
|
-
|
68
|
+
[](torch::IValue& self) {
|
73
69
|
return self.toTensor();
|
74
70
|
})
|
75
71
|
.define_method(
|
76
72
|
"to_generic_dict",
|
77
|
-
|
73
|
+
[](torch::IValue& self) {
|
78
74
|
auto dict = self.toGenericDict();
|
79
75
|
Rice::Hash obj;
|
80
76
|
for (auto& pair : dict) {
|
81
|
-
|
77
|
+
auto k = torch::IValue{pair.key()};
|
78
|
+
auto v = torch::IValue{pair.value()};
|
79
|
+
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
82
80
|
}
|
83
81
|
return obj;
|
84
82
|
})
|
85
|
-
.
|
83
|
+
.define_singleton_function(
|
86
84
|
"from_tensor",
|
87
|
-
|
85
|
+
[](torch::Tensor& v) {
|
88
86
|
return torch::IValue(v);
|
89
87
|
})
|
90
88
|
// TODO create specialized list types?
|
91
|
-
.
|
89
|
+
.define_singleton_function(
|
92
90
|
"from_list",
|
93
|
-
|
91
|
+
[](Rice::Array obj) {
|
94
92
|
c10::impl::GenericList list(c10::AnyType::get());
|
95
93
|
for (auto entry : obj) {
|
96
|
-
list.push_back(
|
94
|
+
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
97
95
|
}
|
98
96
|
return torch::IValue(list);
|
99
97
|
})
|
100
|
-
.
|
98
|
+
.define_singleton_function(
|
101
99
|
"from_string",
|
102
|
-
|
100
|
+
[](Rice::String v) {
|
103
101
|
return torch::IValue(v.str());
|
104
102
|
})
|
105
|
-
.
|
103
|
+
.define_singleton_function(
|
106
104
|
"from_int",
|
107
|
-
|
105
|
+
[](int64_t v) {
|
108
106
|
return torch::IValue(v);
|
109
107
|
})
|
110
|
-
.
|
108
|
+
.define_singleton_function(
|
111
109
|
"from_double",
|
112
|
-
|
110
|
+
[](double v) {
|
113
111
|
return torch::IValue(v);
|
114
112
|
})
|
115
|
-
.
|
113
|
+
.define_singleton_function(
|
116
114
|
"from_bool",
|
117
|
-
|
115
|
+
[](bool v) {
|
118
116
|
return torch::IValue(v);
|
119
117
|
})
|
120
118
|
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
121
119
|
// createGenericDict and toIValue
|
122
|
-
.
|
120
|
+
.define_singleton_function(
|
123
121
|
"from_dict",
|
124
|
-
|
122
|
+
[](Rice::Hash obj) {
|
125
123
|
auto key_type = c10::AnyType::get();
|
126
124
|
auto value_type = c10::AnyType::get();
|
127
125
|
c10::impl::GenericDict elems(key_type, value_type);
|
128
126
|
elems.reserve(obj.size());
|
129
127
|
for (auto entry : obj) {
|
130
|
-
elems.insert(
|
128
|
+
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
131
129
|
}
|
132
130
|
return torch::IValue(std::move(elems));
|
133
131
|
});
|
data/ext/torch/nn.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "nn_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -19,74 +19,74 @@ void init_nn(Rice::Module& m) {
|
|
19
19
|
|
20
20
|
Rice::define_module_under(rb_mNN, "Init")
|
21
21
|
.add_handler<torch::Error>(handle_error)
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"_calculate_gain",
|
24
|
-
|
24
|
+
[](NonlinearityType nonlinearity, double param) {
|
25
25
|
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
26
|
})
|
27
|
-
.
|
27
|
+
.define_singleton_function(
|
28
28
|
"_uniform!",
|
29
|
-
|
29
|
+
[](Tensor tensor, double low, double high) {
|
30
30
|
return torch::nn::init::uniform_(tensor, low, high);
|
31
31
|
})
|
32
|
-
.
|
32
|
+
.define_singleton_function(
|
33
33
|
"_normal!",
|
34
|
-
|
34
|
+
[](Tensor tensor, double mean, double std) {
|
35
35
|
return torch::nn::init::normal_(tensor, mean, std);
|
36
36
|
})
|
37
|
-
.
|
37
|
+
.define_singleton_function(
|
38
38
|
"_constant!",
|
39
|
-
|
39
|
+
[](Tensor tensor, Scalar value) {
|
40
40
|
return torch::nn::init::constant_(tensor, value);
|
41
41
|
})
|
42
|
-
.
|
42
|
+
.define_singleton_function(
|
43
43
|
"_ones!",
|
44
|
-
|
44
|
+
[](Tensor tensor) {
|
45
45
|
return torch::nn::init::ones_(tensor);
|
46
46
|
})
|
47
|
-
.
|
47
|
+
.define_singleton_function(
|
48
48
|
"_zeros!",
|
49
|
-
|
49
|
+
[](Tensor tensor) {
|
50
50
|
return torch::nn::init::zeros_(tensor);
|
51
51
|
})
|
52
|
-
.
|
52
|
+
.define_singleton_function(
|
53
53
|
"_eye!",
|
54
|
-
|
54
|
+
[](Tensor tensor) {
|
55
55
|
return torch::nn::init::eye_(tensor);
|
56
56
|
})
|
57
|
-
.
|
57
|
+
.define_singleton_function(
|
58
58
|
"_dirac!",
|
59
|
-
|
59
|
+
[](Tensor tensor) {
|
60
60
|
return torch::nn::init::dirac_(tensor);
|
61
61
|
})
|
62
|
-
.
|
62
|
+
.define_singleton_function(
|
63
63
|
"_xavier_uniform!",
|
64
|
-
|
64
|
+
[](Tensor tensor, double gain) {
|
65
65
|
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
66
|
})
|
67
|
-
.
|
67
|
+
.define_singleton_function(
|
68
68
|
"_xavier_normal!",
|
69
|
-
|
69
|
+
[](Tensor tensor, double gain) {
|
70
70
|
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
71
|
})
|
72
|
-
.
|
72
|
+
.define_singleton_function(
|
73
73
|
"_kaiming_uniform!",
|
74
|
-
|
74
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
75
|
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
76
|
})
|
77
|
-
.
|
77
|
+
.define_singleton_function(
|
78
78
|
"_kaiming_normal!",
|
79
|
-
|
79
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
80
|
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
81
|
})
|
82
|
-
.
|
82
|
+
.define_singleton_function(
|
83
83
|
"_orthogonal!",
|
84
|
-
|
84
|
+
[](Tensor tensor, double gain) {
|
85
85
|
return torch::nn::init::orthogonal_(tensor, gain);
|
86
86
|
})
|
87
|
-
.
|
87
|
+
.define_singleton_function(
|
88
88
|
"_sparse!",
|
89
|
-
|
89
|
+
[](Tensor tensor, double sparsity, double std) {
|
90
90
|
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
91
|
});
|
92
92
|
|
@@ -94,18 +94,18 @@ void init_nn(Rice::Module& m) {
|
|
94
94
|
.add_handler<torch::Error>(handle_error)
|
95
95
|
.define_method(
|
96
96
|
"grad",
|
97
|
-
|
97
|
+
[](Parameter& self) {
|
98
98
|
auto grad = self.grad();
|
99
|
-
return grad.defined() ?
|
99
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
101
|
.define_method(
|
102
102
|
"grad=",
|
103
|
-
|
103
|
+
[](Parameter& self, torch::Tensor& grad) {
|
104
104
|
self.mutable_grad() = grad;
|
105
105
|
})
|
106
|
-
.
|
106
|
+
.define_singleton_function(
|
107
107
|
"_make_subclass",
|
108
|
-
|
108
|
+
[](Tensor& rd, bool requires_grad) {
|
109
109
|
auto data = rd.detach();
|
110
110
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
111
|
auto var = data.set_requires_grad(requires_grad);
|
data/ext/torch/random.cpp
CHANGED
@@ -1,20 +1,20 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_random(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "Random")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
10
|
+
.define_singleton_function(
|
11
11
|
"initial_seed",
|
12
|
-
|
12
|
+
[]() {
|
13
13
|
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
14
|
})
|
15
|
-
.
|
15
|
+
.define_singleton_function(
|
16
16
|
"seed",
|
17
|
-
|
17
|
+
[]() {
|
18
18
|
// TODO set for CUDA when available
|
19
19
|
auto generator = at::detail::getDefaultCPUGenerator();
|
20
20
|
return generator.seed();
|