torch-rb 0.9.2 → 0.10.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/README.md +17 -3
- data/codegen/function.rb +3 -3
- data/codegen/generate_functions.rb +27 -5
- data/codegen/native_functions.yaml +951 -362
- data/ext/torch/ruby_arg_parser.h +26 -6
- data/ext/torch/sparse_functions.h +6 -0
- data/ext/torch/utils.h +7 -0
- data/lib/torch/nn/parameter_list.rb +48 -0
- data/lib/torch/tensor.rb +8 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +5 -3
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -88,18 +88,18 @@ struct RubyArgs {
|
|
88
88
|
inline c10::optional<at::Generator> generator(int i);
|
89
89
|
inline at::Storage storage(int i);
|
90
90
|
inline at::ScalarType scalartype(int i);
|
91
|
-
|
91
|
+
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
92
92
|
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
93
93
|
inline c10::optional<at::Scalar> scalarOptional(int i);
|
94
94
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
95
95
|
inline c10::optional<bool> toBoolOptional(int i);
|
96
96
|
inline c10::optional<double> toDoubleOptional(int i);
|
97
97
|
inline c10::OptionalArray<double> doublelistOptional(int i);
|
98
|
-
|
99
|
-
|
98
|
+
inline at::Layout layout(int i);
|
99
|
+
inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
100
100
|
inline c10::optional<at::Layout> layoutOptional(int i);
|
101
101
|
inline at::Device device(int i);
|
102
|
-
|
102
|
+
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
103
103
|
// inline c10::optional<at::Device> deviceOptional(int i);
|
104
104
|
// inline at::Dimname dimname(int i);
|
105
105
|
// inline std::vector<at::Dimname> dimnamelist(int i);
|
@@ -240,6 +240,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
240
240
|
return it->second;
|
241
241
|
}
|
242
242
|
|
243
|
+
inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
244
|
+
if (NIL_P(args[i])) return default_scalartype;
|
245
|
+
return scalartype(i);
|
246
|
+
}
|
247
|
+
|
243
248
|
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
244
249
|
if (NIL_P(args[i])) return c10::nullopt;
|
245
250
|
return scalartype(i);
|
@@ -284,8 +289,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
284
289
|
return res;
|
285
290
|
}
|
286
291
|
|
287
|
-
inline
|
288
|
-
if (NIL_P(args[i])) return
|
292
|
+
inline at::Layout RubyArgs::layout(int i) {
|
293
|
+
if (NIL_P(args[i])) return signature.params[i].default_layout;
|
289
294
|
|
290
295
|
static std::unordered_map<VALUE, Layout> layout_map = {
|
291
296
|
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
@@ -298,6 +303,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
298
303
|
return it->second;
|
299
304
|
}
|
300
305
|
|
306
|
+
inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
|
307
|
+
if (NIL_P(args[i])) return default_layout;
|
308
|
+
return layout(i);
|
309
|
+
}
|
310
|
+
|
311
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
312
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
313
|
+
return layout(i);
|
314
|
+
}
|
315
|
+
|
301
316
|
inline at::Device RubyArgs::device(int i) {
|
302
317
|
if (NIL_P(args[i])) {
|
303
318
|
return at::Device("cpu");
|
@@ -306,6 +321,11 @@ inline at::Device RubyArgs::device(int i) {
|
|
306
321
|
return at::Device(device_str);
|
307
322
|
}
|
308
323
|
|
324
|
+
inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
325
|
+
if (NIL_P(args[i])) return default_device;
|
326
|
+
return device(i);
|
327
|
+
}
|
328
|
+
|
309
329
|
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
310
330
|
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
311
331
|
throw std::runtime_error("memoryformat not supported yet");
|
data/ext/torch/utils.h
CHANGED
@@ -1,8 +1,15 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <torch/torch.h>
|
4
|
+
|
3
5
|
#include <rice/rice.hpp>
|
4
6
|
#include <rice/stl.hpp>
|
5
7
|
|
8
|
+
static_assert(
|
9
|
+
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
|
10
|
+
"Incompatible LibTorch version"
|
11
|
+
);
|
12
|
+
|
6
13
|
// TODO find better place
|
7
14
|
inline void handle_error(torch::Error const & ex) {
|
8
15
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ParameterList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(parameters)
|
7
|
+
super()
|
8
|
+
@initialized = true
|
9
|
+
unless parameters.nil?
|
10
|
+
concat(parameters)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
def length
|
15
|
+
@parameters.length
|
16
|
+
end
|
17
|
+
alias_method :count, :length
|
18
|
+
alias_method :size, :length
|
19
|
+
|
20
|
+
def concat(parameters)
|
21
|
+
unless parameters.is_a?(Enumerable)
|
22
|
+
raise TypeError, "ParameterList#concat should be called with an enumerable, but got #{parameters.class.name}"
|
23
|
+
end
|
24
|
+
offset = length
|
25
|
+
parameters.each_with_index do |param, i|
|
26
|
+
register_parameter((offset + i).to_s, param)
|
27
|
+
end
|
28
|
+
self
|
29
|
+
end
|
30
|
+
|
31
|
+
def each(&block)
|
32
|
+
if block_given?
|
33
|
+
@parameters.values.each(&block)
|
34
|
+
else
|
35
|
+
to_enum(:each)
|
36
|
+
end
|
37
|
+
end
|
38
|
+
|
39
|
+
def [](idx)
|
40
|
+
if idx.is_a?(Range)
|
41
|
+
self.class.new(@parameters.values[idx])
|
42
|
+
else
|
43
|
+
@parameters[idx.to_s]
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -199,5 +199,13 @@ module Torch
|
|
199
199
|
def real
|
200
200
|
Torch.real(self)
|
201
201
|
end
|
202
|
+
|
203
|
+
def coerce(other)
|
204
|
+
if other.is_a?(Numeric)
|
205
|
+
[Torch.tensor(other), self]
|
206
|
+
else
|
207
|
+
raise TypeError, "#{self.class} can't be coerced into #{other.class}"
|
208
|
+
end
|
209
|
+
end
|
202
210
|
end
|
203
211
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.10.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-06-14 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -52,6 +52,7 @@ files:
|
|
52
52
|
- ext/torch/random.cpp
|
53
53
|
- ext/torch/ruby_arg_parser.cpp
|
54
54
|
- ext/torch/ruby_arg_parser.h
|
55
|
+
- ext/torch/sparse_functions.h
|
55
56
|
- ext/torch/special.cpp
|
56
57
|
- ext/torch/special_functions.h
|
57
58
|
- ext/torch/templates.h
|
@@ -149,6 +150,7 @@ files:
|
|
149
150
|
- lib/torch/nn/nll_loss.rb
|
150
151
|
- lib/torch/nn/pairwise_distance.rb
|
151
152
|
- lib/torch/nn/parameter.rb
|
153
|
+
- lib/torch/nn/parameter_list.rb
|
152
154
|
- lib/torch/nn/poisson_nll_loss.rb
|
153
155
|
- lib/torch/nn/prelu.rb
|
154
156
|
- lib/torch/nn/reflection_pad1d.rb
|
@@ -227,7 +229,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
227
229
|
- !ruby/object:Gem::Version
|
228
230
|
version: '0'
|
229
231
|
requirements: []
|
230
|
-
rubygems_version: 3.3.
|
232
|
+
rubygems_version: 3.3.7
|
231
233
|
signing_key:
|
232
234
|
specification_version: 4
|
233
235
|
summary: Deep learning for Ruby, powered by LibTorch
|