torch-rb 0.13.2 → 0.14.0

Sign up to get free protection for your applications and to get access to all the features.
data/ext/torch/torch.cpp CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ #include <fstream>
6
+
5
7
  #include "torch_functions.h"
6
8
  #include "templates.h"
7
9
  #include "utils.h"
@@ -57,16 +59,18 @@ void init_torch(Rice::Module& m) {
57
59
  "_save",
58
60
  [](const torch::IValue &value) {
59
61
  auto v = torch::pickle_save(value);
60
- std::string str(v.begin(), v.end());
61
- return str;
62
+ return Object(rb_str_new(v.data(), v.size()));
62
63
  })
63
64
  .define_singleton_function(
64
65
  "_load",
65
- [](const std::string &s) {
66
- std::vector<char> v;
67
- std::copy(s.begin(), s.end(), std::back_inserter(v));
66
+ [](const std::string &filename) {
68
67
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
69
- return torch::pickle_load(v);
68
+ std::ifstream input(filename, std::ios::binary);
69
+ std::vector<char> bytes(
70
+ (std::istreambuf_iterator<char>(input)),
71
+ (std::istreambuf_iterator<char>()));
72
+ input.close();
73
+ return torch::pickle_load(bytes);
70
74
  })
71
75
  .define_singleton_function(
72
76
  "_from_blob",
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 1,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.13.2"
2
+ VERSION = "0.14.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -387,8 +387,11 @@ module Torch
387
387
  File.binwrite(f, _save(to_ivalue(obj)))
388
388
  end
389
389
 
390
- def load(f)
391
- to_ruby(_load(File.binread(f)))
390
+ def load(filename)
391
+ # keep backwards compatibility
392
+ File.open(filename, "rb") { |f| f.read(1) }
393
+
394
+ to_ruby(_load(filename))
392
395
  end
393
396
 
394
397
  def tensor(data, **options)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.13.2
4
+ version: 0.14.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-05-12 00:00:00.000000000 Z
11
+ date: 2023-11-09 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice