torch-rb 0.22.2 → 0.23.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -0
- data/README.md +3 -2
- data/codegen/native_functions.yaml +259 -103
- data/ext/torch/device.cpp +5 -2
- data/ext/torch/ruby_arg_parser.cpp +2 -1
- data/ext/torch/ruby_arg_parser.h +5 -2
- data/ext/torch/templates.h +4 -4
- data/ext/torch/tensor.cpp +1 -1
- data/ext/torch/utils.h +6 -1
- data/lib/torch/device.rb +0 -1
- data/lib/torch/tensor.rb +1 -6
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
data/ext/torch/device.cpp
CHANGED
|
@@ -8,7 +8,8 @@
|
|
|
8
8
|
#include "utils.h"
|
|
9
9
|
|
|
10
10
|
void init_device(Rice::Module& m) {
|
|
11
|
-
Rice::define_class_under<torch::Device>(m, "Device")
|
|
11
|
+
auto rb_cDevice = Rice::define_class_under<torch::Device>(m, "Device");
|
|
12
|
+
rb_cDevice
|
|
12
13
|
.define_constructor(Rice::Constructor<torch::Device, const std::string&>())
|
|
13
14
|
.define_method(
|
|
14
15
|
"_index",
|
|
@@ -28,8 +29,10 @@ void init_device(Rice::Module& m) {
|
|
|
28
29
|
return s.str();
|
|
29
30
|
})
|
|
30
31
|
.define_method(
|
|
31
|
-
"
|
|
32
|
+
"to_s",
|
|
32
33
|
[](torch::Device& self) {
|
|
33
34
|
return self.str();
|
|
34
35
|
});
|
|
36
|
+
|
|
37
|
+
THPDeviceClass = rb_cDevice.value();
|
|
35
38
|
}
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
#include "ruby_arg_parser.h"
|
|
9
9
|
|
|
10
|
+
VALUE THPDeviceClass = Qnil;
|
|
10
11
|
VALUE THPGeneratorClass = Qnil;
|
|
11
12
|
VALUE THPVariableClass = Qnil;
|
|
12
13
|
|
|
@@ -257,7 +258,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool {
|
|
|
257
258
|
case ParameterType::LAYOUT: return SYMBOL_P(obj);
|
|
258
259
|
case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
|
|
259
260
|
case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
|
|
260
|
-
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING)
|
|
261
|
+
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING) || THPDevice_Check(obj);
|
|
261
262
|
case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
|
|
262
263
|
case ParameterType::SYM_INT: return is_int_or_symint(obj);
|
|
263
264
|
case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
|
data/ext/torch/ruby_arg_parser.h
CHANGED
|
@@ -357,8 +357,11 @@ inline at::Device RubyArgs::device(int i) {
|
|
|
357
357
|
if (NIL_P(args[i])) {
|
|
358
358
|
return at::Device("cpu");
|
|
359
359
|
}
|
|
360
|
-
|
|
361
|
-
|
|
360
|
+
if (RB_TYPE_P(args[i], T_STRING)) {
|
|
361
|
+
const std::string &device_str = THPUtils_unpackString(args[i]);
|
|
362
|
+
return at::Device(device_str);
|
|
363
|
+
}
|
|
364
|
+
return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
|
|
362
365
|
}
|
|
363
366
|
|
|
364
367
|
inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
data/ext/torch/templates.h
CHANGED
|
@@ -69,7 +69,7 @@ namespace Rice::detail {
|
|
|
69
69
|
|
|
70
70
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
71
71
|
|
|
72
|
-
|
|
72
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
73
73
|
|
|
74
74
|
c10::complex<T> convert(VALUE x) {
|
|
75
75
|
VALUE real = rb_funcall(x, rb_intern("real"), 0);
|
|
@@ -93,7 +93,7 @@ namespace Rice::detail {
|
|
|
93
93
|
|
|
94
94
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
97
97
|
|
|
98
98
|
FanModeType convert(VALUE x) {
|
|
99
99
|
auto s = String(x).str();
|
|
@@ -122,7 +122,7 @@ namespace Rice::detail {
|
|
|
122
122
|
|
|
123
123
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
124
124
|
|
|
125
|
-
|
|
125
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
126
126
|
|
|
127
127
|
NonlinearityType convert(VALUE x) {
|
|
128
128
|
auto s = String(x).str();
|
|
@@ -169,7 +169,7 @@ namespace Rice::detail {
|
|
|
169
169
|
|
|
170
170
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
171
171
|
|
|
172
|
-
|
|
172
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
173
173
|
|
|
174
174
|
Scalar convert(VALUE x) {
|
|
175
175
|
if (FIXNUM_P(x)) {
|
data/ext/torch/tensor.cpp
CHANGED
data/ext/torch/utils.h
CHANGED
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
#include <rice/stl.hpp>
|
|
9
9
|
|
|
10
10
|
static_assert(
|
|
11
|
-
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR ==
|
|
11
|
+
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 10,
|
|
12
12
|
"Incompatible LibTorch version"
|
|
13
13
|
);
|
|
14
14
|
|
|
@@ -20,6 +20,7 @@ inline void handle_global_error(const torch::Error& ex) {
|
|
|
20
20
|
|
|
21
21
|
// keep THP prefix for now to make it easier to compare code
|
|
22
22
|
|
|
23
|
+
extern VALUE THPDeviceClass;
|
|
23
24
|
extern VALUE THPGeneratorClass;
|
|
24
25
|
extern VALUE THPVariableClass;
|
|
25
26
|
|
|
@@ -48,6 +49,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
|
48
49
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
|
49
50
|
}
|
|
50
51
|
|
|
52
|
+
inline bool THPDevice_Check(VALUE obj) {
|
|
53
|
+
return rb_obj_is_kind_of(obj, THPDeviceClass);
|
|
54
|
+
}
|
|
55
|
+
|
|
51
56
|
inline bool THPGenerator_Check(VALUE obj) {
|
|
52
57
|
return rb_obj_is_kind_of(obj, THPGeneratorClass);
|
|
53
58
|
}
|
data/lib/torch/device.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
|
@@ -115,7 +115,7 @@ module Torch
|
|
|
115
115
|
if numel != 1
|
|
116
116
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
|
117
117
|
end
|
|
118
|
-
to_a.first
|
|
118
|
+
to_a.flatten.first
|
|
119
119
|
end
|
|
120
120
|
|
|
121
121
|
def to_i
|
|
@@ -210,10 +210,5 @@ module Torch
|
|
|
210
210
|
raise TypeError, "#{self.class} can't be coerced into #{other.class}"
|
|
211
211
|
end
|
|
212
212
|
end
|
|
213
|
-
|
|
214
|
-
# TODO return Device instead of String in 0.19.0
|
|
215
|
-
def device
|
|
216
|
-
_device._str
|
|
217
|
-
end
|
|
218
213
|
end
|
|
219
214
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: torch-rb
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.23.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrew Kane
|
|
@@ -15,14 +15,14 @@ dependencies:
|
|
|
15
15
|
requirements:
|
|
16
16
|
- - ">="
|
|
17
17
|
- !ruby/object:Gem::Version
|
|
18
|
-
version: '4.
|
|
18
|
+
version: '4.8'
|
|
19
19
|
type: :runtime
|
|
20
20
|
prerelease: false
|
|
21
21
|
version_requirements: !ruby/object:Gem::Requirement
|
|
22
22
|
requirements:
|
|
23
23
|
- - ">="
|
|
24
24
|
- !ruby/object:Gem::Version
|
|
25
|
-
version: '4.
|
|
25
|
+
version: '4.8'
|
|
26
26
|
email: andrew@ankane.org
|
|
27
27
|
executables: []
|
|
28
28
|
extensions:
|
|
@@ -241,7 +241,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
|
241
241
|
- !ruby/object:Gem::Version
|
|
242
242
|
version: '0'
|
|
243
243
|
requirements: []
|
|
244
|
-
rubygems_version:
|
|
244
|
+
rubygems_version: 4.0.3
|
|
245
245
|
specification_version: 4
|
|
246
246
|
summary: Deep learning for Ruby, powered by LibTorch
|
|
247
247
|
test_files: []
|