torch-rb 0.22.2 → 0.23.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/device.cpp CHANGED
@@ -8,7 +8,8 @@
8
8
  #include "utils.h"
9
9
 
10
10
  void init_device(Rice::Module& m) {
11
- Rice::define_class_under<torch::Device>(m, "Device")
11
+ auto rb_cDevice = Rice::define_class_under<torch::Device>(m, "Device");
12
+ rb_cDevice
12
13
  .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
13
14
  .define_method(
14
15
  "_index",
@@ -28,8 +29,10 @@ void init_device(Rice::Module& m) {
28
29
  return s.str();
29
30
  })
30
31
  .define_method(
31
- "_str",
32
+ "to_s",
32
33
  [](torch::Device& self) {
33
34
  return self.str();
34
35
  });
36
+
37
+ THPDeviceClass = rb_cDevice.value();
35
38
  }
@@ -7,6 +7,7 @@
7
7
 
8
8
  #include "ruby_arg_parser.h"
9
9
 
10
+ VALUE THPDeviceClass = Qnil;
10
11
  VALUE THPGeneratorClass = Qnil;
11
12
  VALUE THPVariableClass = Qnil;
12
13
 
@@ -257,7 +258,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool {
257
258
  case ParameterType::LAYOUT: return SYMBOL_P(obj);
258
259
  case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
259
260
  case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
260
- case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
261
+ case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING) || THPDevice_Check(obj);
261
262
  case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
262
263
  case ParameterType::SYM_INT: return is_int_or_symint(obj);
263
264
  case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
@@ -357,8 +357,11 @@ inline at::Device RubyArgs::device(int i) {
357
357
  if (NIL_P(args[i])) {
358
358
  return at::Device("cpu");
359
359
  }
360
- const std::string &device_str = THPUtils_unpackString(args[i]);
361
- return at::Device(device_str);
360
+ if (RB_TYPE_P(args[i], T_STRING)) {
361
+ const std::string &device_str = THPUtils_unpackString(args[i]);
362
+ return at::Device(device_str);
363
+ }
364
+ return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
362
365
  }
363
366
 
364
367
  inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
@@ -69,7 +69,7 @@ namespace Rice::detail {
69
69
 
70
70
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
71
71
 
72
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
72
+ double is_convertible(VALUE value) { return Convertible::Exact; }
73
73
 
74
74
  c10::complex<T> convert(VALUE x) {
75
75
  VALUE real = rb_funcall(x, rb_intern("real"), 0);
@@ -93,7 +93,7 @@ namespace Rice::detail {
93
93
 
94
94
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
95
95
 
96
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
96
+ double is_convertible(VALUE value) { return Convertible::Exact; }
97
97
 
98
98
  FanModeType convert(VALUE x) {
99
99
  auto s = String(x).str();
@@ -122,7 +122,7 @@ namespace Rice::detail {
122
122
 
123
123
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
124
124
 
125
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
125
+ double is_convertible(VALUE value) { return Convertible::Exact; }
126
126
 
127
127
  NonlinearityType convert(VALUE x) {
128
128
  auto s = String(x).str();
@@ -169,7 +169,7 @@ namespace Rice::detail {
169
169
 
170
170
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
171
171
 
172
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
172
+ double is_convertible(VALUE value) { return Convertible::Exact; }
173
173
 
174
174
  Scalar convert(VALUE x) {
175
175
  if (FIXNUM_P(x)) {
data/ext/torch/tensor.cpp CHANGED
@@ -215,7 +215,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
215
215
  return s.str();
216
216
  })
217
217
  .define_method(
218
- "_device",
218
+ "device",
219
219
  [](Tensor& self) {
220
220
  return self.device();
221
221
  })
data/ext/torch/utils.h CHANGED
@@ -8,7 +8,7 @@
8
8
  #include <rice/stl.hpp>
9
9
 
10
10
  static_assert(
11
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 9,
11
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 10,
12
12
  "Incompatible LibTorch version"
13
13
  );
14
14
 
@@ -20,6 +20,7 @@ inline void handle_global_error(const torch::Error& ex) {
20
20
 
21
21
  // keep THP prefix for now to make it easier to compare code
22
22
 
23
+ extern VALUE THPDeviceClass;
23
24
  extern VALUE THPGeneratorClass;
24
25
  extern VALUE THPVariableClass;
25
26
 
@@ -48,6 +49,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
48
49
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
49
50
  }
50
51
 
52
+ inline bool THPDevice_Check(VALUE obj) {
53
+ return rb_obj_is_kind_of(obj, THPDeviceClass);
54
+ }
55
+
51
56
  inline bool THPGenerator_Check(VALUE obj) {
52
57
  return rb_obj_is_kind_of(obj, THPGeneratorClass);
53
58
  }
data/lib/torch/device.rb CHANGED
@@ -8,7 +8,6 @@ module Torch
8
8
  extra = ", index: #{index.inspect}" if index?
9
9
  "device(type: #{type.inspect}#{extra})"
10
10
  end
11
- alias_method :to_s, :inspect
12
11
 
13
12
  def ==(other)
14
13
  eql?(other)
data/lib/torch/tensor.rb CHANGED
@@ -115,7 +115,7 @@ module Torch
115
115
  if numel != 1
116
116
  raise Error, "only one element tensors can be converted to Ruby scalars"
117
117
  end
118
- to_a.first
118
+ to_a.flatten.first
119
119
  end
120
120
 
121
121
  def to_i
@@ -210,10 +210,5 @@ module Torch
210
210
  raise TypeError, "#{self.class} can't be coerced into #{other.class}"
211
211
  end
212
212
  end
213
-
214
- # TODO return Device instead of String in 0.19.0
215
- def device
216
- _device._str
217
- end
218
213
  end
219
214
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.22.2"
2
+ VERSION = "0.23.0"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.22.2
4
+ version: 0.23.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
@@ -15,14 +15,14 @@ dependencies:
15
15
  requirements:
16
16
  - - ">="
17
17
  - !ruby/object:Gem::Version
18
- version: '4.7'
18
+ version: '4.8'
19
19
  type: :runtime
20
20
  prerelease: false
21
21
  version_requirements: !ruby/object:Gem::Requirement
22
22
  requirements:
23
23
  - - ">="
24
24
  - !ruby/object:Gem::Version
25
- version: '4.7'
25
+ version: '4.8'
26
26
  email: andrew@ankane.org
27
27
  executables: []
28
28
  extensions:
@@ -241,7 +241,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
241
241
  - !ruby/object:Gem::Version
242
242
  version: '0'
243
243
  requirements: []
244
- rubygems_version: 3.6.9
244
+ rubygems_version: 4.0.3
245
245
  specification_version: 4
246
246
  summary: Deep learning for Ruby, powered by LibTorch
247
247
  test_files: []