torch-rb 0.10.0 → 0.10.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/codegen/generate_functions.rb +22 -4
- data/ext/torch/ruby_arg_parser.h +26 -6
- data/ext/torch/utils.h +7 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f4665eec43d85fbf02ce75f4b268dbf001bfad7e3ae1ecace0e9911b651e2cc2
|
4
|
+
data.tar.gz: d11ee1386ce7feeea68333de6c361d8737a7164cfd1626abce3a511deecb2963
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cf346bc03f36d4fc920151b0554c93c33a59d0fecd35c6f110dc862ad8b35c6b8641d306124505ddd012ce9d903363672e99772cc2bd7b981d962c9d00f08d3e
|
7
|
+
data.tar.gz: 265157846417fdc3c024e0f50d0b0a663d345ca423ad9233a7d0e722bf5134a8e15e1afce30248992155787484e80311e484c64c5787945b5f5a88625115479a
|
data/CHANGELOG.md
CHANGED
@@ -293,7 +293,13 @@ def split_opt_params(params)
|
|
293
293
|
end
|
294
294
|
|
295
295
|
def generate_tensor_options(function, opt_params)
|
296
|
-
|
296
|
+
new_function = function.base_name.start_with?("new_")
|
297
|
+
like_function = function.base_name.end_with?("_like")
|
298
|
+
|
299
|
+
code = String.new("")
|
300
|
+
code << "\n auto self = _r.tensor(0);" if like_function
|
301
|
+
code << "\n const auto options = TensorOptions()"
|
302
|
+
|
297
303
|
order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
|
298
304
|
opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
|
299
305
|
i = opt[:position]
|
@@ -304,12 +310,24 @@ def generate_tensor_options(function, opt_params)
|
|
304
310
|
if function.base_name == "arange"
|
305
311
|
"dtype(_r.scalartypeOptional(#{i}))"
|
306
312
|
else
|
307
|
-
|
313
|
+
if new_function || like_function
|
314
|
+
"dtype(_r.scalartypeWithDefault(#{i}, self.scalar_type()))"
|
315
|
+
else
|
316
|
+
"dtype(_r.scalartype(#{i}))"
|
317
|
+
end
|
308
318
|
end
|
309
319
|
when "device"
|
310
|
-
|
320
|
+
if new_function || like_function
|
321
|
+
"device(_r.deviceWithDefault(#{i}, self.device()))"
|
322
|
+
else
|
323
|
+
"device(_r.device(#{i}))"
|
324
|
+
end
|
311
325
|
when "layout"
|
312
|
-
|
326
|
+
if new_function || like_function
|
327
|
+
"layout(_r.layoutWithDefault(#{i}, self.layout()))"
|
328
|
+
else
|
329
|
+
"layout(_r.layoutOptional(#{i}))"
|
330
|
+
end
|
313
331
|
when "requires_grad"
|
314
332
|
"requires_grad(_r.toBool(#{i}))"
|
315
333
|
when "pin_memory"
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -88,18 +88,18 @@ struct RubyArgs {
|
|
88
88
|
inline c10::optional<at::Generator> generator(int i);
|
89
89
|
inline at::Storage storage(int i);
|
90
90
|
inline at::ScalarType scalartype(int i);
|
91
|
-
|
91
|
+
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
92
92
|
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
93
93
|
inline c10::optional<at::Scalar> scalarOptional(int i);
|
94
94
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
95
95
|
inline c10::optional<bool> toBoolOptional(int i);
|
96
96
|
inline c10::optional<double> toDoubleOptional(int i);
|
97
97
|
inline c10::OptionalArray<double> doublelistOptional(int i);
|
98
|
-
|
99
|
-
|
98
|
+
inline at::Layout layout(int i);
|
99
|
+
inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
100
100
|
inline c10::optional<at::Layout> layoutOptional(int i);
|
101
101
|
inline at::Device device(int i);
|
102
|
-
|
102
|
+
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
103
103
|
// inline c10::optional<at::Device> deviceOptional(int i);
|
104
104
|
// inline at::Dimname dimname(int i);
|
105
105
|
// inline std::vector<at::Dimname> dimnamelist(int i);
|
@@ -240,6 +240,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
240
240
|
return it->second;
|
241
241
|
}
|
242
242
|
|
243
|
+
inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
244
|
+
if (NIL_P(args[i])) return default_scalartype;
|
245
|
+
return scalartype(i);
|
246
|
+
}
|
247
|
+
|
243
248
|
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
244
249
|
if (NIL_P(args[i])) return c10::nullopt;
|
245
250
|
return scalartype(i);
|
@@ -284,8 +289,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
284
289
|
return res;
|
285
290
|
}
|
286
291
|
|
287
|
-
inline
|
288
|
-
if (NIL_P(args[i])) return
|
292
|
+
inline at::Layout RubyArgs::layout(int i) {
|
293
|
+
if (NIL_P(args[i])) return signature.params[i].default_layout;
|
289
294
|
|
290
295
|
static std::unordered_map<VALUE, Layout> layout_map = {
|
291
296
|
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
@@ -298,6 +303,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
298
303
|
return it->second;
|
299
304
|
}
|
300
305
|
|
306
|
+
inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
|
307
|
+
if (NIL_P(args[i])) return default_layout;
|
308
|
+
return layout(i);
|
309
|
+
}
|
310
|
+
|
311
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
312
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
313
|
+
return layout(i);
|
314
|
+
}
|
315
|
+
|
301
316
|
inline at::Device RubyArgs::device(int i) {
|
302
317
|
if (NIL_P(args[i])) {
|
303
318
|
return at::Device("cpu");
|
@@ -306,6 +321,11 @@ inline at::Device RubyArgs::device(int i) {
|
|
306
321
|
return at::Device(device_str);
|
307
322
|
}
|
308
323
|
|
324
|
+
inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
325
|
+
if (NIL_P(args[i])) return default_device;
|
326
|
+
return device(i);
|
327
|
+
}
|
328
|
+
|
309
329
|
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
310
330
|
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
311
331
|
throw std::runtime_error("memoryformat not supported yet");
|
data/ext/torch/utils.h
CHANGED
@@ -1,8 +1,15 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <torch/torch.h>
|
4
|
+
|
3
5
|
#include <rice/rice.hpp>
|
4
6
|
#include <rice/stl.hpp>
|
5
7
|
|
8
|
+
static_assert(
|
9
|
+
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
|
10
|
+
"Incompatible LibTorch version"
|
11
|
+
);
|
12
|
+
|
6
13
|
// TODO find better place
|
7
14
|
inline void handle_error(torch::Error const & ex) {
|
8
15
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.10.
|
4
|
+
version: 0.10.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-04-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|