torch-rb 0.13.2 → 0.14.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -401,7 +401,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
401
401
  if (str != "None") {
402
402
  throw std::runtime_error("default value for Tensor must be none, got: " + str);
403
403
  }
404
- } else if (type_ == ParameterType::INT64) {
404
+ } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
405
405
  default_int = atol(str.c_str());
406
406
  } else if (type_ == ParameterType::BOOL) {
407
407
  default_bool = (str == "True" || str == "true");
@@ -417,7 +417,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
417
417
  default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
418
418
  at::Scalar(atof(str.c_str()));
419
419
  }
420
- } else if (type_ == ParameterType::INT_LIST) {
420
+ } else if (type_ == ParameterType::INT_LIST || type_ == ParameterType::SYM_INT_LIST) {
421
421
  if (str != "None") {
422
422
  default_intlist = parse_intlist_args(str, size);
423
423
  }
@@ -452,6 +452,31 @@ void FunctionParameter::set_default_str(const std::string& str) {
452
452
  default_string = parse_string_literal(str);
453
453
  }
454
454
  }
455
+ // These types weren't handled here before. Adding a default error
456
+ // led to a lot of test failures so adding this skip for now.
457
+ // We should correctly handle these though because it might be causing
458
+ // silent failures.
459
+ else if (type_ == ParameterType::TENSOR_LIST) {
460
+ // throw std::runtime_error("Invalid Tensor List");
461
+ } else if (type_ == ParameterType::GENERATOR) {
462
+ // throw std::runtime_error("ParameterType::GENERATOR");
463
+ } else if (type_ == ParameterType::PYOBJECT) {
464
+ // throw std::runtime_error("ParameterType::PYOBJECT");
465
+ } else if (type_ == ParameterType::MEMORY_FORMAT) {
466
+ // throw std::runtime_error("ParameterType::MEMORY_FORMAT");
467
+ } else if (type_ == ParameterType::DIMNAME) {
468
+ // throw std::runtime_error("ParameterType::DIMNAME");
469
+ } else if (type_ == ParameterType::DIMNAME_LIST) {
470
+ // throw std::runtime_error("ParameterType::DIMNAME_LIST");
471
+ } else if (type_ == ParameterType::SCALAR_LIST) {
472
+ // throw std::runtime_error("ParameterType::SCALAR_LIST");
473
+ } else if (type_ == ParameterType::STORAGE) {
474
+ // throw std::runtime_error("ParameterType::STORAGE");
475
+ } else if (type_ == ParameterType::QSCHEME) {
476
+ // throw std::runtime_error("ParameterType::QSCHEME");
477
+ } else {
478
+ throw std::runtime_error("unknown parameter type");
479
+ }
455
480
  }
456
481
 
457
482
  FunctionSignature::FunctionSignature(const std::string& fmt, int index)
data/ext/torch/torch.cpp CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ #include <fstream>
6
+
5
7
  #include "torch_functions.h"
6
8
  #include "templates.h"
7
9
  #include "utils.h"
@@ -57,16 +59,18 @@ void init_torch(Rice::Module& m) {
57
59
  "_save",
58
60
  [](const torch::IValue &value) {
59
61
  auto v = torch::pickle_save(value);
60
- std::string str(v.begin(), v.end());
61
- return str;
62
+ return Object(rb_str_new(v.data(), v.size()));
62
63
  })
63
64
  .define_singleton_function(
64
65
  "_load",
65
- [](const std::string &s) {
66
- std::vector<char> v;
67
- std::copy(s.begin(), s.end(), std::back_inserter(v));
66
+ [](const std::string &filename) {
68
67
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
69
- return torch::pickle_load(v);
68
+ std::ifstream input(filename, std::ios::binary);
69
+ std::vector<char> bytes(
70
+ (std::istreambuf_iterator<char>(input)),
71
+ (std::istreambuf_iterator<char>()));
72
+ input.close();
73
+ return torch::pickle_load(bytes);
70
74
  })
71
75
  .define_singleton_function(
72
76
  "_from_blob",
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 1,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.13.2"
2
+ VERSION = "0.14.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -387,8 +387,11 @@ module Torch
387
387
  File.binwrite(f, _save(to_ivalue(obj)))
388
388
  end
389
389
 
390
- def load(f)
391
- to_ruby(_load(File.binread(f)))
390
+ def load(filename)
391
+ # keep backwards compatibility
392
+ File.open(filename, "rb") { |f| f.read(1) }
393
+
394
+ to_ruby(_load(filename))
392
395
  end
393
396
 
394
397
  def tensor(data, **options)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.13.2
4
+ version: 0.14.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-05-12 00:00:00.000000000 Z
11
+ date: 2023-12-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -237,7 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
237
237
  - !ruby/object:Gem::Version
238
238
  version: '0'
239
239
  requirements: []
240
- rubygems_version: 3.4.10
240
+ rubygems_version: 3.5.3
241
241
  signing_key:
242
242
  specification_version: 4
243
243
  summary: Deep learning for Ruby, powered by LibTorch