torch-rb 0.14.1 → 0.15.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/extconf.rb CHANGED
@@ -52,6 +52,9 @@ $INCFLAGS += " -I#{inc}"
52
52
  $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
53
53
 
54
54
  $LDFLAGS += " -Wl,-rpath,#{lib}"
55
+ if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
56
+ $LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
57
+ end
55
58
  $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
56
59
 
57
60
  # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
@@ -169,27 +169,4 @@ namespace Rice::detail
169
169
  }
170
170
  }
171
171
  };
172
-
173
- template<typename T>
174
- struct Type<torch::optional<T>>
175
- {
176
- static bool verify()
177
- {
178
- return true;
179
- }
180
- };
181
-
182
- template<typename T>
183
- class From_Ruby<torch::optional<T>>
184
- {
185
- public:
186
- torch::optional<T> convert(VALUE x)
187
- {
188
- if (NIL_P(x)) {
189
- return torch::nullopt;
190
- } else {
191
- return torch::optional<T>{From_Ruby<T>().convert(x)};
192
- }
193
- }
194
- };
195
172
  }
data/ext/torch/tensor.cpp CHANGED
@@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
103
103
 
104
104
  rb_cTensor
105
105
  .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
106
+ .define_method("mps?", [](Tensor& self) { return self.is_mps(); })
106
107
  .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
107
108
  .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
108
109
  .define_method("dim", [](Tensor& self) { return self.dim(); })
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 1,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 2,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -31,9 +31,9 @@ module Torch
31
31
  return if nonzero_finite_vals.numel == 0
32
32
 
33
33
  # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
34
- nonzero_finite_abs = nonzero_finite_vals.abs.double
35
- nonzero_finite_min = nonzero_finite_abs.min.double
36
- nonzero_finite_max = nonzero_finite_abs.max.double
34
+ nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
35
+ nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
36
+ nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)
37
37
 
38
38
  nonzero_finite_vals.each do |value|
39
39
  if value.item != value.item.ceil
@@ -107,6 +107,11 @@ module Torch
107
107
  # Ruby throws error when negative, Python doesn't
108
108
  " " * [@max_width - ret.size, 0].max + ret
109
109
  end
110
+
111
+ def tensor_totype(t)
112
+ dtype = t.mps? ? :float : :double
113
+ t.to(dtype: dtype)
114
+ end
110
115
  end
111
116
 
112
117
  def inspect
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.14.1"
2
+ VERSION = "0.15.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.14.1
4
+ version: 0.15.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-12-27 00:00:00.000000000 Z
11
+ date: 2024-02-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice