torch-rb 0.10.0 → 0.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/README.md +16 -3
- data/codegen/function.rb +1 -1
- data/codegen/generate_functions.rb +46 -11
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.h +37 -9
- data/ext/torch/utils.h +7 -0
- data/lib/torch/tensor.rb +8 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +3 -3
data/ext/torch/extconf.rb
CHANGED
@@ -7,21 +7,9 @@ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
|
7
7
|
|
8
8
|
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
9
9
|
|
10
|
-
# check omp first
|
11
|
-
if have_library("omp") || have_library("gomp")
|
12
|
-
$CXXFLAGS += " -Xclang" if apple_clang
|
13
|
-
$CXXFLAGS += " -fopenmp"
|
14
|
-
end
|
15
|
-
|
16
10
|
if apple_clang
|
17
|
-
# silence rice warnings
|
18
|
-
$CXXFLAGS += " -Wno-deprecated-declarations"
|
19
|
-
|
20
|
-
# silence ruby/intern.h warning
|
21
|
-
$CXXFLAGS += " -Wno-deprecated-register"
|
22
|
-
|
23
11
|
# silence torch warnings
|
24
|
-
$CXXFLAGS += " -Wno-
|
12
|
+
$CXXFLAGS += " -Wno-deprecated-declarations"
|
25
13
|
else
|
26
14
|
# silence rice warnings
|
27
15
|
$CXXFLAGS += " -Wno-noexcept-type"
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -83,23 +83,23 @@ struct RubyArgs {
|
|
83
83
|
template<int N>
|
84
84
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
85
85
|
inline std::vector<int64_t> intlist(int i);
|
86
|
-
|
87
|
-
|
86
|
+
inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
87
|
+
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
88
88
|
inline c10::optional<at::Generator> generator(int i);
|
89
89
|
inline at::Storage storage(int i);
|
90
90
|
inline at::ScalarType scalartype(int i);
|
91
|
-
|
91
|
+
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
92
92
|
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
93
93
|
inline c10::optional<at::Scalar> scalarOptional(int i);
|
94
94
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
95
95
|
inline c10::optional<bool> toBoolOptional(int i);
|
96
96
|
inline c10::optional<double> toDoubleOptional(int i);
|
97
97
|
inline c10::OptionalArray<double> doublelistOptional(int i);
|
98
|
-
|
99
|
-
|
98
|
+
inline at::Layout layout(int i);
|
99
|
+
inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
100
100
|
inline c10::optional<at::Layout> layoutOptional(int i);
|
101
101
|
inline at::Device device(int i);
|
102
|
-
|
102
|
+
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
103
103
|
// inline c10::optional<at::Device> deviceOptional(int i);
|
104
104
|
// inline at::Dimname dimname(int i);
|
105
105
|
// inline std::vector<at::Dimname> dimnamelist(int i);
|
@@ -166,8 +166,11 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
166
166
|
}
|
167
167
|
|
168
168
|
inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
169
|
-
|
169
|
+
return intlistWithDefault(i, signature.params[i].default_intlist);
|
170
|
+
}
|
170
171
|
|
172
|
+
inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
|
173
|
+
if (NIL_P(args[i])) return default_intlist;
|
171
174
|
VALUE arg = args[i];
|
172
175
|
auto size = signature.params[i].size;
|
173
176
|
if (size > 0 && FIXNUM_P(arg)) {
|
@@ -189,6 +192,11 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
189
192
|
return res;
|
190
193
|
}
|
191
194
|
|
195
|
+
inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
|
196
|
+
if (NIL_P(args[i])) return {};
|
197
|
+
return intlist(i);
|
198
|
+
}
|
199
|
+
|
192
200
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
193
201
|
if (NIL_P(args[i])) return c10::nullopt;
|
194
202
|
throw std::runtime_error("generator not supported yet");
|
@@ -240,6 +248,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
240
248
|
return it->second;
|
241
249
|
}
|
242
250
|
|
251
|
+
inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
252
|
+
if (NIL_P(args[i])) return default_scalartype;
|
253
|
+
return scalartype(i);
|
254
|
+
}
|
255
|
+
|
243
256
|
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
244
257
|
if (NIL_P(args[i])) return c10::nullopt;
|
245
258
|
return scalartype(i);
|
@@ -284,8 +297,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
284
297
|
return res;
|
285
298
|
}
|
286
299
|
|
287
|
-
inline
|
288
|
-
if (NIL_P(args[i])) return
|
300
|
+
inline at::Layout RubyArgs::layout(int i) {
|
301
|
+
if (NIL_P(args[i])) return signature.params[i].default_layout;
|
289
302
|
|
290
303
|
static std::unordered_map<VALUE, Layout> layout_map = {
|
291
304
|
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
@@ -298,6 +311,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
298
311
|
return it->second;
|
299
312
|
}
|
300
313
|
|
314
|
+
inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
|
315
|
+
if (NIL_P(args[i])) return default_layout;
|
316
|
+
return layout(i);
|
317
|
+
}
|
318
|
+
|
319
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
320
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
321
|
+
return layout(i);
|
322
|
+
}
|
323
|
+
|
301
324
|
inline at::Device RubyArgs::device(int i) {
|
302
325
|
if (NIL_P(args[i])) {
|
303
326
|
return at::Device("cpu");
|
@@ -306,6 +329,11 @@ inline at::Device RubyArgs::device(int i) {
|
|
306
329
|
return at::Device(device_str);
|
307
330
|
}
|
308
331
|
|
332
|
+
inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
333
|
+
if (NIL_P(args[i])) return default_device;
|
334
|
+
return device(i);
|
335
|
+
}
|
336
|
+
|
309
337
|
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
310
338
|
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
311
339
|
throw std::runtime_error("memoryformat not supported yet");
|
data/ext/torch/utils.h
CHANGED
@@ -1,8 +1,15 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <torch/torch.h>
|
4
|
+
|
3
5
|
#include <rice/rice.hpp>
|
4
6
|
#include <rice/stl.hpp>
|
5
7
|
|
8
|
+
static_assert(
|
9
|
+
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 12,
|
10
|
+
"Incompatible LibTorch version"
|
11
|
+
);
|
12
|
+
|
6
13
|
// TODO find better place
|
7
14
|
inline void handle_error(torch::Error const & ex) {
|
8
15
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
data/lib/torch/tensor.rb
CHANGED
@@ -199,5 +199,13 @@ module Torch
|
|
199
199
|
def real
|
200
200
|
Torch.real(self)
|
201
201
|
end
|
202
|
+
|
203
|
+
def coerce(other)
|
204
|
+
if other.is_a?(Numeric)
|
205
|
+
[Torch.tensor(other), self]
|
206
|
+
else
|
207
|
+
raise TypeError, "#{self.class} can't be coerced into #{other.class}"
|
208
|
+
end
|
209
|
+
end
|
202
210
|
end
|
203
211
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.11.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-07-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -222,7 +222,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
222
222
|
requirements:
|
223
223
|
- - ">="
|
224
224
|
- !ruby/object:Gem::Version
|
225
|
-
version: '2.
|
225
|
+
version: '2.7'
|
226
226
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
227
227
|
requirements:
|
228
228
|
- - ">="
|