torch-rb 0.13.2 → 0.14.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -401,7 +401,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
401
401
  if (str != "None") {
402
402
  throw std::runtime_error("default value for Tensor must be none, got: " + str);
403
403
  }
404
- } else if (type_ == ParameterType::INT64) {
404
+ } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
405
405
  default_int = atol(str.c_str());
406
406
  } else if (type_ == ParameterType::BOOL) {
407
407
  default_bool = (str == "True" || str == "true");
@@ -417,7 +417,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
417
417
  default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
418
418
  at::Scalar(atof(str.c_str()));
419
419
  }
420
- } else if (type_ == ParameterType::INT_LIST) {
420
+ } else if (type_ == ParameterType::INT_LIST || type_ == ParameterType::SYM_INT_LIST) {
421
421
  if (str != "None") {
422
422
  default_intlist = parse_intlist_args(str, size);
423
423
  }
@@ -452,6 +452,31 @@ void FunctionParameter::set_default_str(const std::string& str) {
452
452
  default_string = parse_string_literal(str);
453
453
  }
454
454
  }
455
+ // These types weren't handled here before. Adding a default error
456
+ // led to a lot of test failures so adding this skip for now.
457
+ // We should correctly handle these though because it might be causing
458
+ // silent failures.
459
+ else if (type_ == ParameterType::TENSOR_LIST) {
460
+ // throw std::runtime_error("Invalid Tensor List");
461
+ } else if (type_ == ParameterType::GENERATOR) {
462
+ // throw std::runtime_error("ParameterType::GENERATOR");
463
+ } else if (type_ == ParameterType::PYOBJECT) {
464
+ // throw std::runtime_error("ParameterType::PYOBJECT");
465
+ } else if (type_ == ParameterType::MEMORY_FORMAT) {
466
+ // throw std::runtime_error("ParameterType::MEMORY_FORMAT");
467
+ } else if (type_ == ParameterType::DIMNAME) {
468
+ // throw std::runtime_error("ParameterType::DIMNAME");
469
+ } else if (type_ == ParameterType::DIMNAME_LIST) {
470
+ // throw std::runtime_error("ParameterType::DIMNAME_LIST");
471
+ } else if (type_ == ParameterType::SCALAR_LIST) {
472
+ // throw std::runtime_error("ParameterType::SCALAR_LIST");
473
+ } else if (type_ == ParameterType::STORAGE) {
474
+ // throw std::runtime_error("ParameterType::STORAGE");
475
+ } else if (type_ == ParameterType::QSCHEME) {
476
+ // throw std::runtime_error("ParameterType::QSCHEME");
477
+ } else {
478
+ throw std::runtime_error("unknown parameter type");
479
+ }
455
480
  }
456
481
 
457
482
  FunctionSignature::FunctionSignature(const std::string& fmt, int index)
data/ext/torch/torch.cpp CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ #include <fstream>
6
+
5
7
  #include "torch_functions.h"
6
8
  #include "templates.h"
7
9
  #include "utils.h"
@@ -57,16 +59,18 @@ void init_torch(Rice::Module& m) {
57
59
  "_save",
58
60
  [](const torch::IValue &value) {
59
61
  auto v = torch::pickle_save(value);
60
- std::string str(v.begin(), v.end());
61
- return str;
62
+ return Object(rb_str_new(v.data(), v.size()));
62
63
  })
63
64
  .define_singleton_function(
64
65
  "_load",
65
- [](const std::string &s) {
66
- std::vector<char> v;
67
- std::copy(s.begin(), s.end(), std::back_inserter(v));
66
+ [](const std::string &filename) {
68
67
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
69
- return torch::pickle_load(v);
68
+ std::ifstream input(filename, std::ios::binary);
69
+ std::vector<char> bytes(
70
+ (std::istreambuf_iterator<char>(input)),
71
+ (std::istreambuf_iterator<char>()));
72
+ input.close();
73
+ return torch::pickle_load(bytes);
70
74
  })
71
75
  .define_singleton_function(
72
76
  "_from_blob",
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 1,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.13.2"
2
+ VERSION = "0.14.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -387,8 +387,11 @@ module Torch
387
387
  File.binwrite(f, _save(to_ivalue(obj)))
388
388
  end
389
389
 
390
- def load(f)
391
- to_ruby(_load(File.binread(f)))
390
+ def load(filename)
391
+ # keep backwards compatibility
392
+ File.open(filename, "rb") { |f| f.read(1) }
393
+
394
+ to_ruby(_load(filename))
392
395
  end
393
396
 
394
397
  def tensor(data, **options)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.13.2
4
+ version: 0.14.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-05-12 00:00:00.000000000 Z
11
+ date: 2023-12-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -237,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
237
237
  - !ruby/object:Gem::Version
238
238
  version: '0'
239
239
  requirements: []
240
- rubygems_version: 3.4.10
240
+ rubygems_version: 3.5.3
241
241
  signing_key:
242
242
  specification_version: 4
243
243
  summary: Deep learning for Ruby, powered by LibTorch