torch-rb 0.6.0 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/codegen/generate_functions.rb +2 -2
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +14 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +16 -11
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +5 -0
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
data/ext/torch/torch.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "torch_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -9,69 +9,78 @@
|
|
9
9
|
void init_torch(Rice::Module& m) {
|
10
10
|
m.add_handler<torch::Error>(handle_error);
|
11
11
|
add_torch_functions(m);
|
12
|
-
m.
|
12
|
+
m.define_singleton_function(
|
13
13
|
"grad_enabled?",
|
14
|
-
|
14
|
+
[]() {
|
15
15
|
return torch::GradMode::is_enabled();
|
16
16
|
})
|
17
|
-
.
|
17
|
+
.define_singleton_function(
|
18
18
|
"_set_grad_enabled",
|
19
|
-
|
19
|
+
[](bool enabled) {
|
20
20
|
torch::GradMode::set_enabled(enabled);
|
21
21
|
})
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"manual_seed",
|
24
|
-
|
24
|
+
[](uint64_t seed) {
|
25
25
|
return torch::manual_seed(seed);
|
26
26
|
})
|
27
27
|
// config
|
28
|
-
.
|
28
|
+
.define_singleton_function(
|
29
29
|
"show_config",
|
30
|
-
|
30
|
+
[] {
|
31
31
|
return torch::show_config();
|
32
32
|
})
|
33
|
-
.
|
33
|
+
.define_singleton_function(
|
34
34
|
"parallel_info",
|
35
|
-
|
35
|
+
[] {
|
36
36
|
return torch::get_parallel_info();
|
37
37
|
})
|
38
38
|
// begin operations
|
39
|
-
.
|
39
|
+
.define_singleton_function(
|
40
40
|
"_save",
|
41
|
-
|
41
|
+
[](const torch::IValue &value) {
|
42
42
|
auto v = torch::pickle_save(value);
|
43
43
|
std::string str(v.begin(), v.end());
|
44
44
|
return str;
|
45
45
|
})
|
46
|
-
.
|
46
|
+
.define_singleton_function(
|
47
47
|
"_load",
|
48
|
-
|
48
|
+
[](const std::string &s) {
|
49
49
|
std::vector<char> v;
|
50
50
|
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
51
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
52
|
return torch::pickle_load(v);
|
53
53
|
})
|
54
|
-
.
|
54
|
+
.define_singleton_function(
|
55
55
|
"_from_blob",
|
56
|
-
|
56
|
+
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
57
|
void *data = const_cast<char *>(s.c_str());
|
58
58
|
return torch::from_blob(data, size, options);
|
59
59
|
})
|
60
|
-
.
|
60
|
+
.define_singleton_function(
|
61
61
|
"_tensor",
|
62
|
-
|
62
|
+
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
63
|
auto dtype = options.dtype();
|
64
64
|
torch::Tensor t;
|
65
65
|
if (dtype == torch::kBool) {
|
66
66
|
std::vector<uint8_t> vec;
|
67
67
|
for (long i = 0; i < a.size(); i++) {
|
68
|
-
vec.push_back(
|
68
|
+
vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
|
72
|
+
// TODO use template
|
73
|
+
std::vector<c10::complex<double>> vec;
|
74
|
+
Object obj;
|
75
|
+
for (long i = 0; i < a.size(); i++) {
|
76
|
+
obj = a[i];
|
77
|
+
vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
|
69
78
|
}
|
70
79
|
t = torch::tensor(vec, options);
|
71
80
|
} else {
|
72
81
|
std::vector<float> vec;
|
73
82
|
for (long i = 0; i < a.size(); i++) {
|
74
|
-
vec.push_back(
|
83
|
+
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
75
84
|
}
|
76
85
|
// hack for requires_grad error
|
77
86
|
if (options.requires_grad()) {
|
data/ext/torch/utils.h
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
#include <rice/stl.hpp>
|
5
5
|
|
6
6
|
// TODO find better place
|
7
|
-
inline void handle_error(torch::Error const & ex)
|
8
|
-
{
|
7
|
+
inline void handle_error(torch::Error const & ex) {
|
9
8
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
10
9
|
}
|
11
10
|
|
data/ext/torch/wrap_outputs.h
CHANGED
@@ -1,99 +1,106 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <torch/torch.h>
|
4
|
-
#include <rice/
|
4
|
+
#include <rice/rice.hpp>
|
5
5
|
|
6
|
-
inline
|
7
|
-
return
|
6
|
+
inline VALUE wrap(bool x) {
|
7
|
+
return Rice::detail::To_Ruby<bool>().convert(x);
|
8
8
|
}
|
9
9
|
|
10
|
-
inline
|
11
|
-
return
|
10
|
+
inline VALUE wrap(int64_t x) {
|
11
|
+
return Rice::detail::To_Ruby<int64_t>().convert(x);
|
12
12
|
}
|
13
13
|
|
14
|
-
inline
|
15
|
-
return
|
14
|
+
inline VALUE wrap(double x) {
|
15
|
+
return Rice::detail::To_Ruby<double>().convert(x);
|
16
16
|
}
|
17
17
|
|
18
|
-
inline
|
19
|
-
return
|
18
|
+
inline VALUE wrap(torch::Tensor x) {
|
19
|
+
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
20
20
|
}
|
21
21
|
|
22
|
-
inline
|
23
|
-
return
|
22
|
+
inline VALUE wrap(torch::Scalar x) {
|
23
|
+
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
24
24
|
}
|
25
25
|
|
26
|
-
inline
|
27
|
-
return
|
26
|
+
inline VALUE wrap(torch::ScalarType x) {
|
27
|
+
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
28
28
|
}
|
29
29
|
|
30
|
-
inline
|
31
|
-
return
|
30
|
+
inline VALUE wrap(torch::QScheme x) {
|
31
|
+
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
32
32
|
}
|
33
33
|
|
34
|
-
inline
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
34
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
35
|
+
return rb_ary_new3(
|
36
|
+
2,
|
37
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
38
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
39
|
+
);
|
39
40
|
}
|
40
41
|
|
41
|
-
inline
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
42
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
43
|
+
return rb_ary_new3(
|
44
|
+
3,
|
45
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
46
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
47
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
|
48
|
+
);
|
47
49
|
}
|
48
50
|
|
49
|
-
inline
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
51
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
52
|
+
return rb_ary_new3(
|
53
|
+
4,
|
54
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
55
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
56
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
57
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
|
58
|
+
);
|
56
59
|
}
|
57
60
|
|
58
|
-
inline
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
61
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
62
|
+
return rb_ary_new3(
|
63
|
+
5,
|
64
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
65
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
66
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
67
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
|
68
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
|
69
|
+
);
|
66
70
|
}
|
67
71
|
|
68
|
-
inline
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
72
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
73
|
+
return rb_ary_new3(
|
74
|
+
4,
|
75
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
76
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
77
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
78
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
79
|
+
);
|
75
80
|
}
|
76
81
|
|
77
|
-
inline
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
82
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
83
|
+
return rb_ary_new3(
|
84
|
+
4,
|
85
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
86
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
87
|
+
Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
|
88
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
89
|
+
);
|
84
90
|
}
|
85
91
|
|
86
|
-
inline
|
87
|
-
|
88
|
-
for (auto
|
89
|
-
a
|
92
|
+
inline VALUE wrap(torch::TensorList x) {
|
93
|
+
auto a = rb_ary_new2(x.size());
|
94
|
+
for (auto t : x) {
|
95
|
+
rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
90
96
|
}
|
91
|
-
return
|
97
|
+
return a;
|
92
98
|
}
|
93
99
|
|
94
|
-
inline
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
100
|
+
inline VALUE wrap(std::tuple<double, double> x) {
|
101
|
+
return rb_ary_new3(
|
102
|
+
2,
|
103
|
+
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
104
|
+
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|
105
|
+
);
|
99
106
|
}
|
data/lib/torch.rb
CHANGED
@@ -238,8 +238,11 @@ module Torch
|
|
238
238
|
double: 7,
|
239
239
|
float64: 7,
|
240
240
|
complex_half: 8,
|
241
|
+
complex32: 8,
|
241
242
|
complex_float: 9,
|
243
|
+
complex64: 9,
|
242
244
|
complex_double: 10,
|
245
|
+
complex128: 10,
|
243
246
|
bool: 11,
|
244
247
|
qint8: 12,
|
245
248
|
quint8: 13,
|
@@ -394,6 +397,8 @@ module Torch
|
|
394
397
|
options[:dtype] = :int64
|
395
398
|
elsif data.all? { |v| v == true || v == false }
|
396
399
|
options[:dtype] = :bool
|
400
|
+
elsif data.any? { |v| v.is_a?(Complex) }
|
401
|
+
options[:dtype] = :complex64
|
397
402
|
end
|
398
403
|
end
|
399
404
|
|
data/lib/torch/inspector.rb
CHANGED
@@ -96,8 +96,11 @@ module Torch
|
|
96
96
|
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
97
97
|
end
|
98
98
|
elsif @complex_dtype
|
99
|
-
|
100
|
-
|
99
|
+
# TODO use float formatter for each part
|
100
|
+
precision = PRINT_OPTS[:precision]
|
101
|
+
imag = value.imag
|
102
|
+
sign = imag >= 0 ? "+" : "-"
|
103
|
+
ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
|
101
104
|
else
|
102
105
|
ret = value.to_s
|
103
106
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.7.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-05-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version:
|
19
|
+
version: 4.0.2
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version:
|
26
|
+
version: 4.0.2
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|