torch-rb 0.6.0 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/codegen/generate_functions.rb +2 -2
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +14 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +16 -11
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +5 -0
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
data/ext/torch/torch.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "torch_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -9,69 +9,78 @@
|
|
9
9
|
void init_torch(Rice::Module& m) {
|
10
10
|
m.add_handler<torch::Error>(handle_error);
|
11
11
|
add_torch_functions(m);
|
12
|
-
m.
|
12
|
+
m.define_singleton_function(
|
13
13
|
"grad_enabled?",
|
14
|
-
|
14
|
+
[]() {
|
15
15
|
return torch::GradMode::is_enabled();
|
16
16
|
})
|
17
|
-
.
|
17
|
+
.define_singleton_function(
|
18
18
|
"_set_grad_enabled",
|
19
|
-
|
19
|
+
[](bool enabled) {
|
20
20
|
torch::GradMode::set_enabled(enabled);
|
21
21
|
})
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"manual_seed",
|
24
|
-
|
24
|
+
[](uint64_t seed) {
|
25
25
|
return torch::manual_seed(seed);
|
26
26
|
})
|
27
27
|
// config
|
28
|
-
.
|
28
|
+
.define_singleton_function(
|
29
29
|
"show_config",
|
30
|
-
|
30
|
+
[] {
|
31
31
|
return torch::show_config();
|
32
32
|
})
|
33
|
-
.
|
33
|
+
.define_singleton_function(
|
34
34
|
"parallel_info",
|
35
|
-
|
35
|
+
[] {
|
36
36
|
return torch::get_parallel_info();
|
37
37
|
})
|
38
38
|
// begin operations
|
39
|
-
.
|
39
|
+
.define_singleton_function(
|
40
40
|
"_save",
|
41
|
-
|
41
|
+
[](const torch::IValue &value) {
|
42
42
|
auto v = torch::pickle_save(value);
|
43
43
|
std::string str(v.begin(), v.end());
|
44
44
|
return str;
|
45
45
|
})
|
46
|
-
.
|
46
|
+
.define_singleton_function(
|
47
47
|
"_load",
|
48
|
-
|
48
|
+
[](const std::string &s) {
|
49
49
|
std::vector<char> v;
|
50
50
|
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
51
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
52
|
return torch::pickle_load(v);
|
53
53
|
})
|
54
|
-
.
|
54
|
+
.define_singleton_function(
|
55
55
|
"_from_blob",
|
56
|
-
|
56
|
+
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
57
|
void *data = const_cast<char *>(s.c_str());
|
58
58
|
return torch::from_blob(data, size, options);
|
59
59
|
})
|
60
|
-
.
|
60
|
+
.define_singleton_function(
|
61
61
|
"_tensor",
|
62
|
-
|
62
|
+
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
63
|
auto dtype = options.dtype();
|
64
64
|
torch::Tensor t;
|
65
65
|
if (dtype == torch::kBool) {
|
66
66
|
std::vector<uint8_t> vec;
|
67
67
|
for (long i = 0; i < a.size(); i++) {
|
68
|
-
vec.push_back(
|
68
|
+
vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
|
72
|
+
// TODO use template
|
73
|
+
std::vector<c10::complex<double>> vec;
|
74
|
+
Object obj;
|
75
|
+
for (long i = 0; i < a.size(); i++) {
|
76
|
+
obj = a[i];
|
77
|
+
vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
|
69
78
|
}
|
70
79
|
t = torch::tensor(vec, options);
|
71
80
|
} else {
|
72
81
|
std::vector<float> vec;
|
73
82
|
for (long i = 0; i < a.size(); i++) {
|
74
|
-
vec.push_back(
|
83
|
+
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
75
84
|
}
|
76
85
|
// hack for requires_grad error
|
77
86
|
if (options.requires_grad()) {
|
data/ext/torch/utils.h
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
#include <rice/stl.hpp>
|
5
5
|
|
6
6
|
// TODO find better place
|
7
|
-
inline void handle_error(torch::Error const & ex)
|
8
|
-
{
|
7
|
+
inline void handle_error(torch::Error const & ex) {
|
9
8
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
10
9
|
}
|
11
10
|
|
data/ext/torch/wrap_outputs.h
CHANGED
@@ -1,99 +1,106 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <torch/torch.h>
|
4
|
-
#include <rice/
|
4
|
+
#include <rice/rice.hpp>
|
5
5
|
|
6
|
-
inline
|
7
|
-
return
|
6
|
+
inline VALUE wrap(bool x) {
|
7
|
+
return Rice::detail::To_Ruby<bool>().convert(x);
|
8
8
|
}
|
9
9
|
|
10
|
-
inline
|
11
|
-
return
|
10
|
+
inline VALUE wrap(int64_t x) {
|
11
|
+
return Rice::detail::To_Ruby<int64_t>().convert(x);
|
12
12
|
}
|
13
13
|
|
14
|
-
inline
|
15
|
-
return
|
14
|
+
inline VALUE wrap(double x) {
|
15
|
+
return Rice::detail::To_Ruby<double>().convert(x);
|
16
16
|
}
|
17
17
|
|
18
|
-
inline
|
19
|
-
return
|
18
|
+
inline VALUE wrap(torch::Tensor x) {
|
19
|
+
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
20
20
|
}
|
21
21
|
|
22
|
-
inline
|
23
|
-
return
|
22
|
+
inline VALUE wrap(torch::Scalar x) {
|
23
|
+
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
24
24
|
}
|
25
25
|
|
26
|
-
inline
|
27
|
-
return
|
26
|
+
inline VALUE wrap(torch::ScalarType x) {
|
27
|
+
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
28
28
|
}
|
29
29
|
|
30
|
-
inline
|
31
|
-
return
|
30
|
+
inline VALUE wrap(torch::QScheme x) {
|
31
|
+
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
32
32
|
}
|
33
33
|
|
34
|
-
inline
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
34
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
35
|
+
return rb_ary_new3(
|
36
|
+
2,
|
37
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
38
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
39
|
+
);
|
39
40
|
}
|
40
41
|
|
41
|
-
inline
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
42
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
43
|
+
return rb_ary_new3(
|
44
|
+
3,
|
45
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
46
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
47
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
|
48
|
+
);
|
47
49
|
}
|
48
50
|
|
49
|
-
inline
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
51
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
52
|
+
return rb_ary_new3(
|
53
|
+
4,
|
54
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
55
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
56
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
57
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
|
58
|
+
);
|
56
59
|
}
|
57
60
|
|
58
|
-
inline
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
61
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
62
|
+
return rb_ary_new3(
|
63
|
+
5,
|
64
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
65
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
66
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
67
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
|
68
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
|
69
|
+
);
|
66
70
|
}
|
67
71
|
|
68
|
-
inline
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
72
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
73
|
+
return rb_ary_new3(
|
74
|
+
4,
|
75
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
76
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
77
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
78
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
79
|
+
);
|
75
80
|
}
|
76
81
|
|
77
|
-
inline
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
82
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
83
|
+
return rb_ary_new3(
|
84
|
+
4,
|
85
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
86
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
87
|
+
Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
|
88
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
89
|
+
);
|
84
90
|
}
|
85
91
|
|
86
|
-
inline
|
87
|
-
|
88
|
-
for (auto
|
89
|
-
a
|
92
|
+
inline VALUE wrap(torch::TensorList x) {
|
93
|
+
auto a = rb_ary_new2(x.size());
|
94
|
+
for (auto t : x) {
|
95
|
+
rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
90
96
|
}
|
91
|
-
return
|
97
|
+
return a;
|
92
98
|
}
|
93
99
|
|
94
|
-
inline
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
100
|
+
inline VALUE wrap(std::tuple<double, double> x) {
|
101
|
+
return rb_ary_new3(
|
102
|
+
2,
|
103
|
+
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
104
|
+
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|
105
|
+
);
|
99
106
|
}
|
data/lib/torch.rb
CHANGED
@@ -238,8 +238,11 @@ module Torch
|
|
238
238
|
double: 7,
|
239
239
|
float64: 7,
|
240
240
|
complex_half: 8,
|
241
|
+
complex32: 8,
|
241
242
|
complex_float: 9,
|
243
|
+
complex64: 9,
|
242
244
|
complex_double: 10,
|
245
|
+
complex128: 10,
|
243
246
|
bool: 11,
|
244
247
|
qint8: 12,
|
245
248
|
quint8: 13,
|
@@ -394,6 +397,8 @@ module Torch
|
|
394
397
|
options[:dtype] = :int64
|
395
398
|
elsif data.all? { |v| v == true || v == false }
|
396
399
|
options[:dtype] = :bool
|
400
|
+
elsif data.any? { |v| v.is_a?(Complex) }
|
401
|
+
options[:dtype] = :complex64
|
397
402
|
end
|
398
403
|
end
|
399
404
|
|
data/lib/torch/inspector.rb
CHANGED
@@ -96,8 +96,11 @@ module Torch
|
|
96
96
|
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
97
97
|
end
|
98
98
|
elsif @complex_dtype
|
99
|
-
|
100
|
-
|
99
|
+
# TODO use float formatter for each part
|
100
|
+
precision = PRINT_OPTS[:precision]
|
101
|
+
imag = value.imag
|
102
|
+
sign = imag >= 0 ? "+" : "-"
|
103
|
+
ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
|
101
104
|
else
|
102
105
|
ret = value.to_s
|
103
106
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.7.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-05-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version:
|
19
|
+
version: 4.0.2
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version:
|
26
|
+
version: 4.0.2
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|