torch-rb 0.10.0 → 0.10.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/codegen/generate_functions.rb +22 -4
- data/ext/torch/ruby_arg_parser.h +26 -6
- data/ext/torch/utils.h +7 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: f4665eec43d85fbf02ce75f4b268dbf001bfad7e3ae1ecace0e9911b651e2cc2
|
4
|
+
data.tar.gz: d11ee1386ce7feeea68333de6c361d8737a7164cfd1626abce3a511deecb2963
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cf346bc03f36d4fc920151b0554c93c33a59d0fecd35c6f110dc862ad8b35c6b8641d306124505ddd012ce9d903363672e99772cc2bd7b981d962c9d00f08d3e
|
7
|
+
data.tar.gz: 265157846417fdc3c024e0f50d0b0a663d345ca423ad9233a7d0e722bf5134a8e15e1afce30248992155787484e80311e484c64c5787945b5f5a88625115479a
|
data/CHANGELOG.md
CHANGED
@@ -293,7 +293,13 @@ def split_opt_params(params)
|
|
293
293
|
end
|
294
294
|
|
295
295
|
def generate_tensor_options(function, opt_params)
|
296
|
-
|
296
|
+
new_function = function.base_name.start_with?("new_")
|
297
|
+
like_function = function.base_name.end_with?("_like")
|
298
|
+
|
299
|
+
code = String.new("")
|
300
|
+
code << "\n auto self = _r.tensor(0);" if like_function
|
301
|
+
code << "\n const auto options = TensorOptions()"
|
302
|
+
|
297
303
|
order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
|
298
304
|
opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
|
299
305
|
i = opt[:position]
|
@@ -304,12 +310,24 @@ def generate_tensor_options(function, opt_params)
|
|
304
310
|
if function.base_name == "arange"
|
305
311
|
"dtype(_r.scalartypeOptional(#{i}))"
|
306
312
|
else
|
307
|
-
|
313
|
+
if new_function || like_function
|
314
|
+
"dtype(_r.scalartypeWithDefault(#{i}, self.scalar_type()))"
|
315
|
+
else
|
316
|
+
"dtype(_r.scalartype(#{i}))"
|
317
|
+
end
|
308
318
|
end
|
309
319
|
when "device"
|
310
|
-
|
320
|
+
if new_function || like_function
|
321
|
+
"device(_r.deviceWithDefault(#{i}, self.device()))"
|
322
|
+
else
|
323
|
+
"device(_r.device(#{i}))"
|
324
|
+
end
|
311
325
|
when "layout"
|
312
|
-
|
326
|
+
if new_function || like_function
|
327
|
+
"layout(_r.layoutWithDefault(#{i}, self.layout()))"
|
328
|
+
else
|
329
|
+
"layout(_r.layoutOptional(#{i}))"
|
330
|
+
end
|
313
331
|
when "requires_grad"
|
314
332
|
"requires_grad(_r.toBool(#{i}))"
|
315
333
|
when "pin_memory"
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -88,18 +88,18 @@ struct RubyArgs {
|
|
88
88
|
inline c10::optional<at::Generator> generator(int i);
|
89
89
|
inline at::Storage storage(int i);
|
90
90
|
inline at::ScalarType scalartype(int i);
|
91
|
-
|
91
|
+
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
92
92
|
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
93
93
|
inline c10::optional<at::Scalar> scalarOptional(int i);
|
94
94
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
95
95
|
inline c10::optional<bool> toBoolOptional(int i);
|
96
96
|
inline c10::optional<double> toDoubleOptional(int i);
|
97
97
|
inline c10::OptionalArray<double> doublelistOptional(int i);
|
98
|
-
|
99
|
-
|
98
|
+
inline at::Layout layout(int i);
|
99
|
+
inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
100
100
|
inline c10::optional<at::Layout> layoutOptional(int i);
|
101
101
|
inline at::Device device(int i);
|
102
|
-
|
102
|
+
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
103
103
|
// inline c10::optional<at::Device> deviceOptional(int i);
|
104
104
|
// inline at::Dimname dimname(int i);
|
105
105
|
// inline std::vector<at::Dimname> dimnamelist(int i);
|
@@ -240,6 +240,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
240
240
|
return it->second;
|
241
241
|
}
|
242
242
|
|
243
|
+
inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
244
|
+
if (NIL_P(args[i])) return default_scalartype;
|
245
|
+
return scalartype(i);
|
246
|
+
}
|
247
|
+
|
243
248
|
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
244
249
|
if (NIL_P(args[i])) return c10::nullopt;
|
245
250
|
return scalartype(i);
|
@@ -284,8 +289,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
284
289
|
return res;
|
285
290
|
}
|
286
291
|
|
287
|
-
inline
|
288
|
-
if (NIL_P(args[i])) return
|
292
|
+
inline at::Layout RubyArgs::layout(int i) {
|
293
|
+
if (NIL_P(args[i])) return signature.params[i].default_layout;
|
289
294
|
|
290
295
|
static std::unordered_map<VALUE, Layout> layout_map = {
|
291
296
|
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
@@ -298,6 +303,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
298
303
|
return it->second;
|
299
304
|
}
|
300
305
|
|
306
|
+
inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
|
307
|
+
if (NIL_P(args[i])) return default_layout;
|
308
|
+
return layout(i);
|
309
|
+
}
|
310
|
+
|
311
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
312
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
313
|
+
return layout(i);
|
314
|
+
}
|
315
|
+
|
301
316
|
inline at::Device RubyArgs::device(int i) {
|
302
317
|
if (NIL_P(args[i])) {
|
303
318
|
return at::Device("cpu");
|
@@ -306,6 +321,11 @@ inline at::Device RubyArgs::device(int i) {
|
|
306
321
|
return at::Device(device_str);
|
307
322
|
}
|
308
323
|
|
324
|
+
inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
325
|
+
if (NIL_P(args[i])) return default_device;
|
326
|
+
return device(i);
|
327
|
+
}
|
328
|
+
|
309
329
|
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
310
330
|
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
311
331
|
throw std::runtime_error("memoryformat not supported yet");
|
data/ext/torch/utils.h
CHANGED
@@ -1,8 +1,15 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <torch/torch.h>
|
4
|
+
|
3
5
|
#include <rice/rice.hpp>
|
4
6
|
#include <rice/stl.hpp>
|
5
7
|
|
8
|
+
static_assert(
|
9
|
+
TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
|
10
|
+
"Incompatible LibTorch version"
|
11
|
+
);
|
12
|
+
|
6
13
|
// TODO find better place
|
7
14
|
inline void handle_error(torch::Error const & ex) {
|
8
15
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.10.
|
4
|
+
version: 0.10.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-04-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|