torch-rb 0.10.2 → 0.11.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +2 -1
- data/codegen/generate_functions.rb +24 -7
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.h +11 -3
- data/ext/torch/utils.h +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -0
- metadata +3 -3
data/ext/torch/extconf.rb
CHANGED
@@ -7,21 +7,9 @@ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
|
7
7
|
|
8
8
|
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
9
9
|
|
10
|
-
# check omp first
|
11
|
-
if have_library("omp") || have_library("gomp")
|
12
|
-
$CXXFLAGS += " -Xclang" if apple_clang
|
13
|
-
$CXXFLAGS += " -fopenmp"
|
14
|
-
end
|
15
|
-
|
16
10
|
if apple_clang
|
17
|
-
# silence rice warnings
|
18
|
-
$CXXFLAGS += " -Wno-deprecated-declarations"
|
19
|
-
|
20
|
-
# silence ruby/intern.h warning
|
21
|
-
$CXXFLAGS += " -Wno-deprecated-register"
|
22
|
-
|
23
11
|
# silence torch warnings
|
24
|
-
$CXXFLAGS += " -Wno-
|
12
|
+
$CXXFLAGS += " -Wno-deprecated-declarations"
|
25
13
|
else
|
26
14
|
# silence rice warnings
|
27
15
|
$CXXFLAGS += " -Wno-noexcept-type"
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -83,8 +83,8 @@ struct RubyArgs {
|
|
83
83
|
template<int N>
|
84
84
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
85
85
|
inline std::vector<int64_t> intlist(int i);
|
86
|
-
|
87
|
-
|
86
|
+
inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
87
|
+
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
88
88
|
inline c10::optional<at::Generator> generator(int i);
|
89
89
|
inline at::Storage storage(int i);
|
90
90
|
inline at::ScalarType scalartype(int i);
|
@@ -166,8 +166,11 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
166
166
|
}
|
167
167
|
|
168
168
|
inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
169
|
-
|
169
|
+
return intlistWithDefault(i, signature.params[i].default_intlist);
|
170
|
+
}
|
170
171
|
|
172
|
+
inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
|
173
|
+
if (NIL_P(args[i])) return default_intlist;
|
171
174
|
VALUE arg = args[i];
|
172
175
|
auto size = signature.params[i].size;
|
173
176
|
if (size > 0 && FIXNUM_P(arg)) {
|
@@ -189,6 +192,11 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
189
192
|
return res;
|
190
193
|
}
|
191
194
|
|
195
|
+
inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
|
196
|
+
if (NIL_P(args[i])) return {};
|
197
|
+
return intlist(i);
|
198
|
+
}
|
199
|
+
|
192
200
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
193
201
|
if (NIL_P(args[i])) return c10::nullopt;
|
194
202
|
throw std::runtime_error("generator not supported yet");
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.11.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-06
|
11
|
+
date: 2022-07-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -222,7 +222,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
222
222
|
requirements:
|
223
223
|
- - ">="
|
224
224
|
- !ruby/object:Gem::Version
|
225
|
-
version: '2.
|
225
|
+
version: '2.7'
|
226
226
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
227
227
|
requirements:
|
228
228
|
- - ">="
|