torch-rb 0.14.0 → 0.15.0

Sign up to get free protection for your applications and to get access to all the features.
data/ext/torch/extconf.rb CHANGED
@@ -52,6 +52,9 @@ $INCFLAGS += " -I#{inc}"
52
52
  $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
53
53
 
54
54
  $LDFLAGS += " -Wl,-rpath,#{lib}"
55
+ if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
56
+ $LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
57
+ end
55
58
  $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
56
59
 
57
60
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
@@ -401,7 +401,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
401
401
  if (str != "None") {
402
402
  throw std::runtime_error("default value for Tensor must be none, got: " + str);
403
403
  }
404
- } else if (type_ == ParameterType::INT64) {
404
+ } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
405
405
  default_int = atol(str.c_str());
406
406
  } else if (type_ == ParameterType::BOOL) {
407
407
  default_bool = (str == "True" || str == "true");
@@ -417,7 +417,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
417
417
  default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
418
418
  at::Scalar(atof(str.c_str()));
419
419
  }
420
- } else if (type_ == ParameterType::INT_LIST) {
420
+ } else if (type_ == ParameterType::INT_LIST || type_ == ParameterType::SYM_INT_LIST) {
421
421
  if (str != "None") {
422
422
  default_intlist = parse_intlist_args(str, size);
423
423
  }
@@ -452,6 +452,31 @@ void FunctionParameter::set_default_str(const std::string& str) {
452
452
  default_string = parse_string_literal(str);
453
453
  }
454
454
  }
455
+ // These types weren't handled here before. Adding a default error
456
+ // led to a lot of test failures so adding this skip for now.
457
+ // We should correctly handle these though because it might be causing
458
+ // silent failures.
459
+ else if (type_ == ParameterType::TENSOR_LIST) {
460
+ // throw std::runtime_error("Invalid Tensor List");
461
+ } else if (type_ == ParameterType::GENERATOR) {
462
+ // throw std::runtime_error("ParameterType::GENERATOR");
463
+ } else if (type_ == ParameterType::PYOBJECT) {
464
+ // throw std::runtime_error("ParameterType::PYOBJECT");
465
+ } else if (type_ == ParameterType::MEMORY_FORMAT) {
466
+ // throw std::runtime_error("ParameterType::MEMORY_FORMAT");
467
+ } else if (type_ == ParameterType::DIMNAME) {
468
+ // throw std::runtime_error("ParameterType::DIMNAME");
469
+ } else if (type_ == ParameterType::DIMNAME_LIST) {
470
+ // throw std::runtime_error("ParameterType::DIMNAME_LIST");
471
+ } else if (type_ == ParameterType::SCALAR_LIST) {
472
+ // throw std::runtime_error("ParameterType::SCALAR_LIST");
473
+ } else if (type_ == ParameterType::STORAGE) {
474
+ // throw std::runtime_error("ParameterType::STORAGE");
475
+ } else if (type_ == ParameterType::QSCHEME) {
476
+ // throw std::runtime_error("ParameterType::QSCHEME");
477
+ } else {
478
+ throw std::runtime_error("unknown parameter type");
479
+ }
455
480
  }
456
481
 
457
482
  FunctionSignature::FunctionSignature(const std::string& fmt, int index)
@@ -169,27 +169,4 @@ namespace Rice::detail
169
169
  }
170
170
  }
171
171
  };
172
-
173
- template<typename T>
174
- struct Type<torch::optional<T>>
175
- {
176
- static bool verify()
177
- {
178
- return true;
179
- }
180
- };
181
-
182
- template<typename T>
183
- class From_Ruby<torch::optional<T>>
184
- {
185
- public:
186
- torch::optional<T> convert(VALUE x)
187
- {
188
- if (NIL_P(x)) {
189
- return torch::nullopt;
190
- } else {
191
- return torch::optional<T>{From_Ruby<T>().convert(x)};
192
- }
193
- }
194
- };
195
172
  }
data/ext/torch/tensor.cpp CHANGED
@@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
103
103
 
104
104
  rb_cTensor
105
105
  .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
106
+ .define_method("mps?", [](Tensor& self) { return self.is_mps(); })
106
107
  .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
107
108
  .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
108
109
  .define_method("dim", [](Tensor& self) { return self.dim(); })
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 1,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 2,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -31,9 +31,9 @@ module Torch
31
31
  return if nonzero_finite_vals.numel == 0
32
32
 
33
33
  # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
- nonzero_finite_abs = nonzero_finite_vals.abs.double
35
- nonzero_finite_min = nonzero_finite_abs.min.double
36
- nonzero_finite_max = nonzero_finite_abs.max.double
34
+ nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
35
+ nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
36
+ nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)
37
37
 
38
38
  nonzero_finite_vals.each do |value|
39
39
  if value.item != value.item.ceil
@@ -107,6 +107,11 @@ module Torch
107
107
  # Ruby throws error when negative, Python doesn't
108
108
  " " * [@max_width - ret.size, 0].max + ret
109
109
  end
110
+
111
+ def tensor_totype(t)
112
+ dtype = t.mps? ? :float : :double
113
+ t.to(dtype: dtype)
114
+ end
110
115
  end
111
116
 
112
117
  def inspect
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.14.0"
2
+ VERSION = "0.15.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.0
4
+ version: 0.15.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-11-09 00:00:00.000000000 Z
11
+ date: 2024-02-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -237,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
237
237
  - !ruby/object:Gem::Version
238
238
  version: '0'
239
239
  requirements: []
240
- rubygems_version: 3.4.10
240
+ rubygems_version: 3.5.3
241
241
  signing_key:
242
242
  specification_version: 4
243
243
  summary: Deep learning for Ruby, powered by LibTorch