torch-rb 0.14.0 → 0.15.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +3 -6
- data/codegen/native_functions.yaml +357 -87
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/ruby_arg_parser.cpp +27 -2
- data/ext/torch/templates.h +0 -23
- data/ext/torch/tensor.cpp +1 -0
- data/ext/torch/utils.h +1 -1
- data/lib/torch/inspector.rb +8 -3
- data/lib/torch/version.rb +1 -1
- metadata +3 -3
data/ext/torch/extconf.rb
CHANGED
@@ -52,6 +52,9 @@ $INCFLAGS += " -I#{inc}"
|
|
52
52
|
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
53
53
|
|
54
54
|
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
55
|
+
if RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i && Dir.exist?("/opt/homebrew/opt/libomp/lib")
|
56
|
+
$LDFLAGS += ",-rpath,/opt/homebrew/opt/libomp/lib"
|
57
|
+
end
|
55
58
|
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
56
59
|
|
57
60
|
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
@@ -401,7 +401,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
401
401
|
if (str != "None") {
|
402
402
|
throw std::runtime_error("default value for Tensor must be none, got: " + str);
|
403
403
|
}
|
404
|
-
} else if (type_ == ParameterType::INT64) {
|
404
|
+
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
|
405
405
|
default_int = atol(str.c_str());
|
406
406
|
} else if (type_ == ParameterType::BOOL) {
|
407
407
|
default_bool = (str == "True" || str == "true");
|
@@ -417,7 +417,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
417
417
|
default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
|
418
418
|
at::Scalar(atof(str.c_str()));
|
419
419
|
}
|
420
|
-
} else if (type_ == ParameterType::INT_LIST) {
|
420
|
+
} else if (type_ == ParameterType::INT_LIST || type_ == ParameterType::SYM_INT_LIST) {
|
421
421
|
if (str != "None") {
|
422
422
|
default_intlist = parse_intlist_args(str, size);
|
423
423
|
}
|
@@ -452,6 +452,31 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
452
452
|
default_string = parse_string_literal(str);
|
453
453
|
}
|
454
454
|
}
|
455
|
+
// These types weren't handled here before. Adding a default error
|
456
|
+
// led to a lot of test failures so adding this skip for now.
|
457
|
+
// We should correctly handle these though because it might be causing
|
458
|
+
// silent failures.
|
459
|
+
else if (type_ == ParameterType::TENSOR_LIST) {
|
460
|
+
// throw std::runtime_error("Invalid Tensor List");
|
461
|
+
} else if (type_ == ParameterType::GENERATOR) {
|
462
|
+
// throw std::runtime_error("ParameterType::GENERATOR");
|
463
|
+
} else if (type_ == ParameterType::PYOBJECT) {
|
464
|
+
// throw std::runtime_error("ParameterType::PYOBJECT");
|
465
|
+
} else if (type_ == ParameterType::MEMORY_FORMAT) {
|
466
|
+
// throw std::runtime_error("ParameterType::MEMORY_FORMAT");
|
467
|
+
} else if (type_ == ParameterType::DIMNAME) {
|
468
|
+
// throw std::runtime_error("ParameterType::DIMNAME");
|
469
|
+
} else if (type_ == ParameterType::DIMNAME_LIST) {
|
470
|
+
// throw std::runtime_error("ParameterType::DIMNAME_LIST");
|
471
|
+
} else if (type_ == ParameterType::SCALAR_LIST) {
|
472
|
+
// throw std::runtime_error("ParameterType::SCALAR_LIST");
|
473
|
+
} else if (type_ == ParameterType::STORAGE) {
|
474
|
+
// throw std::runtime_error("ParameterType::STORAGE");
|
475
|
+
} else if (type_ == ParameterType::QSCHEME) {
|
476
|
+
// throw std::runtime_error("ParameterType::QSCHEME");
|
477
|
+
} else {
|
478
|
+
throw std::runtime_error("unknown parameter type");
|
479
|
+
}
|
455
480
|
}
|
456
481
|
|
457
482
|
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
|
data/ext/torch/templates.h
CHANGED
@@ -169,27 +169,4 @@ namespace Rice::detail
|
|
169
169
|
}
|
170
170
|
}
|
171
171
|
};
|
172
|
-
|
173
|
-
template<typename T>
|
174
|
-
struct Type<torch::optional<T>>
|
175
|
-
{
|
176
|
-
static bool verify()
|
177
|
-
{
|
178
|
-
return true;
|
179
|
-
}
|
180
|
-
};
|
181
|
-
|
182
|
-
template<typename T>
|
183
|
-
class From_Ruby<torch::optional<T>>
|
184
|
-
{
|
185
|
-
public:
|
186
|
-
torch::optional<T> convert(VALUE x)
|
187
|
-
{
|
188
|
-
if (NIL_P(x)) {
|
189
|
-
return torch::nullopt;
|
190
|
-
} else {
|
191
|
-
return torch::optional<T>{From_Ruby<T>().convert(x)};
|
192
|
-
}
|
193
|
-
}
|
194
|
-
};
|
195
172
|
}
|
data/ext/torch/tensor.cpp
CHANGED
@@ -103,6 +103,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
103
103
|
|
104
104
|
rb_cTensor
|
105
105
|
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
106
|
+
.define_method("mps?", [](Tensor& self) { return self.is_mps(); })
|
106
107
|
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
107
108
|
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
108
109
|
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/inspector.rb
CHANGED
@@ -31,9 +31,9 @@ module Torch
|
|
31
31
|
return if nonzero_finite_vals.numel == 0
|
32
32
|
|
33
33
|
# Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
|
34
|
-
nonzero_finite_abs = nonzero_finite_vals.abs
|
35
|
-
nonzero_finite_min = nonzero_finite_abs.min
|
36
|
-
nonzero_finite_max = nonzero_finite_abs.max
|
34
|
+
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs)
|
35
|
+
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min)
|
36
|
+
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max)
|
37
37
|
|
38
38
|
nonzero_finite_vals.each do |value|
|
39
39
|
if value.item != value.item.ceil
|
@@ -107,6 +107,11 @@ module Torch
|
|
107
107
|
# Ruby throws error when negative, Python doesn't
|
108
108
|
" " * [@max_width - ret.size, 0].max + ret
|
109
109
|
end
|
110
|
+
|
111
|
+
def tensor_totype(t)
|
112
|
+
dtype = t.mps? ? :float : :double
|
113
|
+
t.to(dtype: dtype)
|
114
|
+
end
|
110
115
|
end
|
111
116
|
|
112
117
|
def inspect
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.15.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2024-02-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -237,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
237
237
|
- !ruby/object:Gem::Version
|
238
238
|
version: '0'
|
239
239
|
requirements: []
|
240
|
-
rubygems_version: 3.
|
240
|
+
rubygems_version: 3.5.3
|
241
241
|
signing_key:
|
242
242
|
specification_version: 4
|
243
243
|
summary: Deep learning for Ruby, powered by LibTorch
|