torch-rb 0.17.1 → 0.19.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/README.md +15 -18
- data/codegen/generate_functions.rb +3 -1
- data/codegen/native_functions.yaml +351 -178
- data/ext/torch/device.cpp +6 -1
- data/ext/torch/ext.cpp +1 -1
- data/ext/torch/tensor.cpp +2 -4
- data/ext/torch/torch.cpp +7 -12
- data/ext/torch/utils.h +1 -1
- data/lib/torch/device.rb +25 -0
- data/lib/torch/tensor.rb +6 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +6 -1
- metadata +6 -9
data/ext/torch/device.cpp
CHANGED
@@ -8,7 +8,7 @@ void init_device(Rice::Module& m) {
|
|
8
8
|
Rice::define_class_under<torch::Device>(m, "Device")
|
9
9
|
.define_constructor(Rice::Constructor<torch::Device, const std::string&>())
|
10
10
|
.define_method(
|
11
|
-
"
|
11
|
+
"_index",
|
12
12
|
[](torch::Device& self) {
|
13
13
|
return self.index();
|
14
14
|
})
|
@@ -23,5 +23,10 @@ void init_device(Rice::Module& m) {
|
|
23
23
|
std::stringstream s;
|
24
24
|
s << self.type();
|
25
25
|
return s.str();
|
26
|
+
})
|
27
|
+
.define_method(
|
28
|
+
"_str",
|
29
|
+
[](torch::Device& self) {
|
30
|
+
return self.str();
|
26
31
|
});
|
27
32
|
}
|
data/ext/torch/ext.cpp
CHANGED
@@ -31,6 +31,7 @@ void Init_ext()
|
|
31
31
|
|
32
32
|
// keep this order
|
33
33
|
init_torch(m);
|
34
|
+
init_device(m);
|
34
35
|
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
35
36
|
init_nn(m);
|
36
37
|
init_fft(m);
|
@@ -39,7 +40,6 @@ void Init_ext()
|
|
39
40
|
|
40
41
|
init_backends(m);
|
41
42
|
init_cuda(m);
|
42
|
-
init_device(m);
|
43
43
|
init_generator(m, rb_cGenerator);
|
44
44
|
init_ivalue(m, rb_cIValue);
|
45
45
|
init_random(m);
|
data/ext/torch/tensor.cpp
CHANGED
@@ -212,11 +212,9 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
212
212
|
return s.str();
|
213
213
|
})
|
214
214
|
.define_method(
|
215
|
-
"
|
215
|
+
"_device",
|
216
216
|
[](Tensor& self) {
|
217
|
-
|
218
|
-
s << self.device();
|
219
|
-
return s.str();
|
217
|
+
return self.device();
|
220
218
|
})
|
221
219
|
.define_method(
|
222
220
|
"_data_str",
|
data/ext/torch/torch.cpp
CHANGED
@@ -9,19 +9,14 @@
|
|
9
9
|
#include "utils.h"
|
10
10
|
|
11
11
|
template<typename T>
|
12
|
-
torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
12
|
+
torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
|
13
13
|
std::vector<T> vec;
|
14
|
+
vec.reserve(a.size());
|
14
15
|
for (long i = 0; i < a.size(); i++) {
|
15
16
|
vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
|
16
17
|
}
|
17
18
|
|
18
|
-
|
19
|
-
auto requires_grad = options.requires_grad();
|
20
|
-
torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
21
|
-
if (requires_grad) {
|
22
|
-
t.set_requires_grad(true);
|
23
|
-
}
|
24
|
-
|
19
|
+
torch::Tensor t = torch::tensor(vec, options);
|
25
20
|
return t.reshape(size);
|
26
21
|
}
|
27
22
|
|
@@ -46,12 +41,12 @@ void init_torch(Rice::Module& m) {
|
|
46
41
|
// config
|
47
42
|
.define_singleton_function(
|
48
43
|
"show_config",
|
49
|
-
[] {
|
44
|
+
[]() {
|
50
45
|
return torch::show_config();
|
51
46
|
})
|
52
47
|
.define_singleton_function(
|
53
48
|
"parallel_info",
|
54
|
-
[] {
|
49
|
+
[]() {
|
55
50
|
return torch::get_parallel_info();
|
56
51
|
})
|
57
52
|
// begin operations
|
@@ -74,13 +69,13 @@ void init_torch(Rice::Module& m) {
|
|
74
69
|
})
|
75
70
|
.define_singleton_function(
|
76
71
|
"_from_blob",
|
77
|
-
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
72
|
+
[](Rice::String s, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
|
78
73
|
void *data = const_cast<char *>(s.c_str());
|
79
74
|
return torch::from_blob(data, size, options);
|
80
75
|
})
|
81
76
|
.define_singleton_function(
|
82
77
|
"_tensor",
|
83
|
-
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
78
|
+
[](Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
|
84
79
|
auto dtype = options.dtype();
|
85
80
|
if (dtype == torch::kByte) {
|
86
81
|
return make_tensor<uint8_t>(a, size, options);
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/device.rb
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
class Device
|
3
|
+
def index
|
4
|
+
index? ? _index : nil
|
5
|
+
end
|
6
|
+
|
7
|
+
def inspect
|
8
|
+
extra = ", index: #{index.inspect}" if index?
|
9
|
+
"device(type: #{type.inspect}#{extra})"
|
10
|
+
end
|
11
|
+
alias_method :to_s, :inspect
|
12
|
+
|
13
|
+
def ==(other)
|
14
|
+
eql?(other)
|
15
|
+
end
|
16
|
+
|
17
|
+
def eql?(other)
|
18
|
+
other.is_a?(Device) && other.type == type && other.index == index
|
19
|
+
end
|
20
|
+
|
21
|
+
def hash
|
22
|
+
[type, index].hash
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -24,6 +24,7 @@ module Torch
|
|
24
24
|
alias_method :^, :logical_xor
|
25
25
|
alias_method :<<, :__lshift__
|
26
26
|
alias_method :>>, :__rshift__
|
27
|
+
alias_method :~, :bitwise_not
|
27
28
|
|
28
29
|
def self.new(*args)
|
29
30
|
FloatTensor.new(*args)
|
@@ -208,5 +209,10 @@ module Torch
|
|
208
209
|
raise TypeError, "#{self.class} can't be coerced into #{other.class}"
|
209
210
|
end
|
210
211
|
end
|
212
|
+
|
213
|
+
# TODO return Device instead of String in 0.19.0
|
214
|
+
def device
|
215
|
+
_device._str
|
216
|
+
end
|
211
217
|
end
|
212
218
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -8,6 +8,7 @@ require "set"
|
|
8
8
|
require "tmpdir"
|
9
9
|
|
10
10
|
# modules
|
11
|
+
require_relative "torch/device"
|
11
12
|
require_relative "torch/inspector"
|
12
13
|
require_relative "torch/tensor"
|
13
14
|
require_relative "torch/version"
|
@@ -382,7 +383,11 @@ module Torch
|
|
382
383
|
alias_method :set_grad_enabled, :grad_enabled
|
383
384
|
|
384
385
|
def device(str)
|
385
|
-
|
386
|
+
if str.is_a?(Device)
|
387
|
+
str
|
388
|
+
else
|
389
|
+
Device.new(str)
|
390
|
+
end
|
386
391
|
end
|
387
392
|
|
388
393
|
def save(obj, f)
|
metadata
CHANGED
@@ -1,14 +1,13 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.19.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
9
8
|
bindir: bin
|
10
9
|
cert_chain: []
|
11
|
-
date:
|
10
|
+
date: 2025-01-30 00:00:00.000000000 Z
|
12
11
|
dependencies:
|
13
12
|
- !ruby/object:Gem::Dependency
|
14
13
|
name: rice
|
@@ -16,15 +15,14 @@ dependencies:
|
|
16
15
|
requirements:
|
17
16
|
- - ">="
|
18
17
|
- !ruby/object:Gem::Version
|
19
|
-
version:
|
18
|
+
version: 4.3.3
|
20
19
|
type: :runtime
|
21
20
|
prerelease: false
|
22
21
|
version_requirements: !ruby/object:Gem::Requirement
|
23
22
|
requirements:
|
24
23
|
- - ">="
|
25
24
|
- !ruby/object:Gem::Version
|
26
|
-
version:
|
27
|
-
description:
|
25
|
+
version: 4.3.3
|
28
26
|
email: andrew@ankane.org
|
29
27
|
executables: []
|
30
28
|
extensions:
|
@@ -65,6 +63,7 @@ files:
|
|
65
63
|
- ext/torch/wrap_outputs.h
|
66
64
|
- lib/torch-rb.rb
|
67
65
|
- lib/torch.rb
|
66
|
+
- lib/torch/device.rb
|
68
67
|
- lib/torch/hub.rb
|
69
68
|
- lib/torch/inspector.rb
|
70
69
|
- lib/torch/nn/adaptive_avg_pool1d.rb
|
@@ -224,7 +223,6 @@ homepage: https://github.com/ankane/torch.rb
|
|
224
223
|
licenses:
|
225
224
|
- BSD-3-Clause
|
226
225
|
metadata: {}
|
227
|
-
post_install_message:
|
228
226
|
rdoc_options: []
|
229
227
|
require_paths:
|
230
228
|
- lib
|
@@ -239,8 +237,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
239
237
|
- !ruby/object:Gem::Version
|
240
238
|
version: '0'
|
241
239
|
requirements: []
|
242
|
-
rubygems_version: 3.
|
243
|
-
signing_key:
|
240
|
+
rubygems_version: 3.6.2
|
244
241
|
specification_version: 4
|
245
242
|
summary: Deep learning for Ruby, powered by LibTorch
|
246
243
|
test_files: []
|