torch-rb 0.10.0 → 0.10.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7884d0bd8ffd23e775b3a1ca59e536bc18cd79c6bbe2736a9802a2b1693d40de
4
- data.tar.gz: 10849c8a248a70eca7178ca3529fa7428fc3d8e5cfe2ca774b0be6cb5e75eb9b
3
+ metadata.gz: f4665eec43d85fbf02ce75f4b268dbf001bfad7e3ae1ecace0e9911b651e2cc2
4
+ data.tar.gz: d11ee1386ce7feeea68333de6c361d8737a7164cfd1626abce3a511deecb2963
5
5
  SHA512:
6
- metadata.gz: d81587eae00527e9d1e4a65b62a686dbdab27eed9401539faddf553d3a0730ad7394b01bac615bdd0474d1a6602cd0a3117e1d7cb3a3f1d5fbf8bd989fa39e59
7
- data.tar.gz: d581d07821f103ee69267bc8e6b15d5094d8b8ef004917af947e2570649e7fbe750c9893d40600b62ee970646654fc545a270e7a59adb764be2ba73f1fa18b62
6
+ metadata.gz: cf346bc03f36d4fc920151b0554c93c33a59d0fecd35c6f110dc862ad8b35c6b8641d306124505ddd012ce9d903363672e99772cc2bd7b981d962c9d00f08d3e
7
+ data.tar.gz: 265157846417fdc3c024e0f50d0b0a663d345ca423ad9233a7d0e722bf5134a8e15e1afce30248992155787484e80311e484c64c5787945b5f5a88625115479a
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.10.1 (2022-04-12)
2
+
3
+ - Fixed `dtype`, `device`, and `layout` for `new_*` and `like_*` methods
4
+
1
5
  ## 0.10.0 (2022-03-13)
2
6
 
3
7
  - Updated LibTorch to 1.11.0
@@ -293,7 +293,13 @@ def split_opt_params(params)
293
293
  end
294
294
 
295
295
  def generate_tensor_options(function, opt_params)
296
- code = "\n const auto options = TensorOptions()"
296
+ new_function = function.base_name.start_with?("new_")
297
+ like_function = function.base_name.end_with?("_like")
298
+
299
+ code = String.new("")
300
+ code << "\n auto self = _r.tensor(0);" if like_function
301
+ code << "\n const auto options = TensorOptions()"
302
+
297
303
  order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
298
304
  opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
299
305
  i = opt[:position]
@@ -304,12 +310,24 @@ def generate_tensor_options(function, opt_params)
304
310
  if function.base_name == "arange"
305
311
  "dtype(_r.scalartypeOptional(#{i}))"
306
312
  else
307
- "dtype(_r.scalartype(#{i}))"
313
+ if new_function || like_function
314
+ "dtype(_r.scalartypeWithDefault(#{i}, self.scalar_type()))"
315
+ else
316
+ "dtype(_r.scalartype(#{i}))"
317
+ end
308
318
  end
309
319
  when "device"
310
- "device(_r.device(#{i}))"
320
+ if new_function || like_function
321
+ "device(_r.deviceWithDefault(#{i}, self.device()))"
322
+ else
323
+ "device(_r.device(#{i}))"
324
+ end
311
325
  when "layout"
312
- "layout(_r.layoutOptional(#{i}))"
326
+ if new_function || like_function
327
+ "layout(_r.layoutWithDefault(#{i}, self.layout()))"
328
+ else
329
+ "layout(_r.layoutOptional(#{i}))"
330
+ end
313
331
  when "requires_grad"
314
332
  "requires_grad(_r.toBool(#{i}))"
315
333
  when "pin_memory"
@@ -88,18 +88,18 @@ struct RubyArgs {
88
88
  inline c10::optional<at::Generator> generator(int i);
89
89
  inline at::Storage storage(int i);
90
90
  inline at::ScalarType scalartype(int i);
91
- // inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
91
+ inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
92
92
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
93
93
  inline c10::optional<at::Scalar> scalarOptional(int i);
94
94
  inline c10::optional<int64_t> toInt64Optional(int i);
95
95
  inline c10::optional<bool> toBoolOptional(int i);
96
96
  inline c10::optional<double> toDoubleOptional(int i);
97
97
  inline c10::OptionalArray<double> doublelistOptional(int i);
98
- // inline at::Layout layout(int i);
99
- // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
98
+ inline at::Layout layout(int i);
99
+ inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
100
100
  inline c10::optional<at::Layout> layoutOptional(int i);
101
101
  inline at::Device device(int i);
102
- // inline at::Device deviceWithDefault(int i, const at::Device& default_device);
102
+ inline at::Device deviceWithDefault(int i, const at::Device& default_device);
103
103
  // inline c10::optional<at::Device> deviceOptional(int i);
104
104
  // inline at::Dimname dimname(int i);
105
105
  // inline std::vector<at::Dimname> dimnamelist(int i);
@@ -240,6 +240,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
240
240
  return it->second;
241
241
  }
242
242
 
243
+ inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
244
+ if (NIL_P(args[i])) return default_scalartype;
245
+ return scalartype(i);
246
+ }
247
+
243
248
  inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
244
249
  if (NIL_P(args[i])) return c10::nullopt;
245
250
  return scalartype(i);
@@ -284,8 +289,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
284
289
  return res;
285
290
  }
286
291
 
287
- inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
288
- if (NIL_P(args[i])) return c10::nullopt;
292
+ inline at::Layout RubyArgs::layout(int i) {
293
+ if (NIL_P(args[i])) return signature.params[i].default_layout;
289
294
 
290
295
  static std::unordered_map<VALUE, Layout> layout_map = {
291
296
  {ID2SYM(rb_intern("strided")), Layout::Strided},
@@ -298,6 +303,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
298
303
  return it->second;
299
304
  }
300
305
 
306
+ inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
307
+ if (NIL_P(args[i])) return default_layout;
308
+ return layout(i);
309
+ }
310
+
311
+ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
312
+ if (NIL_P(args[i])) return c10::nullopt;
313
+ return layout(i);
314
+ }
315
+
301
316
  inline at::Device RubyArgs::device(int i) {
302
317
  if (NIL_P(args[i])) {
303
318
  return at::Device("cpu");
@@ -306,6 +321,11 @@ inline at::Device RubyArgs::device(int i) {
306
321
  return at::Device(device_str);
307
322
  }
308
323
 
324
+ inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
325
+ if (NIL_P(args[i])) return default_device;
326
+ return device(i);
327
+ }
328
+
309
329
  inline at::MemoryFormat RubyArgs::memoryformat(int i) {
310
330
  if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
311
331
  throw std::runtime_error("memoryformat not supported yet");
data/ext/torch/utils.h CHANGED
@@ -1,8 +1,15 @@
1
1
  #pragma once
2
2
 
3
+ #include <torch/torch.h>
4
+
3
5
  #include <rice/rice.hpp>
4
6
  #include <rice/stl.hpp>
5
7
 
8
+ static_assert(
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
10
+ "Incompatible LibTorch version"
11
+ );
12
+
6
13
  // TODO find better place
7
14
  inline void handle_error(torch::Error const & ex) {
8
15
  throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.10.0"
2
+ VERSION = "0.10.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.0
4
+ version: 0.10.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-03-13 00:00:00.000000000 Z
11
+ date: 2022-04-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice