torch-rb 0.5.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -1
- data/README.md +5 -3
- data/codegen/generate_functions.rb +7 -5
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -622
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +21 -5
- data/ext/torch/templates.h +2 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/lib/torch.rb +9 -10
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +6 -0
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +11 -88
data/ext/torch/extconf.rb
CHANGED
@@ -11,7 +11,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
|
11
11
|
|
12
12
|
# check omp first
|
13
13
|
if have_library("omp") || have_library("gomp")
|
14
|
-
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
15
14
|
$CXXFLAGS += " -Xclang" if apple_clang
|
16
15
|
$CXXFLAGS += " -fopenmp"
|
17
16
|
end
|
@@ -0,0 +1,134 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Array.hpp>
|
4
|
+
#include <rice/Constructor.hpp>
|
5
|
+
#include <rice/Hash.hpp>
|
6
|
+
#include <rice/Module.hpp>
|
7
|
+
#include <rice/String.hpp>
|
8
|
+
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
void init_ivalue(Rice::Module& m) {
|
12
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
13
|
+
Rice::define_class_under<torch::IValue>(m, "IValue")
|
14
|
+
.add_handler<torch::Error>(handle_error)
|
15
|
+
.define_constructor(Rice::Constructor<torch::IValue>())
|
16
|
+
.define_method("bool?", &torch::IValue::isBool)
|
17
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
18
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
19
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
20
|
+
.define_method("device?", &torch::IValue::isDevice)
|
21
|
+
.define_method("double?", &torch::IValue::isDouble)
|
22
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
23
|
+
.define_method("future?", &torch::IValue::isFuture)
|
24
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
25
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
26
|
+
.define_method("list?", &torch::IValue::isList)
|
27
|
+
.define_method("int?", &torch::IValue::isInt)
|
28
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
29
|
+
.define_method("module?", &torch::IValue::isModule)
|
30
|
+
.define_method("none?", &torch::IValue::isNone)
|
31
|
+
.define_method("object?", &torch::IValue::isObject)
|
32
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
33
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
34
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
35
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
36
|
+
.define_method("string?", &torch::IValue::isString)
|
37
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
38
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
39
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
40
|
+
.define_method(
|
41
|
+
"to_bool",
|
42
|
+
*[](torch::IValue& self) {
|
43
|
+
return self.toBool();
|
44
|
+
})
|
45
|
+
.define_method(
|
46
|
+
"to_double",
|
47
|
+
*[](torch::IValue& self) {
|
48
|
+
return self.toDouble();
|
49
|
+
})
|
50
|
+
.define_method(
|
51
|
+
"to_int",
|
52
|
+
*[](torch::IValue& self) {
|
53
|
+
return self.toInt();
|
54
|
+
})
|
55
|
+
.define_method(
|
56
|
+
"to_list",
|
57
|
+
*[](torch::IValue& self) {
|
58
|
+
auto list = self.toListRef();
|
59
|
+
Rice::Array obj;
|
60
|
+
for (auto& elem : list) {
|
61
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
62
|
+
}
|
63
|
+
return obj;
|
64
|
+
})
|
65
|
+
.define_method(
|
66
|
+
"to_string_ref",
|
67
|
+
*[](torch::IValue& self) {
|
68
|
+
return self.toStringRef();
|
69
|
+
})
|
70
|
+
.define_method(
|
71
|
+
"to_tensor",
|
72
|
+
*[](torch::IValue& self) {
|
73
|
+
return self.toTensor();
|
74
|
+
})
|
75
|
+
.define_method(
|
76
|
+
"to_generic_dict",
|
77
|
+
*[](torch::IValue& self) {
|
78
|
+
auto dict = self.toGenericDict();
|
79
|
+
Rice::Hash obj;
|
80
|
+
for (auto& pair : dict) {
|
81
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
82
|
+
}
|
83
|
+
return obj;
|
84
|
+
})
|
85
|
+
.define_singleton_method(
|
86
|
+
"from_tensor",
|
87
|
+
*[](torch::Tensor& v) {
|
88
|
+
return torch::IValue(v);
|
89
|
+
})
|
90
|
+
// TODO create specialized list types?
|
91
|
+
.define_singleton_method(
|
92
|
+
"from_list",
|
93
|
+
*[](Rice::Array obj) {
|
94
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
95
|
+
for (auto entry : obj) {
|
96
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
97
|
+
}
|
98
|
+
return torch::IValue(list);
|
99
|
+
})
|
100
|
+
.define_singleton_method(
|
101
|
+
"from_string",
|
102
|
+
*[](Rice::String v) {
|
103
|
+
return torch::IValue(v.str());
|
104
|
+
})
|
105
|
+
.define_singleton_method(
|
106
|
+
"from_int",
|
107
|
+
*[](int64_t v) {
|
108
|
+
return torch::IValue(v);
|
109
|
+
})
|
110
|
+
.define_singleton_method(
|
111
|
+
"from_double",
|
112
|
+
*[](double v) {
|
113
|
+
return torch::IValue(v);
|
114
|
+
})
|
115
|
+
.define_singleton_method(
|
116
|
+
"from_bool",
|
117
|
+
*[](bool v) {
|
118
|
+
return torch::IValue(v);
|
119
|
+
})
|
120
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
121
|
+
// createGenericDict and toIValue
|
122
|
+
.define_singleton_method(
|
123
|
+
"from_dict",
|
124
|
+
*[](Rice::Hash obj) {
|
125
|
+
auto key_type = c10::AnyType::get();
|
126
|
+
auto value_type = c10::AnyType::get();
|
127
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
128
|
+
elems.reserve(obj.size());
|
129
|
+
for (auto entry : obj) {
|
130
|
+
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second));
|
131
|
+
}
|
132
|
+
return torch::IValue(std::move(elems));
|
133
|
+
});
|
134
|
+
}
|
data/ext/torch/nn.cpp
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "nn_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
// need to make a distinction between parameters and tensors
|
10
|
+
class Parameter: public torch::autograd::Variable {
|
11
|
+
public:
|
12
|
+
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
13
|
+
};
|
14
|
+
|
15
|
+
void init_nn(Rice::Module& m) {
|
16
|
+
auto rb_mNN = Rice::define_module_under(m, "NN");
|
17
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
18
|
+
add_nn_functions(rb_mNN);
|
19
|
+
|
20
|
+
Rice::define_module_under(rb_mNN, "Init")
|
21
|
+
.add_handler<torch::Error>(handle_error)
|
22
|
+
.define_singleton_method(
|
23
|
+
"_calculate_gain",
|
24
|
+
*[](NonlinearityType nonlinearity, double param) {
|
25
|
+
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
|
+
})
|
27
|
+
.define_singleton_method(
|
28
|
+
"_uniform!",
|
29
|
+
*[](Tensor tensor, double low, double high) {
|
30
|
+
return torch::nn::init::uniform_(tensor, low, high);
|
31
|
+
})
|
32
|
+
.define_singleton_method(
|
33
|
+
"_normal!",
|
34
|
+
*[](Tensor tensor, double mean, double std) {
|
35
|
+
return torch::nn::init::normal_(tensor, mean, std);
|
36
|
+
})
|
37
|
+
.define_singleton_method(
|
38
|
+
"_constant!",
|
39
|
+
*[](Tensor tensor, Scalar value) {
|
40
|
+
return torch::nn::init::constant_(tensor, value);
|
41
|
+
})
|
42
|
+
.define_singleton_method(
|
43
|
+
"_ones!",
|
44
|
+
*[](Tensor tensor) {
|
45
|
+
return torch::nn::init::ones_(tensor);
|
46
|
+
})
|
47
|
+
.define_singleton_method(
|
48
|
+
"_zeros!",
|
49
|
+
*[](Tensor tensor) {
|
50
|
+
return torch::nn::init::zeros_(tensor);
|
51
|
+
})
|
52
|
+
.define_singleton_method(
|
53
|
+
"_eye!",
|
54
|
+
*[](Tensor tensor) {
|
55
|
+
return torch::nn::init::eye_(tensor);
|
56
|
+
})
|
57
|
+
.define_singleton_method(
|
58
|
+
"_dirac!",
|
59
|
+
*[](Tensor tensor) {
|
60
|
+
return torch::nn::init::dirac_(tensor);
|
61
|
+
})
|
62
|
+
.define_singleton_method(
|
63
|
+
"_xavier_uniform!",
|
64
|
+
*[](Tensor tensor, double gain) {
|
65
|
+
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
|
+
})
|
67
|
+
.define_singleton_method(
|
68
|
+
"_xavier_normal!",
|
69
|
+
*[](Tensor tensor, double gain) {
|
70
|
+
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
|
+
})
|
72
|
+
.define_singleton_method(
|
73
|
+
"_kaiming_uniform!",
|
74
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
|
+
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
|
+
})
|
77
|
+
.define_singleton_method(
|
78
|
+
"_kaiming_normal!",
|
79
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
|
+
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
|
+
})
|
82
|
+
.define_singleton_method(
|
83
|
+
"_orthogonal!",
|
84
|
+
*[](Tensor tensor, double gain) {
|
85
|
+
return torch::nn::init::orthogonal_(tensor, gain);
|
86
|
+
})
|
87
|
+
.define_singleton_method(
|
88
|
+
"_sparse!",
|
89
|
+
*[](Tensor tensor, double sparsity, double std) {
|
90
|
+
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
|
+
});
|
92
|
+
|
93
|
+
Rice::define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
94
|
+
.add_handler<torch::Error>(handle_error)
|
95
|
+
.define_method(
|
96
|
+
"grad",
|
97
|
+
*[](Parameter& self) {
|
98
|
+
auto grad = self.grad();
|
99
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
100
|
+
})
|
101
|
+
.define_method(
|
102
|
+
"grad=",
|
103
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
104
|
+
self.mutable_grad() = grad;
|
105
|
+
})
|
106
|
+
.define_singleton_method(
|
107
|
+
"_make_subclass",
|
108
|
+
*[](Tensor& rd, bool requires_grad) {
|
109
|
+
auto data = rd.detach();
|
110
|
+
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
|
+
auto var = data.set_requires_grad(requires_grad);
|
112
|
+
return Parameter(std::move(var));
|
113
|
+
});
|
114
|
+
}
|
data/ext/torch/nn_functions.h
CHANGED
@@ -0,0 +1,22 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_random(Rice::Module& m) {
|
8
|
+
Rice::define_module_under(m, "Random")
|
9
|
+
.add_handler<torch::Error>(handle_error)
|
10
|
+
.define_singleton_method(
|
11
|
+
"initial_seed",
|
12
|
+
*[]() {
|
13
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
|
+
})
|
15
|
+
.define_singleton_method(
|
16
|
+
"seed",
|
17
|
+
*[]() {
|
18
|
+
// TODO set for CUDA when available
|
19
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
20
|
+
return generator.seed();
|
21
|
+
});
|
22
|
+
}
|
@@ -487,7 +487,7 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
487
487
|
|
488
488
|
VALUE missing = Qundef;
|
489
489
|
|
490
|
-
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs,
|
490
|
+
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], // NOLINT
|
491
491
|
bool raise_exception) {
|
492
492
|
auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
|
493
493
|
ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
#pragma once
|
4
4
|
|
5
|
+
#include <sstream>
|
6
|
+
|
5
7
|
#include <torch/torch.h>
|
6
8
|
#include <rice/Exception.hpp>
|
7
9
|
|
@@ -46,7 +48,7 @@ struct FunctionParameter {
|
|
46
48
|
struct FunctionSignature {
|
47
49
|
explicit FunctionSignature(const std::string& fmt, int index);
|
48
50
|
|
49
|
-
bool parse(VALUE self, VALUE args, VALUE kwargs,
|
51
|
+
bool parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], bool raise_exception);
|
50
52
|
|
51
53
|
std::string toString() const;
|
52
54
|
|
@@ -63,13 +65,13 @@ struct FunctionSignature {
|
|
63
65
|
};
|
64
66
|
|
65
67
|
struct RubyArgs {
|
66
|
-
RubyArgs(const FunctionSignature& signature,
|
68
|
+
RubyArgs(const FunctionSignature& signature, VALUE* args)
|
67
69
|
: signature(signature)
|
68
70
|
, args(args)
|
69
71
|
, idx(signature.index) {}
|
70
72
|
|
71
73
|
const FunctionSignature& signature;
|
72
|
-
|
74
|
+
VALUE* args;
|
73
75
|
int idx;
|
74
76
|
|
75
77
|
inline at::Tensor tensor(int i);
|
@@ -328,6 +330,12 @@ inline bool RubyArgs::isNone(int i) {
|
|
328
330
|
return NIL_P(args[i]);
|
329
331
|
}
|
330
332
|
|
333
|
+
template<int N>
|
334
|
+
struct ParsedArgs {
|
335
|
+
ParsedArgs() : args() { }
|
336
|
+
VALUE args[N];
|
337
|
+
};
|
338
|
+
|
331
339
|
struct RubyArgParser {
|
332
340
|
std::vector<FunctionSignature> signatures_;
|
333
341
|
std::string function_name;
|
@@ -356,7 +364,15 @@ struct RubyArgParser {
|
|
356
364
|
});
|
357
365
|
}
|
358
366
|
|
359
|
-
|
367
|
+
template<int N>
|
368
|
+
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
|
369
|
+
if (N < max_args) {
|
370
|
+
rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N);
|
371
|
+
}
|
372
|
+
return raw_parse(self, argc, argv, dst.args);
|
373
|
+
}
|
374
|
+
|
375
|
+
inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
|
360
376
|
VALUE args, kwargs;
|
361
377
|
rb_scan_args(argc, argv, "*:", &args, &kwargs);
|
362
378
|
|
@@ -378,7 +394,7 @@ struct RubyArgParser {
|
|
378
394
|
rb_raise(rb_eArgError, "No matching signatures");
|
379
395
|
}
|
380
396
|
|
381
|
-
void print_error(VALUE self, VALUE args, VALUE kwargs,
|
397
|
+
void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
|
382
398
|
ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
|
383
399
|
std::vector<int> plausible_idxs;
|
384
400
|
ssize_t i = 0;
|
data/ext/torch/templates.h
CHANGED
@@ -44,7 +44,7 @@ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
|
44
44
|
{
|
45
45
|
Array a = Array(x);
|
46
46
|
std::vector<int64_t> vec(a.size());
|
47
|
-
for (
|
47
|
+
for (long i = 0; i < a.size(); i++) {
|
48
48
|
vec[i] = from_ruby<int64_t>(a[i]);
|
49
49
|
}
|
50
50
|
return vec;
|
@@ -56,7 +56,7 @@ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
|
56
56
|
{
|
57
57
|
Array a = Array(x);
|
58
58
|
std::vector<Tensor> vec(a.size());
|
59
|
-
for (
|
59
|
+
for (long i = 0; i < a.size(); i++) {
|
60
60
|
vec[i] = from_ruby<Tensor>(a[i]);
|
61
61
|
}
|
62
62
|
return vec;
|
@@ -0,0 +1,307 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Constructor.hpp>
|
4
|
+
#include <rice/Module.hpp>
|
5
|
+
|
6
|
+
#include "tensor_functions.h"
|
7
|
+
#include "ruby_arg_parser.h"
|
8
|
+
#include "templates.h"
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
using namespace Rice;
|
12
|
+
using torch::indexing::TensorIndex;
|
13
|
+
|
14
|
+
Class rb_cTensor;
|
15
|
+
|
16
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
17
|
+
Object obj;
|
18
|
+
|
19
|
+
std::vector<TensorIndex> indices;
|
20
|
+
indices.reserve(a.size());
|
21
|
+
|
22
|
+
for (long i = 0; i < a.size(); i++) {
|
23
|
+
obj = a[i];
|
24
|
+
|
25
|
+
if (obj.is_instance_of(rb_cInteger)) {
|
26
|
+
indices.push_back(from_ruby<int64_t>(obj));
|
27
|
+
} else if (obj.is_instance_of(rb_cRange)) {
|
28
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
29
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
30
|
+
|
31
|
+
Object begin = obj.call("begin");
|
32
|
+
if (!begin.is_nil()) {
|
33
|
+
start_index = from_ruby<int64_t>(begin);
|
34
|
+
}
|
35
|
+
|
36
|
+
Object end = obj.call("end");
|
37
|
+
if (!end.is_nil()) {
|
38
|
+
stop_index = from_ruby<int64_t>(end);
|
39
|
+
}
|
40
|
+
|
41
|
+
Object exclude_end = obj.call("exclude_end?");
|
42
|
+
if (stop_index.has_value() && !exclude_end) {
|
43
|
+
if (stop_index.value() == -1) {
|
44
|
+
stop_index = torch::nullopt;
|
45
|
+
} else {
|
46
|
+
stop_index = stop_index.value() + 1;
|
47
|
+
}
|
48
|
+
}
|
49
|
+
|
50
|
+
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
51
|
+
} else if (obj.is_instance_of(rb_cTensor)) {
|
52
|
+
indices.push_back(from_ruby<Tensor>(obj));
|
53
|
+
} else if (obj.is_nil()) {
|
54
|
+
indices.push_back(torch::indexing::None);
|
55
|
+
} else if (obj == True || obj == False) {
|
56
|
+
indices.push_back(from_ruby<bool>(obj));
|
57
|
+
} else {
|
58
|
+
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
59
|
+
}
|
60
|
+
}
|
61
|
+
return indices;
|
62
|
+
}
|
63
|
+
|
64
|
+
// hack (removes inputs argument)
|
65
|
+
// https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
|
66
|
+
// TODO add support for inputs argument
|
67
|
+
// _backward
|
68
|
+
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
69
|
+
{
|
70
|
+
HANDLE_TH_ERRORS
|
71
|
+
Tensor& self = from_ruby<Tensor&>(self_);
|
72
|
+
static RubyArgParser parser({
|
73
|
+
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
74
|
+
});
|
75
|
+
ParsedArgs<4> parsed_args;
|
76
|
+
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
77
|
+
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
78
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
79
|
+
// in future, release GVL
|
80
|
+
self._backward(inputs, gradient, retain_graph, create_graph);
|
81
|
+
};
|
82
|
+
dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
|
83
|
+
RETURN_NIL
|
84
|
+
END_HANDLE_TH_ERRORS
|
85
|
+
}
|
86
|
+
|
87
|
+
void init_tensor(Rice::Module& m) {
|
88
|
+
rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
89
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
90
|
+
add_tensor_functions(rb_cTensor);
|
91
|
+
THPVariableClass = rb_cTensor.value();
|
92
|
+
|
93
|
+
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
94
|
+
|
95
|
+
rb_cTensor
|
96
|
+
.define_method("cuda?", &torch::Tensor::is_cuda)
|
97
|
+
.define_method("sparse?", &torch::Tensor::is_sparse)
|
98
|
+
.define_method("quantized?", &torch::Tensor::is_quantized)
|
99
|
+
.define_method("dim", &torch::Tensor::dim)
|
100
|
+
.define_method("numel", &torch::Tensor::numel)
|
101
|
+
.define_method("element_size", &torch::Tensor::element_size)
|
102
|
+
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
103
|
+
.define_method(
|
104
|
+
"_size",
|
105
|
+
*[](Tensor& self, int64_t dim) {
|
106
|
+
return self.size(dim);
|
107
|
+
})
|
108
|
+
.define_method(
|
109
|
+
"_stride",
|
110
|
+
*[](Tensor& self, int64_t dim) {
|
111
|
+
return self.stride(dim);
|
112
|
+
})
|
113
|
+
// in C++ for performance
|
114
|
+
.define_method(
|
115
|
+
"shape",
|
116
|
+
*[](Tensor& self) {
|
117
|
+
Array a;
|
118
|
+
for (auto &size : self.sizes()) {
|
119
|
+
a.push(size);
|
120
|
+
}
|
121
|
+
return a;
|
122
|
+
})
|
123
|
+
.define_method(
|
124
|
+
"_strides",
|
125
|
+
*[](Tensor& self) {
|
126
|
+
Array a;
|
127
|
+
for (auto &stride : self.strides()) {
|
128
|
+
a.push(stride);
|
129
|
+
}
|
130
|
+
return a;
|
131
|
+
})
|
132
|
+
.define_method(
|
133
|
+
"_index",
|
134
|
+
*[](Tensor& self, Array indices) {
|
135
|
+
auto vec = index_vector(indices);
|
136
|
+
return self.index(vec);
|
137
|
+
})
|
138
|
+
.define_method(
|
139
|
+
"_index_put_custom",
|
140
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
141
|
+
auto vec = index_vector(indices);
|
142
|
+
return self.index_put_(vec, value);
|
143
|
+
})
|
144
|
+
.define_method(
|
145
|
+
"contiguous?",
|
146
|
+
*[](Tensor& self) {
|
147
|
+
return self.is_contiguous();
|
148
|
+
})
|
149
|
+
.define_method(
|
150
|
+
"_requires_grad!",
|
151
|
+
*[](Tensor& self, bool requires_grad) {
|
152
|
+
return self.set_requires_grad(requires_grad);
|
153
|
+
})
|
154
|
+
.define_method(
|
155
|
+
"grad",
|
156
|
+
*[](Tensor& self) {
|
157
|
+
auto grad = self.grad();
|
158
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
159
|
+
})
|
160
|
+
.define_method(
|
161
|
+
"grad=",
|
162
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
163
|
+
self.mutable_grad() = grad;
|
164
|
+
})
|
165
|
+
.define_method(
|
166
|
+
"_dtype",
|
167
|
+
*[](Tensor& self) {
|
168
|
+
return (int) at::typeMetaToScalarType(self.dtype());
|
169
|
+
})
|
170
|
+
.define_method(
|
171
|
+
"_type",
|
172
|
+
*[](Tensor& self, int dtype) {
|
173
|
+
return self.toType((torch::ScalarType) dtype);
|
174
|
+
})
|
175
|
+
.define_method(
|
176
|
+
"_layout",
|
177
|
+
*[](Tensor& self) {
|
178
|
+
std::stringstream s;
|
179
|
+
s << self.layout();
|
180
|
+
return s.str();
|
181
|
+
})
|
182
|
+
.define_method(
|
183
|
+
"device",
|
184
|
+
*[](Tensor& self) {
|
185
|
+
std::stringstream s;
|
186
|
+
s << self.device();
|
187
|
+
return s.str();
|
188
|
+
})
|
189
|
+
.define_method(
|
190
|
+
"_data_str",
|
191
|
+
*[](Tensor& self) {
|
192
|
+
Tensor tensor = self;
|
193
|
+
|
194
|
+
// move to CPU to get data
|
195
|
+
if (tensor.device().type() != torch::kCPU) {
|
196
|
+
torch::Device device("cpu");
|
197
|
+
tensor = tensor.to(device);
|
198
|
+
}
|
199
|
+
|
200
|
+
if (!tensor.is_contiguous()) {
|
201
|
+
tensor = tensor.contiguous();
|
202
|
+
}
|
203
|
+
|
204
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
205
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
206
|
+
})
|
207
|
+
// for TorchVision
|
208
|
+
.define_method(
|
209
|
+
"_data_ptr",
|
210
|
+
*[](Tensor& self) {
|
211
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
212
|
+
})
|
213
|
+
// TODO figure out a better way to do this
|
214
|
+
.define_method(
|
215
|
+
"_flat_data",
|
216
|
+
*[](Tensor& self) {
|
217
|
+
Tensor tensor = self;
|
218
|
+
|
219
|
+
// move to CPU to get data
|
220
|
+
if (tensor.device().type() != torch::kCPU) {
|
221
|
+
torch::Device device("cpu");
|
222
|
+
tensor = tensor.to(device);
|
223
|
+
}
|
224
|
+
|
225
|
+
Array a;
|
226
|
+
auto dtype = tensor.dtype();
|
227
|
+
|
228
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
229
|
+
|
230
|
+
// TODO DRY if someone knows C++
|
231
|
+
if (dtype == torch::kByte) {
|
232
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
233
|
+
a.push(view[i].item().to<uint8_t>());
|
234
|
+
}
|
235
|
+
} else if (dtype == torch::kChar) {
|
236
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
237
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
238
|
+
}
|
239
|
+
} else if (dtype == torch::kShort) {
|
240
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
241
|
+
a.push(view[i].item().to<int16_t>());
|
242
|
+
}
|
243
|
+
} else if (dtype == torch::kInt) {
|
244
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
245
|
+
a.push(view[i].item().to<int32_t>());
|
246
|
+
}
|
247
|
+
} else if (dtype == torch::kLong) {
|
248
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
249
|
+
a.push(view[i].item().to<int64_t>());
|
250
|
+
}
|
251
|
+
} else if (dtype == torch::kFloat) {
|
252
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
253
|
+
a.push(view[i].item().to<float>());
|
254
|
+
}
|
255
|
+
} else if (dtype == torch::kDouble) {
|
256
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
257
|
+
a.push(view[i].item().to<double>());
|
258
|
+
}
|
259
|
+
} else if (dtype == torch::kBool) {
|
260
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
261
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
262
|
+
}
|
263
|
+
} else {
|
264
|
+
throw std::runtime_error("Unsupported type");
|
265
|
+
}
|
266
|
+
return a;
|
267
|
+
})
|
268
|
+
.define_method(
|
269
|
+
"_to",
|
270
|
+
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
271
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
272
|
+
});
|
273
|
+
|
274
|
+
Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
275
|
+
.add_handler<torch::Error>(handle_error)
|
276
|
+
.define_constructor(Rice::Constructor<torch::TensorOptions>())
|
277
|
+
.define_method(
|
278
|
+
"dtype",
|
279
|
+
*[](torch::TensorOptions& self, int dtype) {
|
280
|
+
return self.dtype((torch::ScalarType) dtype);
|
281
|
+
})
|
282
|
+
.define_method(
|
283
|
+
"layout",
|
284
|
+
*[](torch::TensorOptions& self, const std::string& layout) {
|
285
|
+
torch::Layout l;
|
286
|
+
if (layout == "strided") {
|
287
|
+
l = torch::kStrided;
|
288
|
+
} else if (layout == "sparse") {
|
289
|
+
l = torch::kSparse;
|
290
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
291
|
+
} else {
|
292
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
293
|
+
}
|
294
|
+
return self.layout(l);
|
295
|
+
})
|
296
|
+
.define_method(
|
297
|
+
"device",
|
298
|
+
*[](torch::TensorOptions& self, const std::string& device) {
|
299
|
+
torch::Device d(device);
|
300
|
+
return self.device(d);
|
301
|
+
})
|
302
|
+
.define_method(
|
303
|
+
"requires_grad",
|
304
|
+
*[](torch::TensorOptions& self, bool requires_grad) {
|
305
|
+
return self.requires_grad(requires_grad);
|
306
|
+
});
|
307
|
+
}
|