torch-rb 0.14.1 → 0.15.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -6
- data/codegen/native_functions.yaml +357 -87
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.h +0 -23
- data/ext/torch/tensor.cpp +1 -0
- data/ext/torch/utils.h +1 -1
- data/lib/torch/inspector.rb +8 -3
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
data/ext/torch/extconf.rb
CHANGED
@@ -52,6 +52,9 @@ $INCFLAGS += " -I#{inc}"
|
|
52
52
|
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
53
|
|
54
54
|
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
55
|
+
if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
|
56
|
+
$LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
|
57
|
+
end
|
55
58
|
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
data/ext/torch/templates.h
CHANGED
@@ -169,27 +169,4 @@ namespace Rice::detail
|
|
169
169
|
}
|
170
170
|
}
|
171
171
|
};
|
172
|
-
|
173
|
-
template<typename T>
|
174
|
-
struct Type<torch::optional<T>>
|
175
|
-
{
|
176
|
-
static bool verify()
|
177
|
-
{
|
178
|
-
return true;
|
179
|
-
}
|
180
|
-
};
|
181
|
-
|
182
|
-
template<typename T>
|
183
|
-
class From_Ruby<torch::optional<T>>
|
184
|
-
{
|
185
|
-
public:
|
186
|
-
torch::optional<T> convert(VALUE x)
|
187
|
-
{
|
188
|
-
if (NIL_P(x)) {
|
189
|
-
return torch::nullopt;
|
190
|
-
} else {
|
191
|
-
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
192
|
-
}
|
193
|
-
}
|
194
|
-
};
|
195
172
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
103
103
|
|
104
104
|
rb_cTensor
|
105
105
|
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
106
|
+
.define_method("mps?", [](Tensor& self) { return self.is_mps(); })
|
106
107
|
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
107
108
|
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
108
109
|
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/inspector.rb
CHANGED
@@ -31,9 +31,9 @@ module Torch
|
|
31
31
|
return if nonzero_finite_vals.numel == 0
|
32
32
|
|
33
33
|
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
-
nonzero_finite_abs = nonzero_finite_vals.abs
|
35
|
-
nonzero_finite_min = nonzero_finite_abs.min
|
36
|
-
nonzero_finite_max = nonzero_finite_abs.max
|
34
|
+
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
|
35
|
+
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
|
36
|
+
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)
|
37
37
|
|
38
38
|
nonzero_finite_vals.each do |value|
|
39
39
|
if value.item != value.item.ceil
|
@@ -107,6 +107,11 @@ module Torch
|
|
107
107
|
# Ruby throws error when negative, Python doesn't
|
108
108
|
" " * [@max_width - ret.size, 0].max + ret
|
109
109
|
end
|
110
|
+
|
111
|
+
def tensor_totype(t)
|
112
|
+
dtype = t.mps? ? :float : :double
|
113
|
+
t.to(dtype: dtype)
|
114
|
+
end
|
110
115
|
end
|
111
116
|
|
112
117
|
def inspect
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.15.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-02-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|