torch-rb 0.10.2 → 0.11.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +24 -7
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.h +11 -3
- data/ext/torch/utils.h +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +3 -3
data/ext/torch/extconf.rb
CHANGED
@@ -7,21 +7,9 @@ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
|
7
7
|
|
8
8
|
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
9
9
|
|
10
|
-
# check omp first
|
11
|
-
if have_library("omp") || have_library("gomp")
|
12
|
-
$CXXFLAGS += " -Xclang" if apple_clang
|
13
|
-
$CXXFLAGS += " -fopenmp"
|
14
|
-
end
|
15
|
-
|
16
10
|
if apple_clang
|
17
|
-
# silence rice warnings
|
18
|
-
$CXXFLAGS += " -Wno-deprecated-declarations"
|
19
|
-
|
20
|
-
# silence ruby/intern.h warning
|
21
|
-
$CXXFLAGS += " -Wno-deprecated-register"
|
22
|
-
|
23
11
|
# silence torch warnings
|
24
|
-
$CXXFLAGS += " -Wno-
|
12
|
+
$CXXFLAGS += " -Wno-deprecated-declarations"
|
25
13
|
else
|
26
14
|
# silence rice warnings
|
27
15
|
$CXXFLAGS += " -Wno-noexcept-type"
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -83,8 +83,8 @@ struct RubyArgs {
|
|
83
83
|
template<int N>
|
84
84
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
85
85
|
inline std::vector<int64_t> intlist(int i);
|
86
|
-
|
87
|
-
|
86
|
+
inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
87
|
+
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
88
88
|
inline c10::optional<at::Generator> generator(int i);
|
89
89
|
inline at::Storage storage(int i);
|
90
90
|
inline at::ScalarType scalartype(int i);
|
@@ -166,8 +166,11 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
166
166
|
}
|
167
167
|
|
168
168
|
inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
169
|
-
|
169
|
+
return intlistWithDefault(i, signature.params[i].default_intlist);
|
170
|
+
}
|
170
171
|
|
172
|
+
inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
|
173
|
+
if (NIL_P(args[i])) return default_intlist;
|
171
174
|
VALUE arg = args[i];
|
172
175
|
auto size = signature.params[i].size;
|
173
176
|
if (size > 0 && FIXNUM_P(arg)) {
|
@@ -189,6 +192,11 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
189
192
|
return res;
|
190
193
|
}
|
191
194
|
|
195
|
+
inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
|
196
|
+
if (NIL_P(args[i])) return {};
|
197
|
+
return intlist(i);
|
198
|
+
}
|
199
|
+
|
192
200
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
193
201
|
if (NIL_P(args[i])) return c10::nullopt;
|
194
202
|
throw std::runtime_error("generator not supported yet");
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.11.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-06
|
11
|
+
date: 2022-07-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -222,7 +222,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
222
222
|
requirements:
|
223
223
|
- - ">="
|
224
224
|
- !ruby/object:Gem::Version
|
225
|
-
version: '2.
|
225
|
+
version: '2.7'
|
226
226
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
227
227
|
requirements:
|
228
228
|
- - ">="
|