torch-rb 0.10.0 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 7884d0bd8ffd23e775b3a1ca59e536bc18cd79c6bbe2736a9802a2b1693d40de
4
- data.tar.gz: 10849c8a248a70eca7178ca3529fa7428fc3d8e5cfe2ca774b0be6cb5e75eb9b
3
+ metadata.gz: f4665eec43d85fbf02ce75f4b268dbf001bfad7e3ae1ecace0e9911b651e2cc2
4
+ data.tar.gz: d11ee1386ce7feeea68333de6c361d8737a7164cfd1626abce3a511deecb2963
5
5
  SHA512:
6
- metadata.gz: d81587eae00527e9d1e4a65b62a686dbdab27eed9401539faddf553d3a0730ad7394b01bac615bdd0474d1a6602cd0a3117e1d7cb3a3f1d5fbf8bd989fa39e59
7
- data.tar.gz: d581d07821f103ee69267bc8e6b15d5094d8b8ef004917af947e2570649e7fbe750c9893d40600b62ee970646654fc545a270e7a59adb764be2ba73f1fa18b62
6
+ metadata.gz: cf346bc03f36d4fc920151b0554c93c33a59d0fecd35c6f110dc862ad8b35c6b8641d306124505ddd012ce9d903363672e99772cc2bd7b981d962c9d00f08d3e
7
+ data.tar.gz: 265157846417fdc3c024e0f50d0b0a663d345ca423ad9233a7d0e722bf5134a8e15e1afce30248992155787484e80311e484c64c5787945b5f5a88625115479a
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.10.1 (2022-04-12)
2
+
3
+ - Fixed `dtype`, `device`, and `layout` for `new_*` and `like_*` methods
4
+
1
5
  ## 0.10.0 (2022-03-13)
2
6
 
3
7
  - Updated LibTorch to 1.11.0
@@ -293,7 +293,13 @@ def split_opt_params(params)
293
293
  end
294
294
 
295
295
  def generate_tensor_options(function, opt_params)
296
- code = "\n const auto options = TensorOptions()"
296
+ new_function = function.base_name.start_with?("new_")
297
+ like_function = function.base_name.end_with?("_like")
298
+
299
+ code = String.new("")
300
+ code << "\n auto self = _r.tensor(0);" if like_function
301
+ code << "\n const auto options = TensorOptions()"
302
+
297
303
  order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
298
304
  opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
299
305
  i = opt[:position]
@@ -304,12 +310,24 @@ def generate_tensor_options(function, opt_params)
304
310
  if function.base_name == "arange"
305
311
  "dtype(_r.scalartypeOptional(#{i}))"
306
312
  else
307
- "dtype(_r.scalartype(#{i}))"
313
+ if new_function || like_function
314
+ "dtype(_r.scalartypeWithDefault(#{i}, self.scalar_type()))"
315
+ else
316
+ "dtype(_r.scalartype(#{i}))"
317
+ end
308
318
  end
309
319
  when "device"
310
- "device(_r.device(#{i}))"
320
+ if new_function || like_function
321
+ "device(_r.deviceWithDefault(#{i}, self.device()))"
322
+ else
323
+ "device(_r.device(#{i}))"
324
+ end
311
325
  when "layout"
312
- "layout(_r.layoutOptional(#{i}))"
326
+ if new_function || like_function
327
+ "layout(_r.layoutWithDefault(#{i}, self.layout()))"
328
+ else
329
+ "layout(_r.layoutOptional(#{i}))"
330
+ end
313
331
  when "requires_grad"
314
332
  "requires_grad(_r.toBool(#{i}))"
315
333
  when "pin_memory"
@@ -88,18 +88,18 @@ struct RubyArgs {
88
88
  inline c10::optional<at::Generator> generator(int i);
89
89
  inline at::Storage storage(int i);
90
90
  inline at::ScalarType scalartype(int i);
91
- // inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
91
+ inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
92
92
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
93
93
  inline c10::optional<at::Scalar> scalarOptional(int i);
94
94
  inline c10::optional<int64_t> toInt64Optional(int i);
95
95
  inline c10::optional<bool> toBoolOptional(int i);
96
96
  inline c10::optional<double> toDoubleOptional(int i);
97
97
  inline c10::OptionalArray<double> doublelistOptional(int i);
98
- // inline at::Layout layout(int i);
99
- // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
98
+ inline at::Layout layout(int i);
99
+ inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
100
100
  inline c10::optional<at::Layout> layoutOptional(int i);
101
101
  inline at::Device device(int i);
102
- // inline at::Device deviceWithDefault(int i, const at::Device& default_device);
102
+ inline at::Device deviceWithDefault(int i, const at::Device& default_device);
103
103
  // inline c10::optional<at::Device> deviceOptional(int i);
104
104
  // inline at::Dimname dimname(int i);
105
105
  // inline std::vector<at::Dimname> dimnamelist(int i);
@@ -240,6 +240,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
240
240
  return it->second;
241
241
  }
242
242
 
243
+ inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
244
+ if (NIL_P(args[i])) return default_scalartype;
245
+ return scalartype(i);
246
+ }
247
+
243
248
  inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
244
249
  if (NIL_P(args[i])) return c10::nullopt;
245
250
  return scalartype(i);
@@ -284,8 +289,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
284
289
  return res;
285
290
  }
286
291
 
287
- inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
288
- if (NIL_P(args[i])) return c10::nullopt;
292
+ inline at::Layout RubyArgs::layout(int i) {
293
+ if (NIL_P(args[i])) return signature.params[i].default_layout;
289
294
 
290
295
  static std::unordered_map<VALUE, Layout> layout_map = {
291
296
  {ID2SYM(rb_intern("strided")), Layout::Strided},
@@ -298,6 +303,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
298
303
  return it->second;
299
304
  }
300
305
 
306
+ inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
307
+ if (NIL_P(args[i])) return default_layout;
308
+ return layout(i);
309
+ }
310
+
311
+ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
312
+ if (NIL_P(args[i])) return c10::nullopt;
313
+ return layout(i);
314
+ }
315
+
301
316
  inline at::Device RubyArgs::device(int i) {
302
317
  if (NIL_P(args[i])) {
303
318
  return at::Device("cpu");
@@ -306,6 +321,11 @@ inline at::Device RubyArgs::device(int i) {
306
321
  return at::Device(device_str);
307
322
  }
308
323
 
324
+ inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
325
+ if (NIL_P(args[i])) return default_device;
326
+ return device(i);
327
+ }
328
+
309
329
  inline at::MemoryFormat RubyArgs::memoryformat(int i) {
310
330
  if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
311
331
  throw std::runtime_error("memoryformat not supported yet");
data/ext/torch/utils.h CHANGED
@@ -1,8 +1,15 @@
1
1
  #pragma once
2
2
 
3
+ #include <torch/torch.h>
4
+
3
5
  #include <rice/rice.hpp>
4
6
  #include <rice/stl.hpp>
5
7
 
8
+ static_assert(
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
10
+ "Incompatible LibTorch version"
11
+ );
12
+
6
13
  // TODO find better place
7
14
  inline void handle_error(torch::Error const & ex) {
8
15
  throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.10.0"
2
+ VERSION = "0.10.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.0
4
+ version: 0.10.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-03-13 00:00:00.000000000 Z
11
+ date: 2022-04-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice