torch-rb 0.17.1 → 0.19.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/device.cpp CHANGED
@@ -8,7 +8,7 @@ void init_device(Rice::Module& m) {
8
8
  Rice::define_class_under<torch::Device>(m, "Device")
9
9
  .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
10
10
  .define_method(
11
- "index",
11
+ "_index",
12
12
  [](torch::Device& self) {
13
13
  return self.index();
14
14
  })
@@ -23,5 +23,10 @@ void init_device(Rice::Module& m) {
23
23
  std::stringstream s;
24
24
  s << self.type();
25
25
  return s.str();
26
+ })
27
+ .define_method(
28
+ "_str",
29
+ [](torch::Device& self) {
30
+ return self.str();
26
31
  });
27
32
  }
data/ext/torch/ext.cpp CHANGED
@@ -31,6 +31,7 @@ void Init_ext()
31
31
 
32
32
  // keep this order
33
33
  init_torch(m);
34
+ init_device(m);
34
35
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
35
36
  init_nn(m);
36
37
  init_fft(m);
@@ -39,7 +40,6 @@ void Init_ext()
39
40
 
40
41
  init_backends(m);
41
42
  init_cuda(m);
42
- init_device(m);
43
43
  init_generator(m, rb_cGenerator);
44
44
  init_ivalue(m, rb_cIValue);
45
45
  init_random(m);
data/ext/torch/tensor.cpp CHANGED
@@ -212,11 +212,9 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
212
212
  return s.str();
213
213
  })
214
214
  .define_method(
215
- "device",
215
+ "_device",
216
216
  [](Tensor& self) {
217
- std::stringstream s;
218
- s << self.device();
219
- return s.str();
217
+ return self.device();
220
218
  })
221
219
  .define_method(
222
220
  "_data_str",
data/ext/torch/torch.cpp CHANGED
@@ -9,19 +9,14 @@
9
9
  #include "utils.h"
10
10
 
11
11
  template<typename T>
12
- torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
12
+ torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
13
13
  std::vector<T> vec;
14
+ vec.reserve(a.size());
14
15
  for (long i = 0; i < a.size(); i++) {
15
16
  vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
16
17
  }
17
18
 
18
- // hack for requires_grad error
19
- auto requires_grad = options.requires_grad();
20
- torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
21
- if (requires_grad) {
22
- t.set_requires_grad(true);
23
- }
24
-
19
+ torch::Tensor t = torch::tensor(vec, options);
25
20
  return t.reshape(size);
26
21
  }
27
22
 
@@ -46,12 +41,12 @@ void init_torch(Rice::Module& m) {
46
41
  // config
47
42
  .define_singleton_function(
48
43
  "show_config",
49
- [] {
44
+ []() {
50
45
  return torch::show_config();
51
46
  })
52
47
  .define_singleton_function(
53
48
  "parallel_info",
54
- [] {
49
+ []() {
55
50
  return torch::get_parallel_info();
56
51
  })
57
52
  // begin operations
@@ -74,13 +69,13 @@ void init_torch(Rice::Module& m) {
74
69
  })
75
70
  .define_singleton_function(
76
71
  "_from_blob",
77
- [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
72
+ [](Rice::String s, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
78
73
  void *data = const_cast<char *>(s.c_str());
79
74
  return torch::from_blob(data, size, options);
80
75
  })
81
76
  .define_singleton_function(
82
77
  "_tensor",
83
- [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
78
+ [](Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
84
79
  auto dtype = options.dtype();
85
80
  if (dtype == torch::kByte) {
86
81
  return make_tensor<uint8_t>(a, size, options);
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 4,
9
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
@@ -0,0 +1,25 @@
1
+ module Torch
2
+ class Device
3
+ def index
4
+ index? ? _index : nil
5
+ end
6
+
7
+ def inspect
8
+ extra = ", index: #{index.inspect}" if index?
9
+ "device(type: #{type.inspect}#{extra})"
10
+ end
11
+ alias_method :to_s, :inspect
12
+
13
+ def ==(other)
14
+ eql?(other)
15
+ end
16
+
17
+ def eql?(other)
18
+ other.is_a?(Device) && other.type == type && other.index == index
19
+ end
20
+
21
+ def hash
22
+ [type, index].hash
23
+ end
24
+ end
25
+ end
data/lib/torch/tensor.rb CHANGED
@@ -24,6 +24,7 @@ module Torch
24
24
  alias_method :^, :logical_xor
25
25
  alias_method :<<, :__lshift__
26
26
  alias_method :>>, :__rshift__
27
+ alias_method :~, :bitwise_not
27
28
 
28
29
  def self.new(*args)
29
30
  FloatTensor.new(*args)
@@ -208,5 +209,10 @@ module Torch
208
209
  raise TypeError, "#{self.class} can't be coerced into #{other.class}"
209
210
  end
210
211
  end
212
+
213
+ # TODO return Device instead of String in 0.19.0
214
+ def device
215
+ _device._str
216
+ end
211
217
  end
212
218
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.17.1"
2
+ VERSION = "0.19.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -8,6 +8,7 @@ require "set"
8
8
  require "tmpdir"
9
9
 
10
10
  # modules
11
+ require_relative "torch/device"
11
12
  require_relative "torch/inspector"
12
13
  require_relative "torch/tensor"
13
14
  require_relative "torch/version"
@@ -382,7 +383,11 @@ module Torch
382
383
  alias_method :set_grad_enabled, :grad_enabled
383
384
 
384
385
  def device(str)
385
- Device.new(str)
386
+ if str.is_a?(Device)
387
+ str
388
+ else
389
+ Device.new(str)
390
+ end
386
391
  end
387
392
 
388
393
  def save(obj, f)
metadata CHANGED
@@ -1,14 +1,13 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.17.1
4
+ version: 0.19.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
9
8
  bindir: bin
10
9
  cert_chain: []
11
- date: 2024-08-19 00:00:00.000000000 Z
10
+ date: 2025-01-30 00:00:00.000000000 Z
12
11
  dependencies:
13
12
  - !ruby/object:Gem::Dependency
14
13
  name: rice
@@ -16,15 +15,14 @@ dependencies:
16
15
  requirements:
17
16
  - - ">="
18
17
  - !ruby/object:Gem::Version
19
- version: '4.1'
18
+ version: 4.3.3
20
19
  type: :runtime
21
20
  prerelease: false
22
21
  version_requirements: !ruby/object:Gem::Requirement
23
22
  requirements:
24
23
  - - ">="
25
24
  - !ruby/object:Gem::Version
26
- version: '4.1'
27
- description:
25
+ version: 4.3.3
28
26
  email: andrew@ankane.org
29
27
  executables: []
30
28
  extensions:
@@ -65,6 +63,7 @@ files:
65
63
  - ext/torch/wrap_outputs.h
66
64
  - lib/torch-rb.rb
67
65
  - lib/torch.rb
66
+ - lib/torch/device.rb
68
67
  - lib/torch/hub.rb
69
68
  - lib/torch/inspector.rb
70
69
  - lib/torch/nn/adaptive_avg_pool1d.rb
@@ -224,7 +223,6 @@ homepage: https://github.com/ankane/torch.rb
224
223
  licenses:
225
224
  - BSD-3-Clause
226
225
  metadata: {}
227
- post_install_message:
228
226
  rdoc_options: []
229
227
  require_paths:
230
228
  - lib
@@ -239,8 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
239
237
  - !ruby/object:Gem::Version
240
238
  version: '0'
241
239
  requirements: []
242
- rubygems_version: 3.5.11
243
- signing_key:
240
+ rubygems_version: 3.6.2
244
241
  specification_version: 4
245
242
  summary: Deep learning for Ruby, powered by LibTorch
246
243
  test_files: []