torch-rb 0.9.1 → 0.10.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/README.md +3 -1
- data/codegen/function.rb +2 -2
- data/codegen/generate_functions.rb +27 -5
- data/codegen/native_functions.yaml +951 -362
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.h +26 -6
- data/ext/torch/sparse_functions.h +6 -0
- data/ext/torch/templates.h +34 -0
- data/ext/torch/tensor.cpp +25 -25
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +7 -0
- data/lib/torch/nn/parameter.rb +3 -0
- data/lib/torch/nn/parameter_list.rb +48 -0
- data/lib/torch/tensor.rb +3 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +16 -4
- metadata +5 -3
data/ext/torch/nn.cpp
CHANGED
@@ -98,8 +98,11 @@ void init_nn(Rice::Module& m) {
|
|
98
98
|
auto grad = self.grad();
|
99
99
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
|
+
// can't use grad=
|
102
|
+
// assignment methods fail with Ruby 3.0
|
103
|
+
// TODO add checks like Tensor
|
101
104
|
.define_method(
|
102
|
-
"
|
105
|
+
"_set_grad",
|
103
106
|
[](Parameter& self, torch::Tensor& grad) {
|
104
107
|
self.mutable_grad() = grad;
|
105
108
|
})
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -88,18 +88,18 @@ struct RubyArgs {
|
|
88
88
|
inline c10::optional<at::Generator> generator(int i);
|
89
89
|
inline at::Storage storage(int i);
|
90
90
|
inline at::ScalarType scalartype(int i);
|
91
|
-
|
91
|
+
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
92
92
|
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
93
93
|
inline c10::optional<at::Scalar> scalarOptional(int i);
|
94
94
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
95
95
|
inline c10::optional<bool> toBoolOptional(int i);
|
96
96
|
inline c10::optional<double> toDoubleOptional(int i);
|
97
97
|
inline c10::OptionalArray<double> doublelistOptional(int i);
|
98
|
-
|
99
|
-
|
98
|
+
inline at::Layout layout(int i);
|
99
|
+
inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
100
100
|
inline c10::optional<at::Layout> layoutOptional(int i);
|
101
101
|
inline at::Device device(int i);
|
102
|
-
|
102
|
+
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
103
103
|
// inline c10::optional<at::Device> deviceOptional(int i);
|
104
104
|
// inline at::Dimname dimname(int i);
|
105
105
|
// inline std::vector<at::Dimname> dimnamelist(int i);
|
@@ -240,6 +240,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
240
240
|
return it->second;
|
241
241
|
}
|
242
242
|
|
243
|
+
inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
244
|
+
if (NIL_P(args[i])) return default_scalartype;
|
245
|
+
return scalartype(i);
|
246
|
+
}
|
247
|
+
|
243
248
|
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
244
249
|
if (NIL_P(args[i])) return c10::nullopt;
|
245
250
|
return scalartype(i);
|
@@ -284,8 +289,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
284
289
|
return res;
|
285
290
|
}
|
286
291
|
|
287
|
-
inline
|
288
|
-
if (NIL_P(args[i])) return
|
292
|
+
inline at::Layout RubyArgs::layout(int i) {
|
293
|
+
if (NIL_P(args[i])) return signature.params[i].default_layout;
|
289
294
|
|
290
295
|
static std::unordered_map<VALUE, Layout> layout_map = {
|
291
296
|
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
@@ -298,6 +303,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
298
303
|
return it->second;
|
299
304
|
}
|
300
305
|
|
306
|
+
inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
|
307
|
+
if (NIL_P(args[i])) return default_layout;
|
308
|
+
return layout(i);
|
309
|
+
}
|
310
|
+
|
311
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
312
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
313
|
+
return layout(i);
|
314
|
+
}
|
315
|
+
|
301
316
|
inline at::Device RubyArgs::device(int i) {
|
302
317
|
if (NIL_P(args[i])) {
|
303
318
|
return at::Device("cpu");
|
@@ -306,6 +321,11 @@ inline at::Device RubyArgs::device(int i) {
|
|
306
321
|
return at::Device(device_str);
|
307
322
|
}
|
308
323
|
|
324
|
+
inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
325
|
+
if (NIL_P(args[i])) return default_device;
|
326
|
+
return device(i);
|
327
|
+
}
|
328
|
+
|
309
329
|
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
310
330
|
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
311
331
|
throw std::runtime_error("memoryformat not supported yet");
|
data/ext/torch/templates.h
CHANGED
@@ -41,6 +41,40 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
+
namespace Rice::detail
|
45
|
+
{
|
46
|
+
template<typename T>
|
47
|
+
struct Type<c10::complex<T>>
|
48
|
+
{
|
49
|
+
static bool verify()
|
50
|
+
{
|
51
|
+
return true;
|
52
|
+
}
|
53
|
+
};
|
54
|
+
|
55
|
+
template<typename T>
|
56
|
+
class To_Ruby<c10::complex<T>>
|
57
|
+
{
|
58
|
+
public:
|
59
|
+
VALUE convert(c10::complex<T> const& x)
|
60
|
+
{
|
61
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
62
|
+
}
|
63
|
+
};
|
64
|
+
|
65
|
+
template<typename T>
|
66
|
+
class From_Ruby<c10::complex<T>>
|
67
|
+
{
|
68
|
+
public:
|
69
|
+
c10::complex<T> convert(VALUE x)
|
70
|
+
{
|
71
|
+
VALUE real = rb_funcall(x, rb_intern("real"), 0);
|
72
|
+
VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
|
73
|
+
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
74
|
+
}
|
75
|
+
};
|
76
|
+
}
|
77
|
+
|
44
78
|
namespace Rice::detail
|
45
79
|
{
|
46
80
|
template<>
|
data/ext/torch/tensor.cpp
CHANGED
@@ -10,28 +10,6 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
using torch::indexing::TensorIndex;
|
12
12
|
|
13
|
-
namespace Rice::detail
|
14
|
-
{
|
15
|
-
template<typename T>
|
16
|
-
struct Type<c10::complex<T>>
|
17
|
-
{
|
18
|
-
static bool verify()
|
19
|
-
{
|
20
|
-
return true;
|
21
|
-
}
|
22
|
-
};
|
23
|
-
|
24
|
-
template<typename T>
|
25
|
-
class To_Ruby<c10::complex<T>>
|
26
|
-
{
|
27
|
-
public:
|
28
|
-
VALUE convert(c10::complex<T> const& x)
|
29
|
-
{
|
30
|
-
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
-
}
|
32
|
-
};
|
33
|
-
}
|
34
|
-
|
35
13
|
template<typename T>
|
36
14
|
Array flat_data(Tensor& tensor) {
|
37
15
|
Tensor view = tensor.reshape({tensor.numel()});
|
@@ -189,9 +167,31 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
189
167
|
auto grad = self.grad();
|
190
168
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
191
169
|
})
|
170
|
+
// can't use grad=
|
171
|
+
// assignment methods fail with Ruby 3.0
|
192
172
|
.define_method(
|
193
|
-
"
|
194
|
-
[](Tensor& self,
|
173
|
+
"_set_grad",
|
174
|
+
[](Tensor& self, Rice::Object value) {
|
175
|
+
if (value.is_nil()) {
|
176
|
+
self.mutable_grad().reset();
|
177
|
+
return;
|
178
|
+
}
|
179
|
+
|
180
|
+
const auto& grad = Rice::detail::From_Ruby<torch::Tensor>().convert(value.value());
|
181
|
+
|
182
|
+
// TODO support sparse grad
|
183
|
+
if (!grad.options().type_equal(self.options())) {
|
184
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different type");
|
185
|
+
}
|
186
|
+
|
187
|
+
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
188
|
+
rb_raise(rb_eArgError, "assigned grad has data located on a different device");
|
189
|
+
}
|
190
|
+
|
191
|
+
if (!self.sizes().equals(grad.sizes())) {
|
192
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different size");
|
193
|
+
}
|
194
|
+
|
195
195
|
self.mutable_grad() = grad;
|
196
196
|
})
|
197
197
|
.define_method(
|
@@ -281,7 +281,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
281
281
|
})
|
282
282
|
.define_method(
|
283
283
|
"_to",
|
284
|
-
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
284
|
+
[](Tensor& self, torch::Device& device, int dtype, bool non_blocking, bool copy) {
|
285
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
286
286
|
});
|
287
287
|
|
data/ext/torch/torch.cpp
CHANGED
@@ -6,6 +6,23 @@
|
|
6
6
|
#include "templates.h"
|
7
7
|
#include "utils.h"
|
8
8
|
|
9
|
+
template<typename T>
|
10
|
+
torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
11
|
+
std::vector<T> vec;
|
12
|
+
for (long i = 0; i < a.size(); i++) {
|
13
|
+
vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
|
14
|
+
}
|
15
|
+
|
16
|
+
// hack for requires_grad error
|
17
|
+
auto requires_grad = options.requires_grad();
|
18
|
+
torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
19
|
+
if (requires_grad) {
|
20
|
+
t.set_requires_grad(true);
|
21
|
+
}
|
22
|
+
|
23
|
+
return t.reshape(size);
|
24
|
+
}
|
25
|
+
|
9
26
|
void init_torch(Rice::Module& m) {
|
10
27
|
m.add_handler<torch::Error>(handle_error);
|
11
28
|
add_torch_functions(m);
|
@@ -61,35 +78,28 @@ void init_torch(Rice::Module& m) {
|
|
61
78
|
"_tensor",
|
62
79
|
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
80
|
auto dtype = options.dtype();
|
64
|
-
torch::
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
81
|
+
if (dtype == torch::kByte) {
|
82
|
+
return make_tensor<uint8_t>(a, size, options);
|
83
|
+
} else if (dtype == torch::kChar) {
|
84
|
+
return make_tensor<int8_t>(a, size, options);
|
85
|
+
} else if (dtype == torch::kShort) {
|
86
|
+
return make_tensor<int16_t>(a, size, options);
|
87
|
+
} else if (dtype == torch::kInt) {
|
88
|
+
return make_tensor<int32_t>(a, size, options);
|
89
|
+
} else if (dtype == torch::kLong) {
|
90
|
+
return make_tensor<int64_t>(a, size, options);
|
91
|
+
} else if (dtype == torch::kFloat) {
|
92
|
+
return make_tensor<float>(a, size, options);
|
93
|
+
} else if (dtype == torch::kDouble) {
|
94
|
+
return make_tensor<double>(a, size, options);
|
95
|
+
} else if (dtype == torch::kBool) {
|
96
|
+
return make_tensor<uint8_t>(a, size, options);
|
97
|
+
} else if (dtype == torch::kComplexFloat) {
|
98
|
+
return make_tensor<c10::complex<float>>(a, size, options);
|
99
|
+
} else if (dtype == torch::kComplexDouble) {
|
100
|
+
return make_tensor<c10::complex<double>>(a, size, options);
|
80
101
|
} else {
|
81
|
-
std::
|
82
|
-
for (long i = 0; i < a.size(); i++) {
|
83
|
-
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
84
|
-
}
|
85
|
-
// hack for requires_grad error
|
86
|
-
if (options.requires_grad()) {
|
87
|
-
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
88
|
-
t.set_requires_grad(true);
|
89
|
-
} else {
|
90
|
-
t = torch::tensor(vec, options);
|
91
|
-
}
|
102
|
+
throw std::runtime_error("Unsupported type");
|
92
103
|
}
|
93
|
-
return t.reshape(size);
|
94
104
|
});
|
95
105
|
}
|
data/ext/torch/utils.h
CHANGED
@@ -1,8 +1,15 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <torch/torch.h>
|
4
|
+
|
3
5
|
#include <rice/rice.hpp>
|
4
6
|
#include <rice/stl.hpp>
|
5
7
|
|
8
|
+
static_assert(
|
9
|
+
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
|
10
|
+
"Incompatible LibTorch version"
|
11
|
+
);
|
12
|
+
|
6
13
|
// TODO find better place
|
7
14
|
inline void handle_error(torch::Error const & ex) {
|
8
15
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,48 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ParameterList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(parameters)
|
7
|
+
super()
|
8
|
+
@initialized = true
|
9
|
+
unless parameters.nil?
|
10
|
+
concat(parameters)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@parameters.length
|
16
|
+
end
|
17
|
+
alias_method :count, :length
|
18
|
+
alias_method :size, :length
|
19
|
+
|
20
|
+
def concat(parameters)
|
21
|
+
unless parameters.is_a?(Enumerable)
|
22
|
+
raise TypeError, "ParameterList#concat should be called with an enumerable, but got #{parameters.class.name}"
|
23
|
+
end
|
24
|
+
offset = length
|
25
|
+
parameters.each_with_index do |param, i|
|
26
|
+
register_parameter((offset + i).to_s, param)
|
27
|
+
end
|
28
|
+
self
|
29
|
+
end
|
30
|
+
|
31
|
+
def each(&block)
|
32
|
+
if block_given?
|
33
|
+
@parameters.values.each(&block)
|
34
|
+
else
|
35
|
+
to_enum(:each)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def [](idx)
|
40
|
+
if idx.is_a?(Range)
|
41
|
+
self.class.new(@parameters.values[idx])
|
42
|
+
else
|
43
|
+
@parameters[idx.to_s]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -40,6 +40,7 @@ require "torch/nn/utils"
|
|
40
40
|
# nn containers
|
41
41
|
require "torch/nn/module"
|
42
42
|
require "torch/nn/module_list"
|
43
|
+
require "torch/nn/parameter_list"
|
43
44
|
require "torch/nn/sequential"
|
44
45
|
|
45
46
|
# nn convolution layers
|
@@ -267,7 +268,7 @@ module Torch
|
|
267
268
|
args.first.send(dtype).to(device)
|
268
269
|
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
269
270
|
bytes = args.first.bytes
|
270
|
-
Torch.
|
271
|
+
Torch._from_blob_ref(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
271
272
|
elsif args.size == 1 && args.first.is_a?(Array)
|
272
273
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
273
274
|
elsif args.size == 0
|
@@ -320,12 +321,17 @@ module Torch
|
|
320
321
|
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
|
321
322
|
options = tensor_options(device: "cpu", dtype: dtype[0])
|
322
323
|
# TODO pass pointer to array instead of creating string
|
323
|
-
|
324
|
-
|
324
|
+
_from_blob_ref(ndarray.to_string, ndarray.shape, options)
|
325
|
+
end
|
326
|
+
|
327
|
+
# private
|
328
|
+
# TODO use keepAlive in Rice (currently segfaults)
|
329
|
+
def _from_blob_ref(data, size, options)
|
330
|
+
tensor = _from_blob(data, size, options)
|
325
331
|
# from_blob does not own the data, so we need to keep
|
326
332
|
# a reference to it for duration of tensor
|
327
333
|
# can remove when passing pointer directly
|
328
|
-
tensor.instance_variable_set("@
|
334
|
+
tensor.instance_variable_set("@_numo_data", data)
|
329
335
|
tensor
|
330
336
|
end
|
331
337
|
|
@@ -406,6 +412,12 @@ module Torch
|
|
406
412
|
end
|
407
413
|
end
|
408
414
|
|
415
|
+
# TODO check each dimensions for consistency in future
|
416
|
+
raise Error, "Inconsistent dimensions" if data.size != size.inject(1, :*)
|
417
|
+
|
418
|
+
# TOOD move to C++
|
419
|
+
data = data.map { |v| v ? 1 : 0 } if options[:dtype] == :bool
|
420
|
+
|
409
421
|
_tensor(data, size, tensor_options(**options))
|
410
422
|
end
|
411
423
|
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.10.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-04-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -52,6 +52,7 @@ files:
|
|
52
52
|
- ext/torch/random.cpp
|
53
53
|
- ext/torch/ruby_arg_parser.cpp
|
54
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/sparse_functions.h
|
55
56
|
- ext/torch/special.cpp
|
56
57
|
- ext/torch/special_functions.h
|
57
58
|
- ext/torch/templates.h
|
@@ -149,6 +150,7 @@ files:
|
|
149
150
|
- lib/torch/nn/nll_loss.rb
|
150
151
|
- lib/torch/nn/pairwise_distance.rb
|
151
152
|
- lib/torch/nn/parameter.rb
|
153
|
+
- lib/torch/nn/parameter_list.rb
|
152
154
|
- lib/torch/nn/poisson_nll_loss.rb
|
153
155
|
- lib/torch/nn/prelu.rb
|
154
156
|
- lib/torch/nn/reflection_pad1d.rb
|
@@ -227,7 +229,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
229
|
- !ruby/object:Gem::Version
|
228
230
|
version: '0'
|
229
231
|
requirements: []
|
230
|
-
rubygems_version: 3.3.
|
232
|
+
rubygems_version: 3.3.7
|
231
233
|
signing_key:
|
232
234
|
specification_version: 4
|
233
235
|
summary: Deep learning for Ruby, powered by LibTorch
|