torch-rb 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -1
- data/README.md +5 -3
- data/codegen/generate_functions.rb +7 -5
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -622
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +21 -5
- data/ext/torch/templates.h +2 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/lib/torch.rb +9 -10
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +6 -0
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +11 -88
data/ext/torch/torch.cpp
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "torch_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_torch(Rice::Module& m) {
|
10
|
+
m.add_handler<torch::Error>(handle_error);
|
11
|
+
add_torch_functions(m);
|
12
|
+
m.define_singleton_method(
|
13
|
+
"grad_enabled?",
|
14
|
+
*[]() {
|
15
|
+
return torch::GradMode::is_enabled();
|
16
|
+
})
|
17
|
+
.define_singleton_method(
|
18
|
+
"_set_grad_enabled",
|
19
|
+
*[](bool enabled) {
|
20
|
+
torch::GradMode::set_enabled(enabled);
|
21
|
+
})
|
22
|
+
.define_singleton_method(
|
23
|
+
"manual_seed",
|
24
|
+
*[](uint64_t seed) {
|
25
|
+
return torch::manual_seed(seed);
|
26
|
+
})
|
27
|
+
// config
|
28
|
+
.define_singleton_method(
|
29
|
+
"show_config",
|
30
|
+
*[] {
|
31
|
+
return torch::show_config();
|
32
|
+
})
|
33
|
+
.define_singleton_method(
|
34
|
+
"parallel_info",
|
35
|
+
*[] {
|
36
|
+
return torch::get_parallel_info();
|
37
|
+
})
|
38
|
+
// begin operations
|
39
|
+
.define_singleton_method(
|
40
|
+
"_save",
|
41
|
+
*[](const torch::IValue &value) {
|
42
|
+
auto v = torch::pickle_save(value);
|
43
|
+
std::string str(v.begin(), v.end());
|
44
|
+
return str;
|
45
|
+
})
|
46
|
+
.define_singleton_method(
|
47
|
+
"_load",
|
48
|
+
*[](const std::string &s) {
|
49
|
+
std::vector<char> v;
|
50
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
|
+
return torch::pickle_load(v);
|
53
|
+
})
|
54
|
+
.define_singleton_method(
|
55
|
+
"_from_blob",
|
56
|
+
*[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
|
+
void *data = const_cast<char *>(s.c_str());
|
58
|
+
return torch::from_blob(data, size, options);
|
59
|
+
})
|
60
|
+
.define_singleton_method(
|
61
|
+
"_tensor",
|
62
|
+
*[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
|
+
auto dtype = options.dtype();
|
64
|
+
torch::Tensor t;
|
65
|
+
if (dtype == torch::kBool) {
|
66
|
+
std::vector<uint8_t> vec;
|
67
|
+
for (long i = 0; i < a.size(); i++) {
|
68
|
+
vec.push_back(from_ruby<bool>(a[i]));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else {
|
72
|
+
std::vector<float> vec;
|
73
|
+
for (long i = 0; i < a.size(); i++) {
|
74
|
+
vec.push_back(from_ruby<float>(a[i]));
|
75
|
+
}
|
76
|
+
// hack for requires_grad error
|
77
|
+
if (options.requires_grad()) {
|
78
|
+
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
79
|
+
t.set_requires_grad(true);
|
80
|
+
} else {
|
81
|
+
t = torch::tensor(vec, options);
|
82
|
+
}
|
83
|
+
}
|
84
|
+
return t.reshape(size);
|
85
|
+
});
|
86
|
+
}
|
data/ext/torch/torch_functions.h
CHANGED
data/ext/torch/utils.h
CHANGED
@@ -1,13 +1,20 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <rice/Exception.hpp>
|
3
4
|
#include <rice/Symbol.hpp>
|
4
5
|
|
6
|
+
// TODO find better place
|
7
|
+
inline void handle_error(torch::Error const & ex)
|
8
|
+
{
|
9
|
+
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
10
|
+
}
|
11
|
+
|
5
12
|
// keep THP prefix for now to make it easier to compare code
|
6
13
|
|
7
14
|
extern VALUE THPVariableClass;
|
8
15
|
|
9
16
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
-
return Symbol(str);
|
17
|
+
return Rice::Symbol(str);
|
11
18
|
}
|
12
19
|
|
13
20
|
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
data/lib/torch.rb
CHANGED
@@ -337,25 +337,24 @@ module Torch
|
|
337
337
|
}
|
338
338
|
end
|
339
339
|
|
340
|
-
def no_grad
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
_set_grad_enabled(previous_value)
|
347
|
-
end
|
340
|
+
def no_grad(&block)
|
341
|
+
grad_enabled(false, &block)
|
342
|
+
end
|
343
|
+
|
344
|
+
def enable_grad(&block)
|
345
|
+
grad_enabled(true, &block)
|
348
346
|
end
|
349
347
|
|
350
|
-
def
|
348
|
+
def grad_enabled(value)
|
351
349
|
previous_value = grad_enabled?
|
352
350
|
begin
|
353
|
-
_set_grad_enabled(
|
351
|
+
_set_grad_enabled(value)
|
354
352
|
yield
|
355
353
|
ensure
|
356
354
|
_set_grad_enabled(previous_value)
|
357
355
|
end
|
358
356
|
end
|
357
|
+
alias_method :set_grad_enabled, :grad_enabled
|
359
358
|
|
360
359
|
def device(str)
|
361
360
|
Device.new(str)
|
data/lib/torch/nn/linear.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -286,6 +286,12 @@ module Torch
|
|
286
286
|
named_buffers[name]
|
287
287
|
elsif named_modules.key?(name)
|
288
288
|
named_modules[name]
|
289
|
+
elsif method.end_with?("=") && named_modules.key?(method[0..-2])
|
290
|
+
if instance_variable_defined?("@#{method[0..-2]}")
|
291
|
+
instance_variable_set("@#{method[0..-2]}", *args)
|
292
|
+
else
|
293
|
+
raise NotImplementedYet
|
294
|
+
end
|
289
295
|
else
|
290
296
|
super
|
291
297
|
end
|
data/lib/torch/nn/parameter.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -135,6 +135,10 @@ module Torch
|
|
135
135
|
Torch.ones_like(Torch.empty(*size), **options)
|
136
136
|
end
|
137
137
|
|
138
|
+
def requires_grad=(requires_grad)
|
139
|
+
_requires_grad!(requires_grad)
|
140
|
+
end
|
141
|
+
|
138
142
|
def requires_grad!(requires_grad = true)
|
139
143
|
_requires_grad!(requires_grad)
|
140
144
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.6.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-03-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -24,92 +24,8 @@ dependencies:
|
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: '2.2'
|
27
|
-
- !ruby/object:Gem::Dependency
|
28
|
-
name: bundler
|
29
|
-
requirement: !ruby/object:Gem::Requirement
|
30
|
-
requirements:
|
31
|
-
- - ">="
|
32
|
-
- !ruby/object:Gem::Version
|
33
|
-
version: '0'
|
34
|
-
type: :development
|
35
|
-
prerelease: false
|
36
|
-
version_requirements: !ruby/object:Gem::Requirement
|
37
|
-
requirements:
|
38
|
-
- - ">="
|
39
|
-
- !ruby/object:Gem::Version
|
40
|
-
version: '0'
|
41
|
-
- !ruby/object:Gem::Dependency
|
42
|
-
name: rake
|
43
|
-
requirement: !ruby/object:Gem::Requirement
|
44
|
-
requirements:
|
45
|
-
- - ">="
|
46
|
-
- !ruby/object:Gem::Version
|
47
|
-
version: '0'
|
48
|
-
type: :development
|
49
|
-
prerelease: false
|
50
|
-
version_requirements: !ruby/object:Gem::Requirement
|
51
|
-
requirements:
|
52
|
-
- - ">="
|
53
|
-
- !ruby/object:Gem::Version
|
54
|
-
version: '0'
|
55
|
-
- !ruby/object:Gem::Dependency
|
56
|
-
name: rake-compiler
|
57
|
-
requirement: !ruby/object:Gem::Requirement
|
58
|
-
requirements:
|
59
|
-
- - ">="
|
60
|
-
- !ruby/object:Gem::Version
|
61
|
-
version: '0'
|
62
|
-
type: :development
|
63
|
-
prerelease: false
|
64
|
-
version_requirements: !ruby/object:Gem::Requirement
|
65
|
-
requirements:
|
66
|
-
- - ">="
|
67
|
-
- !ruby/object:Gem::Version
|
68
|
-
version: '0'
|
69
|
-
- !ruby/object:Gem::Dependency
|
70
|
-
name: minitest
|
71
|
-
requirement: !ruby/object:Gem::Requirement
|
72
|
-
requirements:
|
73
|
-
- - ">="
|
74
|
-
- !ruby/object:Gem::Version
|
75
|
-
version: '5'
|
76
|
-
type: :development
|
77
|
-
prerelease: false
|
78
|
-
version_requirements: !ruby/object:Gem::Requirement
|
79
|
-
requirements:
|
80
|
-
- - ">="
|
81
|
-
- !ruby/object:Gem::Version
|
82
|
-
version: '5'
|
83
|
-
- !ruby/object:Gem::Dependency
|
84
|
-
name: numo-narray
|
85
|
-
requirement: !ruby/object:Gem::Requirement
|
86
|
-
requirements:
|
87
|
-
- - ">="
|
88
|
-
- !ruby/object:Gem::Version
|
89
|
-
version: '0'
|
90
|
-
type: :development
|
91
|
-
prerelease: false
|
92
|
-
version_requirements: !ruby/object:Gem::Requirement
|
93
|
-
requirements:
|
94
|
-
- - ">="
|
95
|
-
- !ruby/object:Gem::Version
|
96
|
-
version: '0'
|
97
|
-
- !ruby/object:Gem::Dependency
|
98
|
-
name: torchvision
|
99
|
-
requirement: !ruby/object:Gem::Requirement
|
100
|
-
requirements:
|
101
|
-
- - ">="
|
102
|
-
- !ruby/object:Gem::Version
|
103
|
-
version: 0.1.1
|
104
|
-
type: :development
|
105
|
-
prerelease: false
|
106
|
-
version_requirements: !ruby/object:Gem::Requirement
|
107
|
-
requirements:
|
108
|
-
- - ">="
|
109
|
-
- !ruby/object:Gem::Version
|
110
|
-
version: 0.1.1
|
111
27
|
description:
|
112
|
-
email: andrew@
|
28
|
+
email: andrew@ankane.org
|
113
29
|
executables: []
|
114
30
|
extensions:
|
115
31
|
- ext/torch/extconf.rb
|
@@ -121,13 +37,20 @@ files:
|
|
121
37
|
- codegen/function.rb
|
122
38
|
- codegen/generate_functions.rb
|
123
39
|
- codegen/native_functions.yaml
|
40
|
+
- ext/torch/cuda.cpp
|
41
|
+
- ext/torch/device.cpp
|
124
42
|
- ext/torch/ext.cpp
|
125
43
|
- ext/torch/extconf.rb
|
44
|
+
- ext/torch/ivalue.cpp
|
45
|
+
- ext/torch/nn.cpp
|
126
46
|
- ext/torch/nn_functions.h
|
47
|
+
- ext/torch/random.cpp
|
127
48
|
- ext/torch/ruby_arg_parser.cpp
|
128
49
|
- ext/torch/ruby_arg_parser.h
|
129
50
|
- ext/torch/templates.h
|
51
|
+
- ext/torch/tensor.cpp
|
130
52
|
- ext/torch/tensor_functions.h
|
53
|
+
- ext/torch/torch.cpp
|
131
54
|
- ext/torch/torch_functions.h
|
132
55
|
- ext/torch/utils.h
|
133
56
|
- ext/torch/wrap_outputs.h
|
@@ -282,7 +205,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
282
205
|
requirements:
|
283
206
|
- - ">="
|
284
207
|
- !ruby/object:Gem::Version
|
285
|
-
version: '2.
|
208
|
+
version: '2.6'
|
286
209
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
287
210
|
requirements:
|
288
211
|
- - ">="
|