torch-rb 0.5.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -1
- data/README.md +5 -3
- data/codegen/generate_functions.rb +7 -5
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -622
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +21 -5
- data/ext/torch/templates.h +2 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/lib/torch.rb +9 -10
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +6 -0
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +11 -88
data/ext/torch/torch.cpp
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "torch_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_torch(Rice::Module& m) {
|
10
|
+
m.add_handler<torch::Error>(handle_error);
|
11
|
+
add_torch_functions(m);
|
12
|
+
m.define_singleton_method(
|
13
|
+
"grad_enabled?",
|
14
|
+
*[]() {
|
15
|
+
return torch::GradMode::is_enabled();
|
16
|
+
})
|
17
|
+
.define_singleton_method(
|
18
|
+
"_set_grad_enabled",
|
19
|
+
*[](bool enabled) {
|
20
|
+
torch::GradMode::set_enabled(enabled);
|
21
|
+
})
|
22
|
+
.define_singleton_method(
|
23
|
+
"manual_seed",
|
24
|
+
*[](uint64_t seed) {
|
25
|
+
return torch::manual_seed(seed);
|
26
|
+
})
|
27
|
+
// config
|
28
|
+
.define_singleton_method(
|
29
|
+
"show_config",
|
30
|
+
*[] {
|
31
|
+
return torch::show_config();
|
32
|
+
})
|
33
|
+
.define_singleton_method(
|
34
|
+
"parallel_info",
|
35
|
+
*[] {
|
36
|
+
return torch::get_parallel_info();
|
37
|
+
})
|
38
|
+
// begin operations
|
39
|
+
.define_singleton_method(
|
40
|
+
"_save",
|
41
|
+
*[](const torch::IValue &value) {
|
42
|
+
auto v = torch::pickle_save(value);
|
43
|
+
std::string str(v.begin(), v.end());
|
44
|
+
return str;
|
45
|
+
})
|
46
|
+
.define_singleton_method(
|
47
|
+
"_load",
|
48
|
+
*[](const std::string &s) {
|
49
|
+
std::vector<char> v;
|
50
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
|
+
return torch::pickle_load(v);
|
53
|
+
})
|
54
|
+
.define_singleton_method(
|
55
|
+
"_from_blob",
|
56
|
+
*[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
|
+
void *data = const_cast<char *>(s.c_str());
|
58
|
+
return torch::from_blob(data, size, options);
|
59
|
+
})
|
60
|
+
.define_singleton_method(
|
61
|
+
"_tensor",
|
62
|
+
*[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
|
+
auto dtype = options.dtype();
|
64
|
+
torch::Tensor t;
|
65
|
+
if (dtype == torch::kBool) {
|
66
|
+
std::vector<uint8_t> vec;
|
67
|
+
for (long i = 0; i < a.size(); i++) {
|
68
|
+
vec.push_back(from_ruby<bool>(a[i]));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else {
|
72
|
+
std::vector<float> vec;
|
73
|
+
for (long i = 0; i < a.size(); i++) {
|
74
|
+
vec.push_back(from_ruby<float>(a[i]));
|
75
|
+
}
|
76
|
+
// hack for requires_grad error
|
77
|
+
if (options.requires_grad()) {
|
78
|
+
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
79
|
+
t.set_requires_grad(true);
|
80
|
+
} else {
|
81
|
+
t = torch::tensor(vec, options);
|
82
|
+
}
|
83
|
+
}
|
84
|
+
return t.reshape(size);
|
85
|
+
});
|
86
|
+
}
|
data/ext/torch/torch_functions.h
CHANGED
data/ext/torch/utils.h
CHANGED
@@ -1,13 +1,20 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <rice/Exception.hpp>
|
3
4
|
#include <rice/Symbol.hpp>
|
4
5
|
|
6
|
+
// TODO find better place
|
7
|
+
inline void handle_error(torch::Error const & ex)
|
8
|
+
{
|
9
|
+
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
10
|
+
}
|
11
|
+
|
5
12
|
// keep THP prefix for now to make it easier to compare code
|
6
13
|
|
7
14
|
extern VALUE THPVariableClass;
|
8
15
|
|
9
16
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
-
return Symbol(str);
|
17
|
+
return Rice::Symbol(str);
|
11
18
|
}
|
12
19
|
|
13
20
|
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
data/lib/torch.rb
CHANGED
@@ -337,25 +337,24 @@ module Torch
|
|
337
337
|
}
|
338
338
|
end
|
339
339
|
|
340
|
-
def no_grad
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
_set_grad_enabled(previous_value)
|
347
|
-
end
|
340
|
+
def no_grad(&block)
|
341
|
+
grad_enabled(false, &block)
|
342
|
+
end
|
343
|
+
|
344
|
+
def enable_grad(&block)
|
345
|
+
grad_enabled(true, &block)
|
348
346
|
end
|
349
347
|
|
350
|
-
def
|
348
|
+
def grad_enabled(value)
|
351
349
|
previous_value = grad_enabled?
|
352
350
|
begin
|
353
|
-
_set_grad_enabled(
|
351
|
+
_set_grad_enabled(value)
|
354
352
|
yield
|
355
353
|
ensure
|
356
354
|
_set_grad_enabled(previous_value)
|
357
355
|
end
|
358
356
|
end
|
357
|
+
alias_method :set_grad_enabled, :grad_enabled
|
359
358
|
|
360
359
|
def device(str)
|
361
360
|
Device.new(str)
|
data/lib/torch/nn/linear.rb
CHANGED
data/lib/torch/nn/module.rb
CHANGED
@@ -286,6 +286,12 @@ module Torch
|
|
286
286
|
named_buffers[name]
|
287
287
|
elsif named_modules.key?(name)
|
288
288
|
named_modules[name]
|
289
|
+
elsif method.end_with?("=") && named_modules.key?(method[0..-2])
|
290
|
+
if instance_variable_defined?("@#{method[0..-2]}")
|
291
|
+
instance_variable_set("@#{method[0..-2]}", *args)
|
292
|
+
else
|
293
|
+
raise NotImplementedYet
|
294
|
+
end
|
289
295
|
else
|
290
296
|
super
|
291
297
|
end
|
data/lib/torch/nn/parameter.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -135,6 +135,10 @@ module Torch
|
|
135
135
|
Torch.ones_like(Torch.empty(*size), **options)
|
136
136
|
end
|
137
137
|
|
138
|
+
def requires_grad=(requires_grad)
|
139
|
+
_requires_grad!(requires_grad)
|
140
|
+
end
|
141
|
+
|
138
142
|
def requires_grad!(requires_grad = true)
|
139
143
|
_requires_grad!(requires_grad)
|
140
144
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.6.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-03-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -24,92 +24,8 @@ dependencies:
|
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
26
|
version: '2.2'
|
27
|
-
- !ruby/object:Gem::Dependency
|
28
|
-
name: bundler
|
29
|
-
requirement: !ruby/object:Gem::Requirement
|
30
|
-
requirements:
|
31
|
-
- - ">="
|
32
|
-
- !ruby/object:Gem::Version
|
33
|
-
version: '0'
|
34
|
-
type: :development
|
35
|
-
prerelease: false
|
36
|
-
version_requirements: !ruby/object:Gem::Requirement
|
37
|
-
requirements:
|
38
|
-
- - ">="
|
39
|
-
- !ruby/object:Gem::Version
|
40
|
-
version: '0'
|
41
|
-
- !ruby/object:Gem::Dependency
|
42
|
-
name: rake
|
43
|
-
requirement: !ruby/object:Gem::Requirement
|
44
|
-
requirements:
|
45
|
-
- - ">="
|
46
|
-
- !ruby/object:Gem::Version
|
47
|
-
version: '0'
|
48
|
-
type: :development
|
49
|
-
prerelease: false
|
50
|
-
version_requirements: !ruby/object:Gem::Requirement
|
51
|
-
requirements:
|
52
|
-
- - ">="
|
53
|
-
- !ruby/object:Gem::Version
|
54
|
-
version: '0'
|
55
|
-
- !ruby/object:Gem::Dependency
|
56
|
-
name: rake-compiler
|
57
|
-
requirement: !ruby/object:Gem::Requirement
|
58
|
-
requirements:
|
59
|
-
- - ">="
|
60
|
-
- !ruby/object:Gem::Version
|
61
|
-
version: '0'
|
62
|
-
type: :development
|
63
|
-
prerelease: false
|
64
|
-
version_requirements: !ruby/object:Gem::Requirement
|
65
|
-
requirements:
|
66
|
-
- - ">="
|
67
|
-
- !ruby/object:Gem::Version
|
68
|
-
version: '0'
|
69
|
-
- !ruby/object:Gem::Dependency
|
70
|
-
name: minitest
|
71
|
-
requirement: !ruby/object:Gem::Requirement
|
72
|
-
requirements:
|
73
|
-
- - ">="
|
74
|
-
- !ruby/object:Gem::Version
|
75
|
-
version: '5'
|
76
|
-
type: :development
|
77
|
-
prerelease: false
|
78
|
-
version_requirements: !ruby/object:Gem::Requirement
|
79
|
-
requirements:
|
80
|
-
- - ">="
|
81
|
-
- !ruby/object:Gem::Version
|
82
|
-
version: '5'
|
83
|
-
- !ruby/object:Gem::Dependency
|
84
|
-
name: numo-narray
|
85
|
-
requirement: !ruby/object:Gem::Requirement
|
86
|
-
requirements:
|
87
|
-
- - ">="
|
88
|
-
- !ruby/object:Gem::Version
|
89
|
-
version: '0'
|
90
|
-
type: :development
|
91
|
-
prerelease: false
|
92
|
-
version_requirements: !ruby/object:Gem::Requirement
|
93
|
-
requirements:
|
94
|
-
- - ">="
|
95
|
-
- !ruby/object:Gem::Version
|
96
|
-
version: '0'
|
97
|
-
- !ruby/object:Gem::Dependency
|
98
|
-
name: torchvision
|
99
|
-
requirement: !ruby/object:Gem::Requirement
|
100
|
-
requirements:
|
101
|
-
- - ">="
|
102
|
-
- !ruby/object:Gem::Version
|
103
|
-
version: 0.1.1
|
104
|
-
type: :development
|
105
|
-
prerelease: false
|
106
|
-
version_requirements: !ruby/object:Gem::Requirement
|
107
|
-
requirements:
|
108
|
-
- - ">="
|
109
|
-
- !ruby/object:Gem::Version
|
110
|
-
version: 0.1.1
|
111
27
|
description:
|
112
|
-
email: andrew@
|
28
|
+
email: andrew@ankane.org
|
113
29
|
executables: []
|
114
30
|
extensions:
|
115
31
|
- ext/torch/extconf.rb
|
@@ -121,13 +37,20 @@ files:
|
|
121
37
|
- codegen/function.rb
|
122
38
|
- codegen/generate_functions.rb
|
123
39
|
- codegen/native_functions.yaml
|
40
|
+
- ext/torch/cuda.cpp
|
41
|
+
- ext/torch/device.cpp
|
124
42
|
- ext/torch/ext.cpp
|
125
43
|
- ext/torch/extconf.rb
|
44
|
+
- ext/torch/ivalue.cpp
|
45
|
+
- ext/torch/nn.cpp
|
126
46
|
- ext/torch/nn_functions.h
|
47
|
+
- ext/torch/random.cpp
|
127
48
|
- ext/torch/ruby_arg_parser.cpp
|
128
49
|
- ext/torch/ruby_arg_parser.h
|
129
50
|
- ext/torch/templates.h
|
51
|
+
- ext/torch/tensor.cpp
|
130
52
|
- ext/torch/tensor_functions.h
|
53
|
+
- ext/torch/torch.cpp
|
131
54
|
- ext/torch/torch_functions.h
|
132
55
|
- ext/torch/utils.h
|
133
56
|
- ext/torch/wrap_outputs.h
|
@@ -282,7 +205,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
282
205
|
requirements:
|
283
206
|
- - ">="
|
284
207
|
- !ruby/object:Gem::Version
|
285
|
-
version: '2.
|
208
|
+
version: '2.6'
|
286
209
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
287
210
|
requirements:
|
288
211
|
- - ">="
|