torch-rb 0.8.3 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/ruby_arg_parser.h +17 -3
- data/ext/torch/templates.h +0 -37
- data/ext/torch/tensor.cpp +8 -8
- data/lib/torch/tensor.rb +12 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -75,7 +75,7 @@ struct RubyArgs {
|
|
75
75
|
int idx;
|
76
76
|
|
77
77
|
inline at::Tensor tensor(int i);
|
78
|
-
inline
|
78
|
+
inline c10::optional<at::Tensor> optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
81
|
inline std::vector<at::Scalar> scalarlist(int i);
|
@@ -109,6 +109,9 @@ struct RubyArgs {
|
|
109
109
|
// inline at::QScheme toQScheme(int i);
|
110
110
|
inline std::string string(int i);
|
111
111
|
inline c10::optional<std::string> stringOptional(int i);
|
112
|
+
inline c10::string_view stringView(int i);
|
113
|
+
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
114
|
+
inline c10::optional<c10::string_view> stringViewOptional(int i);
|
112
115
|
// inline PyObject* pyobject(int i);
|
113
116
|
inline int64_t toInt64(int i);
|
114
117
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
|
|
125
128
|
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
126
129
|
}
|
127
130
|
|
128
|
-
inline
|
129
|
-
if (NIL_P(args[i])) return
|
131
|
+
inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
|
132
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
130
133
|
return tensor(i);
|
131
134
|
}
|
132
135
|
|
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
322
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
323
326
|
}
|
324
327
|
|
328
|
+
inline c10::string_view RubyArgs::stringView(int i) {
|
329
|
+
auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
330
|
+
return c10::string_view(str.data(), str.size());
|
331
|
+
}
|
332
|
+
|
333
|
+
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
334
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
335
|
+
auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
336
|
+
return c10::string_view(str.data(), str.size());
|
337
|
+
}
|
338
|
+
|
325
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
326
340
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
327
341
|
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
data/ext/torch/templates.h
CHANGED
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
-
class OptionalTensor {
|
45
|
-
torch::Tensor value;
|
46
|
-
public:
|
47
|
-
OptionalTensor(Object o) {
|
48
|
-
if (o.is_nil()) {
|
49
|
-
value = {};
|
50
|
-
} else {
|
51
|
-
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
52
|
-
}
|
53
|
-
}
|
54
|
-
OptionalTensor(torch::Tensor o) {
|
55
|
-
value = o;
|
56
|
-
}
|
57
|
-
operator torch::Tensor() const {
|
58
|
-
return value;
|
59
|
-
}
|
60
|
-
};
|
61
|
-
|
62
44
|
namespace Rice::detail
|
63
45
|
{
|
64
46
|
template<>
|
@@ -131,25 +113,6 @@ namespace Rice::detail
|
|
131
113
|
}
|
132
114
|
};
|
133
115
|
|
134
|
-
template<>
|
135
|
-
struct Type<OptionalTensor>
|
136
|
-
{
|
137
|
-
static bool verify()
|
138
|
-
{
|
139
|
-
return true;
|
140
|
-
}
|
141
|
-
};
|
142
|
-
|
143
|
-
template<>
|
144
|
-
class From_Ruby<OptionalTensor>
|
145
|
-
{
|
146
|
-
public:
|
147
|
-
OptionalTensor convert(VALUE x)
|
148
|
-
{
|
149
|
-
return OptionalTensor(x);
|
150
|
-
}
|
151
|
-
};
|
152
|
-
|
153
116
|
template<>
|
154
117
|
struct Type<Scalar>
|
155
118
|
{
|
data/ext/torch/tensor.cpp
CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
107
107
|
ParsedArgs<4> parsed_args;
|
108
108
|
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
109
109
|
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
110
|
-
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const
|
110
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
111
111
|
// in future, release GVL
|
112
112
|
self._backward(inputs, gradient, retain_graph, create_graph);
|
113
113
|
};
|
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
125
125
|
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
126
126
|
|
127
127
|
rb_cTensor
|
128
|
-
.define_method("cuda?", &
|
129
|
-
.define_method("sparse?", &
|
130
|
-
.define_method("quantized?", &
|
131
|
-
.define_method("dim", &
|
132
|
-
.define_method("numel", &
|
133
|
-
.define_method("element_size", &
|
134
|
-
.define_method("requires_grad", &
|
128
|
+
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
129
|
+
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
130
|
+
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
131
|
+
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
132
|
+
.define_method("numel", [](Tensor& self) { return self.numel(); })
|
133
|
+
.define_method("element_size", [](Tensor& self) { return self.element_size(); })
|
134
|
+
.define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
|
135
135
|
.define_method(
|
136
136
|
"_size",
|
137
137
|
[](Tensor& self, int64_t dim) {
|
data/lib/torch/tensor.rb
CHANGED
@@ -191,5 +191,17 @@ module Torch
|
|
191
191
|
clone
|
192
192
|
end
|
193
193
|
end
|
194
|
+
|
195
|
+
# not a method in native_functions.yaml
|
196
|
+
# attribute in Python rather than method
|
197
|
+
def imag
|
198
|
+
Torch.imag(self)
|
199
|
+
end
|
200
|
+
|
201
|
+
# not a method in native_functions.yaml
|
202
|
+
# attribute in Python rather than method
|
203
|
+
def real
|
204
|
+
Torch.real(self)
|
205
|
+
end
|
194
206
|
end
|
195
207
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-10-
|
11
|
+
date: 2021-10-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|