torch-rb 0.13.2 → 0.14.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +1 -0
- data/codegen/generate_functions.rb +6 -1
- data/codegen/native_functions.yaml +985 -516
- data/ext/torch/ruby_arg_parser.cpp +27 -2
- data/ext/torch/torch.cpp +10 -6
- data/ext/torch/utils.h +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +5 -2
- metadata +3 -3
@@ -401,7 +401,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
401
401
|
if (str != "None") {
|
402
402
|
throw std::runtime_error("default value for Tensor must be none, got: " + str);
|
403
403
|
}
|
404
|
-
} else if (type_ == ParameterType::INT64) {
|
404
|
+
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
|
405
405
|
default_int = atol(str.c_str());
|
406
406
|
} else if (type_ == ParameterType::BOOL) {
|
407
407
|
default_bool = (str == "True" || str == "true");
|
@@ -417,7 +417,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
417
417
|
default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
|
418
418
|
at::Scalar(atof(str.c_str()));
|
419
419
|
}
|
420
|
-
} else if (type_ == ParameterType::INT_LIST) {
|
420
|
+
} else if (type_ == ParameterType::INT_LIST || type_ == ParameterType::SYM_INT_LIST) {
|
421
421
|
if (str != "None") {
|
422
422
|
default_intlist = parse_intlist_args(str, size);
|
423
423
|
}
|
@@ -452,6 +452,31 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
452
452
|
default_string = parse_string_literal(str);
|
453
453
|
}
|
454
454
|
}
|
455
|
+
// These types weren't handled here before. Adding a default error
|
456
|
+
// led to a lot of test failures so adding this skip for now.
|
457
|
+
// We should correctly handle these though because it might be causing
|
458
|
+
// silent failures.
|
459
|
+
else if (type_ == ParameterType::TENSOR_LIST) {
|
460
|
+
// throw std::runtime_error("Invalid Tensor List");
|
461
|
+
} else if (type_ == ParameterType::GENERATOR) {
|
462
|
+
// throw std::runtime_error("ParameterType::GENERATOR");
|
463
|
+
} else if (type_ == ParameterType::PYOBJECT) {
|
464
|
+
// throw std::runtime_error("ParameterType::PYOBJECT");
|
465
|
+
} else if (type_ == ParameterType::MEMORY_FORMAT) {
|
466
|
+
// throw std::runtime_error("ParameterType::MEMORY_FORMAT");
|
467
|
+
} else if (type_ == ParameterType::DIMNAME) {
|
468
|
+
// throw std::runtime_error("ParameterType::DIMNAME");
|
469
|
+
} else if (type_ == ParameterType::DIMNAME_LIST) {
|
470
|
+
// throw std::runtime_error("ParameterType::DIMNAME_LIST");
|
471
|
+
} else if (type_ == ParameterType::SCALAR_LIST) {
|
472
|
+
// throw std::runtime_error("ParameterType::SCALAR_LIST");
|
473
|
+
} else if (type_ == ParameterType::STORAGE) {
|
474
|
+
// throw std::runtime_error("ParameterType::STORAGE");
|
475
|
+
} else if (type_ == ParameterType::QSCHEME) {
|
476
|
+
// throw std::runtime_error("ParameterType::QSCHEME");
|
477
|
+
} else {
|
478
|
+
throw std::runtime_error("unknown parameter type");
|
479
|
+
}
|
455
480
|
}
|
456
481
|
|
457
482
|
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
|
data/ext/torch/torch.cpp
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
#include <rice/rice.hpp>
|
4
4
|
|
5
|
+
#include <fstream>
|
6
|
+
|
5
7
|
#include "torch_functions.h"
|
6
8
|
#include "templates.h"
|
7
9
|
#include "utils.h"
|
@@ -57,16 +59,18 @@ void init_torch(Rice::Module& m) {
|
|
57
59
|
"_save",
|
58
60
|
[](const torch::IValue &value) {
|
59
61
|
auto v = torch::pickle_save(value);
|
60
|
-
|
61
|
-
return str;
|
62
|
+
return Object(rb_str_new(v.data(), v.size()));
|
62
63
|
})
|
63
64
|
.define_singleton_function(
|
64
65
|
"_load",
|
65
|
-
[](const std::string &
|
66
|
-
std::vector<char> v;
|
67
|
-
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
66
|
+
[](const std::string &filename) {
|
68
67
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
69
|
-
|
68
|
+
std::ifstream input(filename, std::ios::binary);
|
69
|
+
std::vector<char> bytes(
|
70
|
+
(std::istreambuf_iterator<char>(input)),
|
71
|
+
(std::istreambuf_iterator<char>()));
|
72
|
+
input.close();
|
73
|
+
return torch::pickle_load(bytes);
|
70
74
|
})
|
71
75
|
.define_singleton_function(
|
72
76
|
"_from_blob",
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -387,8 +387,11 @@ module Torch
|
|
387
387
|
File.binwrite(f, _save(to_ivalue(obj)))
|
388
388
|
end
|
389
389
|
|
390
|
-
def load(
|
391
|
-
|
390
|
+
def load(filename)
|
391
|
+
# keep backwards compatibility
|
392
|
+
File.open(filename, "rb") { |f| f.read(1) }
|
393
|
+
|
394
|
+
to_ruby(_load(filename))
|
392
395
|
end
|
393
396
|
|
394
397
|
def tensor(data, **options)
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.14.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-12-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -237,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
237
237
|
- !ruby/object:Gem::Version
|
238
238
|
version: '0'
|
239
239
|
requirements: []
|
240
|
-
rubygems_version: 3.
|
240
|
+
rubygems_version: 3.5.3
|
241
241
|
signing_key:
|
242
242
|
specification_version: 4
|
243
243
|
summary: Deep learning for Ruby, powered by LibTorch
|