torch-rb 0.17.1 → 0.19.0

Sign up to get free protection for your applications and to get access to all the features.
data/ext/torch/device.cpp CHANGED
@@ -8,7 +8,7 @@ void init_device(Rice::Module& m) {
8
8
  Rice::define_class_under<torch::Device>(m, "Device")
9
9
  .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
10
10
  .define_method(
11
- "index",
11
+ "_index",
12
12
  [](torch::Device& self) {
13
13
  return self.index();
14
14
  })
@@ -23,5 +23,10 @@ void init_device(Rice::Module& m) {
23
23
  std::stringstream s;
24
24
  s << self.type();
25
25
  return s.str();
26
+ })
27
+ .define_method(
28
+ "_str",
29
+ [](torch::Device& self) {
30
+ return self.str();
26
31
  });
27
32
  }
data/ext/torch/ext.cpp CHANGED
@@ -31,6 +31,7 @@ void Init_ext()
31
31
 
32
32
  // keep this order
33
33
  init_torch(m);
34
+ init_device(m);
34
35
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
35
36
  init_nn(m);
36
37
  init_fft(m);
@@ -39,7 +40,6 @@ void Init_ext()
39
40
 
40
41
  init_backends(m);
41
42
  init_cuda(m);
42
- init_device(m);
43
43
  init_generator(m, rb_cGenerator);
44
44
  init_ivalue(m, rb_cIValue);
45
45
  init_random(m);
data/ext/torch/tensor.cpp CHANGED
@@ -212,11 +212,9 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
212
212
  return s.str();
213
213
  })
214
214
  .define_method(
215
- "device",
215
+ "_device",
216
216
  [](Tensor& self) {
217
- std::stringstream s;
218
- s << self.device();
219
- return s.str();
217
+ return self.device();
220
218
  })
221
219
  .define_method(
222
220
  "_data_str",
data/ext/torch/torch.cpp CHANGED
@@ -9,19 +9,14 @@
9
9
  #include "utils.h"
10
10
 
11
11
  template<typename T>
12
- torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
12
+ torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
13
13
  std::vector<T> vec;
14
+ vec.reserve(a.size());
14
15
  for (long i = 0; i < a.size(); i++) {
15
16
  vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
16
17
  }
17
18
 
18
- // hack for requires_grad error
19
- auto requires_grad = options.requires_grad();
20
- torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
21
- if (requires_grad) {
22
- t.set_requires_grad(true);
23
- }
24
-
19
+ torch::Tensor t = torch::tensor(vec, options);
25
20
  return t.reshape(size);
26
21
  }
27
22
 
@@ -46,12 +41,12 @@ void init_torch(Rice::Module& m) {
46
41
  // config
47
42
  .define_singleton_function(
48
43
  "show_config",
49
- [] {
44
+ []() {
50
45
  return torch::show_config();
51
46
  })
52
47
  .define_singleton_function(
53
48
  "parallel_info",
54
- [] {
49
+ []() {
55
50
  return torch::get_parallel_info();
56
51
  })
57
52
  // begin operations
@@ -74,13 +69,13 @@ void init_torch(Rice::Module& m) {
74
69
  })
75
70
  .define_singleton_function(
76
71
  "_from_blob",
77
- [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
72
+ [](Rice::String s, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
78
73
  void *data = const_cast<char *>(s.c_str());
79
74
  return torch::from_blob(data, size, options);
80
75
  })
81
76
  .define_singleton_function(
82
77
  "_tensor",
83
- [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
78
+ [](Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
84
79
  auto dtype = options.dtype();
85
80
  if (dtype == torch::kByte) {
86
81
  return make_tensor<uint8_t>(a, size, options);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 4,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ class Device
3
+ def index
4
+ index? ? _index : nil
5
+ end
6
+
7
+ def inspect
8
+ extra = ", index: #{index.inspect}" if index?
9
+ "device(type: #{type.inspect}#{extra})"
10
+ end
11
+ alias_method :to_s, :inspect
12
+
13
+ def ==(other)
14
+ eql?(other)
15
+ end
16
+
17
+ def eql?(other)
18
+ other.is_a?(Device) && other.type == type && other.index == index
19
+ end
20
+
21
+ def hash
22
+ [type, index].hash
23
+ end
24
+ end
25
+ end
data/lib/torch/tensor.rb CHANGED
@@ -24,6 +24,7 @@ module Torch
24
24
  alias_method :^, :logical_xor
25
25
  alias_method :<<, :__lshift__
26
26
  alias_method :>>, :__rshift__
27
+ alias_method :~, :bitwise_not
27
28
 
28
29
  def self.new(*args)
29
30
  FloatTensor.new(*args)
@@ -208,5 +209,10 @@ module Torch
208
209
  raise TypeError, "#{self.class} can't be coerced into #{other.class}"
209
210
  end
210
211
  end
212
+
213
+ # TODO return Device instead of String in 0.19.0
214
+ def device
215
+ _device._str
216
+ end
211
217
  end
212
218
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.17.1"
2
+ VERSION = "0.19.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -8,6 +8,7 @@ require "set"
8
8
  require "tmpdir"
9
9
 
10
10
  # modules
11
+ require_relative "torch/device"
11
12
  require_relative "torch/inspector"
12
13
  require_relative "torch/tensor"
13
14
  require_relative "torch/version"
@@ -382,7 +383,11 @@ module Torch
382
383
  alias_method :set_grad_enabled, :grad_enabled
383
384
 
384
385
  def device(str)
385
- Device.new(str)
386
+ if str.is_a?(Device)
387
+ str
388
+ else
389
+ Device.new(str)
390
+ end
386
391
  end
387
392
 
388
393
  def save(obj, f)
metadata CHANGED
@@ -1,14 +1,13 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.17.1
4
+ version: 0.19.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
9
8
  bindir: bin
10
9
  cert_chain: []
11
- date: 2024-08-19 00:00:00.000000000 Z
10
+ date: 2025-01-30 00:00:00.000000000 Z
12
11
  dependencies:
13
12
  - !ruby/object:Gem::Dependency
14
13
  name: rice
@@ -16,15 +15,14 @@ dependencies:
16
15
  requirements:
17
16
  - - ">="
18
17
  - !ruby/object:Gem::Version
19
- version: '4.1'
18
+ version: 4.3.3
20
19
  type: :runtime
21
20
  prerelease: false
22
21
  version_requirements: !ruby/object:Gem::Requirement
23
22
  requirements:
24
23
  - - ">="
25
24
  - !ruby/object:Gem::Version
26
- version: '4.1'
27
- description:
25
+ version: 4.3.3
28
26
  email: andrew@ankane.org
29
27
  executables: []
30
28
  extensions:
@@ -65,6 +63,7 @@ files:
65
63
  - ext/torch/wrap_outputs.h
66
64
  - lib/torch-rb.rb
67
65
  - lib/torch.rb
66
+ - lib/torch/device.rb
68
67
  - lib/torch/hub.rb
69
68
  - lib/torch/inspector.rb
70
69
  - lib/torch/nn/adaptive_avg_pool1d.rb
@@ -224,7 +223,6 @@ homepage: https://github.com/ankane/torch.rb
224
223
  licenses:
225
224
  - BSD-3-Clause
226
225
  metadata: {}
227
- post_install_message:
228
226
  rdoc_options: []
229
227
  require_paths:
230
228
  - lib
@@ -239,8 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
239
237
  - !ruby/object:Gem::Version
240
238
  version: '0'
241
239
  requirements: []
242
- rubygems_version: 3.5.11
243
- signing_key:
240
+ rubygems_version: 3.6.2
244
241
  specification_version: 4
245
242
  summary: Deep learning for Ruby, powered by LibTorch
246
243
  test_files: []