torch-rb 0.5.3 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -1
- data/README.md +5 -3
- data/codegen/generate_functions.rb +7 -5
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -622
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +21 -5
- data/ext/torch/templates.h +2 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/lib/torch.rb +9 -10
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +6 -0
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +11 -88
data/ext/torch/extconf.rb
CHANGED
@@ -11,7 +11,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
|
11
11
|
|
12
12
|
# check omp first
|
13
13
|
if have_library("omp") || have_library("gomp")
|
14
|
-
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
15
14
|
$CXXFLAGS += " -Xclang" if apple_clang
|
16
15
|
$CXXFLAGS += " -fopenmp"
|
17
16
|
end
|
@@ -0,0 +1,134 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Array.hpp>
|
4
|
+
#include <rice/Constructor.hpp>
|
5
|
+
#include <rice/Hash.hpp>
|
6
|
+
#include <rice/Module.hpp>
|
7
|
+
#include <rice/String.hpp>
|
8
|
+
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
void init_ivalue(Rice::Module& m) {
|
12
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
13
|
+
Rice::define_class_under<torch::IValue>(m, "IValue")
|
14
|
+
.add_handler<torch::Error>(handle_error)
|
15
|
+
.define_constructor(Rice::Constructor<torch::IValue>())
|
16
|
+
.define_method("bool?", &torch::IValue::isBool)
|
17
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
18
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
19
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
20
|
+
.define_method("device?", &torch::IValue::isDevice)
|
21
|
+
.define_method("double?", &torch::IValue::isDouble)
|
22
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
23
|
+
.define_method("future?", &torch::IValue::isFuture)
|
24
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
25
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
26
|
+
.define_method("list?", &torch::IValue::isList)
|
27
|
+
.define_method("int?", &torch::IValue::isInt)
|
28
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
29
|
+
.define_method("module?", &torch::IValue::isModule)
|
30
|
+
.define_method("none?", &torch::IValue::isNone)
|
31
|
+
.define_method("object?", &torch::IValue::isObject)
|
32
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
33
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
34
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
35
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
36
|
+
.define_method("string?", &torch::IValue::isString)
|
37
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
38
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
39
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
40
|
+
.define_method(
|
41
|
+
"to_bool",
|
42
|
+
*[](torch::IValue& self) {
|
43
|
+
return self.toBool();
|
44
|
+
})
|
45
|
+
.define_method(
|
46
|
+
"to_double",
|
47
|
+
*[](torch::IValue& self) {
|
48
|
+
return self.toDouble();
|
49
|
+
})
|
50
|
+
.define_method(
|
51
|
+
"to_int",
|
52
|
+
*[](torch::IValue& self) {
|
53
|
+
return self.toInt();
|
54
|
+
})
|
55
|
+
.define_method(
|
56
|
+
"to_list",
|
57
|
+
*[](torch::IValue& self) {
|
58
|
+
auto list = self.toListRef();
|
59
|
+
Rice::Array obj;
|
60
|
+
for (auto& elem : list) {
|
61
|
+
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
62
|
+
}
|
63
|
+
return obj;
|
64
|
+
})
|
65
|
+
.define_method(
|
66
|
+
"to_string_ref",
|
67
|
+
*[](torch::IValue& self) {
|
68
|
+
return self.toStringRef();
|
69
|
+
})
|
70
|
+
.define_method(
|
71
|
+
"to_tensor",
|
72
|
+
*[](torch::IValue& self) {
|
73
|
+
return self.toTensor();
|
74
|
+
})
|
75
|
+
.define_method(
|
76
|
+
"to_generic_dict",
|
77
|
+
*[](torch::IValue& self) {
|
78
|
+
auto dict = self.toGenericDict();
|
79
|
+
Rice::Hash obj;
|
80
|
+
for (auto& pair : dict) {
|
81
|
+
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
82
|
+
}
|
83
|
+
return obj;
|
84
|
+
})
|
85
|
+
.define_singleton_method(
|
86
|
+
"from_tensor",
|
87
|
+
*[](torch::Tensor& v) {
|
88
|
+
return torch::IValue(v);
|
89
|
+
})
|
90
|
+
// TODO create specialized list types?
|
91
|
+
.define_singleton_method(
|
92
|
+
"from_list",
|
93
|
+
*[](Rice::Array obj) {
|
94
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
95
|
+
for (auto entry : obj) {
|
96
|
+
list.push_back(from_ruby<torch::IValue>(entry));
|
97
|
+
}
|
98
|
+
return torch::IValue(list);
|
99
|
+
})
|
100
|
+
.define_singleton_method(
|
101
|
+
"from_string",
|
102
|
+
*[](Rice::String v) {
|
103
|
+
return torch::IValue(v.str());
|
104
|
+
})
|
105
|
+
.define_singleton_method(
|
106
|
+
"from_int",
|
107
|
+
*[](int64_t v) {
|
108
|
+
return torch::IValue(v);
|
109
|
+
})
|
110
|
+
.define_singleton_method(
|
111
|
+
"from_double",
|
112
|
+
*[](double v) {
|
113
|
+
return torch::IValue(v);
|
114
|
+
})
|
115
|
+
.define_singleton_method(
|
116
|
+
"from_bool",
|
117
|
+
*[](bool v) {
|
118
|
+
return torch::IValue(v);
|
119
|
+
})
|
120
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
121
|
+
// createGenericDict and toIValue
|
122
|
+
.define_singleton_method(
|
123
|
+
"from_dict",
|
124
|
+
*[](Rice::Hash obj) {
|
125
|
+
auto key_type = c10::AnyType::get();
|
126
|
+
auto value_type = c10::AnyType::get();
|
127
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
128
|
+
elems.reserve(obj.size());
|
129
|
+
for (auto entry : obj) {
|
130
|
+
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second));
|
131
|
+
}
|
132
|
+
return torch::IValue(std::move(elems));
|
133
|
+
});
|
134
|
+
}
|
data/ext/torch/nn.cpp
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "nn_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
// need to make a distinction between parameters and tensors
|
10
|
+
class Parameter: public torch::autograd::Variable {
|
11
|
+
public:
|
12
|
+
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
13
|
+
};
|
14
|
+
|
15
|
+
void init_nn(Rice::Module& m) {
|
16
|
+
auto rb_mNN = Rice::define_module_under(m, "NN");
|
17
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
18
|
+
add_nn_functions(rb_mNN);
|
19
|
+
|
20
|
+
Rice::define_module_under(rb_mNN, "Init")
|
21
|
+
.add_handler<torch::Error>(handle_error)
|
22
|
+
.define_singleton_method(
|
23
|
+
"_calculate_gain",
|
24
|
+
*[](NonlinearityType nonlinearity, double param) {
|
25
|
+
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
|
+
})
|
27
|
+
.define_singleton_method(
|
28
|
+
"_uniform!",
|
29
|
+
*[](Tensor tensor, double low, double high) {
|
30
|
+
return torch::nn::init::uniform_(tensor, low, high);
|
31
|
+
})
|
32
|
+
.define_singleton_method(
|
33
|
+
"_normal!",
|
34
|
+
*[](Tensor tensor, double mean, double std) {
|
35
|
+
return torch::nn::init::normal_(tensor, mean, std);
|
36
|
+
})
|
37
|
+
.define_singleton_method(
|
38
|
+
"_constant!",
|
39
|
+
*[](Tensor tensor, Scalar value) {
|
40
|
+
return torch::nn::init::constant_(tensor, value);
|
41
|
+
})
|
42
|
+
.define_singleton_method(
|
43
|
+
"_ones!",
|
44
|
+
*[](Tensor tensor) {
|
45
|
+
return torch::nn::init::ones_(tensor);
|
46
|
+
})
|
47
|
+
.define_singleton_method(
|
48
|
+
"_zeros!",
|
49
|
+
*[](Tensor tensor) {
|
50
|
+
return torch::nn::init::zeros_(tensor);
|
51
|
+
})
|
52
|
+
.define_singleton_method(
|
53
|
+
"_eye!",
|
54
|
+
*[](Tensor tensor) {
|
55
|
+
return torch::nn::init::eye_(tensor);
|
56
|
+
})
|
57
|
+
.define_singleton_method(
|
58
|
+
"_dirac!",
|
59
|
+
*[](Tensor tensor) {
|
60
|
+
return torch::nn::init::dirac_(tensor);
|
61
|
+
})
|
62
|
+
.define_singleton_method(
|
63
|
+
"_xavier_uniform!",
|
64
|
+
*[](Tensor tensor, double gain) {
|
65
|
+
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
|
+
})
|
67
|
+
.define_singleton_method(
|
68
|
+
"_xavier_normal!",
|
69
|
+
*[](Tensor tensor, double gain) {
|
70
|
+
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
|
+
})
|
72
|
+
.define_singleton_method(
|
73
|
+
"_kaiming_uniform!",
|
74
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
|
+
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
|
+
})
|
77
|
+
.define_singleton_method(
|
78
|
+
"_kaiming_normal!",
|
79
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
|
+
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
|
+
})
|
82
|
+
.define_singleton_method(
|
83
|
+
"_orthogonal!",
|
84
|
+
*[](Tensor tensor, double gain) {
|
85
|
+
return torch::nn::init::orthogonal_(tensor, gain);
|
86
|
+
})
|
87
|
+
.define_singleton_method(
|
88
|
+
"_sparse!",
|
89
|
+
*[](Tensor tensor, double sparsity, double std) {
|
90
|
+
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
|
+
});
|
92
|
+
|
93
|
+
Rice::define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
94
|
+
.add_handler<torch::Error>(handle_error)
|
95
|
+
.define_method(
|
96
|
+
"grad",
|
97
|
+
*[](Parameter& self) {
|
98
|
+
auto grad = self.grad();
|
99
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
100
|
+
})
|
101
|
+
.define_method(
|
102
|
+
"grad=",
|
103
|
+
*[](Parameter& self, torch::Tensor& grad) {
|
104
|
+
self.mutable_grad() = grad;
|
105
|
+
})
|
106
|
+
.define_singleton_method(
|
107
|
+
"_make_subclass",
|
108
|
+
*[](Tensor& rd, bool requires_grad) {
|
109
|
+
auto data = rd.detach();
|
110
|
+
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
|
+
auto var = data.set_requires_grad(requires_grad);
|
112
|
+
return Parameter(std::move(var));
|
113
|
+
});
|
114
|
+
}
|
data/ext/torch/nn_functions.h
CHANGED
@@ -0,0 +1,22 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_random(Rice::Module& m) {
|
8
|
+
Rice::define_module_under(m, "Random")
|
9
|
+
.add_handler<torch::Error>(handle_error)
|
10
|
+
.define_singleton_method(
|
11
|
+
"initial_seed",
|
12
|
+
*[]() {
|
13
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
|
+
})
|
15
|
+
.define_singleton_method(
|
16
|
+
"seed",
|
17
|
+
*[]() {
|
18
|
+
// TODO set for CUDA when available
|
19
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
20
|
+
return generator.seed();
|
21
|
+
});
|
22
|
+
}
|
@@ -487,7 +487,7 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
487
487
|
|
488
488
|
VALUE missing = Qundef;
|
489
489
|
|
490
|
-
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs,
|
490
|
+
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], // NOLINT
|
491
491
|
bool raise_exception) {
|
492
492
|
auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
|
493
493
|
ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
#pragma once
|
4
4
|
|
5
|
+
#include <sstream>
|
6
|
+
|
5
7
|
#include <torch/torch.h>
|
6
8
|
#include <rice/Exception.hpp>
|
7
9
|
|
@@ -46,7 +48,7 @@ struct FunctionParameter {
|
|
46
48
|
struct FunctionSignature {
|
47
49
|
explicit FunctionSignature(const std::string& fmt, int index);
|
48
50
|
|
49
|
-
bool parse(VALUE self, VALUE args, VALUE kwargs,
|
51
|
+
bool parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], bool raise_exception);
|
50
52
|
|
51
53
|
std::string toString() const;
|
52
54
|
|
@@ -63,13 +65,13 @@ struct FunctionSignature {
|
|
63
65
|
};
|
64
66
|
|
65
67
|
struct RubyArgs {
|
66
|
-
RubyArgs(const FunctionSignature& signature,
|
68
|
+
RubyArgs(const FunctionSignature& signature, VALUE* args)
|
67
69
|
: signature(signature)
|
68
70
|
, args(args)
|
69
71
|
, idx(signature.index) {}
|
70
72
|
|
71
73
|
const FunctionSignature& signature;
|
72
|
-
|
74
|
+
VALUE* args;
|
73
75
|
int idx;
|
74
76
|
|
75
77
|
inline at::Tensor tensor(int i);
|
@@ -328,6 +330,12 @@ inline bool RubyArgs::isNone(int i) {
|
|
328
330
|
return NIL_P(args[i]);
|
329
331
|
}
|
330
332
|
|
333
|
+
template<int N>
|
334
|
+
struct ParsedArgs {
|
335
|
+
ParsedArgs() : args() { }
|
336
|
+
VALUE args[N];
|
337
|
+
};
|
338
|
+
|
331
339
|
struct RubyArgParser {
|
332
340
|
std::vector<FunctionSignature> signatures_;
|
333
341
|
std::string function_name;
|
@@ -356,7 +364,15 @@ struct RubyArgParser {
|
|
356
364
|
});
|
357
365
|
}
|
358
366
|
|
359
|
-
|
367
|
+
template<int N>
|
368
|
+
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
|
369
|
+
if (N < max_args) {
|
370
|
+
rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N);
|
371
|
+
}
|
372
|
+
return raw_parse(self, argc, argv, dst.args);
|
373
|
+
}
|
374
|
+
|
375
|
+
inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
|
360
376
|
VALUE args, kwargs;
|
361
377
|
rb_scan_args(argc, argv, "*:", &args, &kwargs);
|
362
378
|
|
@@ -378,7 +394,7 @@ struct RubyArgParser {
|
|
378
394
|
rb_raise(rb_eArgError, "No matching signatures");
|
379
395
|
}
|
380
396
|
|
381
|
-
void print_error(VALUE self, VALUE args, VALUE kwargs,
|
397
|
+
void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
|
382
398
|
ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
|
383
399
|
std::vector<int> plausible_idxs;
|
384
400
|
ssize_t i = 0;
|
data/ext/torch/templates.h
CHANGED
@@ -44,7 +44,7 @@ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
|
44
44
|
{
|
45
45
|
Array a = Array(x);
|
46
46
|
std::vector<int64_t> vec(a.size());
|
47
|
-
for (
|
47
|
+
for (long i = 0; i < a.size(); i++) {
|
48
48
|
vec[i] = from_ruby<int64_t>(a[i]);
|
49
49
|
}
|
50
50
|
return vec;
|
@@ -56,7 +56,7 @@ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
|
56
56
|
{
|
57
57
|
Array a = Array(x);
|
58
58
|
std::vector<Tensor> vec(a.size());
|
59
|
-
for (
|
59
|
+
for (long i = 0; i < a.size(); i++) {
|
60
60
|
vec[i] = from_ruby<Tensor>(a[i]);
|
61
61
|
}
|
62
62
|
return vec;
|
@@ -0,0 +1,307 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Constructor.hpp>
|
4
|
+
#include <rice/Module.hpp>
|
5
|
+
|
6
|
+
#include "tensor_functions.h"
|
7
|
+
#include "ruby_arg_parser.h"
|
8
|
+
#include "templates.h"
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
using namespace Rice;
|
12
|
+
using torch::indexing::TensorIndex;
|
13
|
+
|
14
|
+
Class rb_cTensor;
|
15
|
+
|
16
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
17
|
+
Object obj;
|
18
|
+
|
19
|
+
std::vector<TensorIndex> indices;
|
20
|
+
indices.reserve(a.size());
|
21
|
+
|
22
|
+
for (long i = 0; i < a.size(); i++) {
|
23
|
+
obj = a[i];
|
24
|
+
|
25
|
+
if (obj.is_instance_of(rb_cInteger)) {
|
26
|
+
indices.push_back(from_ruby<int64_t>(obj));
|
27
|
+
} else if (obj.is_instance_of(rb_cRange)) {
|
28
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
29
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
30
|
+
|
31
|
+
Object begin = obj.call("begin");
|
32
|
+
if (!begin.is_nil()) {
|
33
|
+
start_index = from_ruby<int64_t>(begin);
|
34
|
+
}
|
35
|
+
|
36
|
+
Object end = obj.call("end");
|
37
|
+
if (!end.is_nil()) {
|
38
|
+
stop_index = from_ruby<int64_t>(end);
|
39
|
+
}
|
40
|
+
|
41
|
+
Object exclude_end = obj.call("exclude_end?");
|
42
|
+
if (stop_index.has_value() && !exclude_end) {
|
43
|
+
if (stop_index.value() == -1) {
|
44
|
+
stop_index = torch::nullopt;
|
45
|
+
} else {
|
46
|
+
stop_index = stop_index.value() + 1;
|
47
|
+
}
|
48
|
+
}
|
49
|
+
|
50
|
+
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
51
|
+
} else if (obj.is_instance_of(rb_cTensor)) {
|
52
|
+
indices.push_back(from_ruby<Tensor>(obj));
|
53
|
+
} else if (obj.is_nil()) {
|
54
|
+
indices.push_back(torch::indexing::None);
|
55
|
+
} else if (obj == True || obj == False) {
|
56
|
+
indices.push_back(from_ruby<bool>(obj));
|
57
|
+
} else {
|
58
|
+
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
59
|
+
}
|
60
|
+
}
|
61
|
+
return indices;
|
62
|
+
}
|
63
|
+
|
64
|
+
// hack (removes inputs argument)
|
65
|
+
// https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
|
66
|
+
// TODO add support for inputs argument
|
67
|
+
// _backward
|
68
|
+
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
69
|
+
{
|
70
|
+
HANDLE_TH_ERRORS
|
71
|
+
Tensor& self = from_ruby<Tensor&>(self_);
|
72
|
+
static RubyArgParser parser({
|
73
|
+
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
74
|
+
});
|
75
|
+
ParsedArgs<4> parsed_args;
|
76
|
+
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
77
|
+
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
78
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
79
|
+
// in future, release GVL
|
80
|
+
self._backward(inputs, gradient, retain_graph, create_graph);
|
81
|
+
};
|
82
|
+
dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
|
83
|
+
RETURN_NIL
|
84
|
+
END_HANDLE_TH_ERRORS
|
85
|
+
}
|
86
|
+
|
87
|
+
void init_tensor(Rice::Module& m) {
|
88
|
+
rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
89
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
90
|
+
add_tensor_functions(rb_cTensor);
|
91
|
+
THPVariableClass = rb_cTensor.value();
|
92
|
+
|
93
|
+
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
94
|
+
|
95
|
+
rb_cTensor
|
96
|
+
.define_method("cuda?", &torch::Tensor::is_cuda)
|
97
|
+
.define_method("sparse?", &torch::Tensor::is_sparse)
|
98
|
+
.define_method("quantized?", &torch::Tensor::is_quantized)
|
99
|
+
.define_method("dim", &torch::Tensor::dim)
|
100
|
+
.define_method("numel", &torch::Tensor::numel)
|
101
|
+
.define_method("element_size", &torch::Tensor::element_size)
|
102
|
+
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
103
|
+
.define_method(
|
104
|
+
"_size",
|
105
|
+
*[](Tensor& self, int64_t dim) {
|
106
|
+
return self.size(dim);
|
107
|
+
})
|
108
|
+
.define_method(
|
109
|
+
"_stride",
|
110
|
+
*[](Tensor& self, int64_t dim) {
|
111
|
+
return self.stride(dim);
|
112
|
+
})
|
113
|
+
// in C++ for performance
|
114
|
+
.define_method(
|
115
|
+
"shape",
|
116
|
+
*[](Tensor& self) {
|
117
|
+
Array a;
|
118
|
+
for (auto &size : self.sizes()) {
|
119
|
+
a.push(size);
|
120
|
+
}
|
121
|
+
return a;
|
122
|
+
})
|
123
|
+
.define_method(
|
124
|
+
"_strides",
|
125
|
+
*[](Tensor& self) {
|
126
|
+
Array a;
|
127
|
+
for (auto &stride : self.strides()) {
|
128
|
+
a.push(stride);
|
129
|
+
}
|
130
|
+
return a;
|
131
|
+
})
|
132
|
+
.define_method(
|
133
|
+
"_index",
|
134
|
+
*[](Tensor& self, Array indices) {
|
135
|
+
auto vec = index_vector(indices);
|
136
|
+
return self.index(vec);
|
137
|
+
})
|
138
|
+
.define_method(
|
139
|
+
"_index_put_custom",
|
140
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
141
|
+
auto vec = index_vector(indices);
|
142
|
+
return self.index_put_(vec, value);
|
143
|
+
})
|
144
|
+
.define_method(
|
145
|
+
"contiguous?",
|
146
|
+
*[](Tensor& self) {
|
147
|
+
return self.is_contiguous();
|
148
|
+
})
|
149
|
+
.define_method(
|
150
|
+
"_requires_grad!",
|
151
|
+
*[](Tensor& self, bool requires_grad) {
|
152
|
+
return self.set_requires_grad(requires_grad);
|
153
|
+
})
|
154
|
+
.define_method(
|
155
|
+
"grad",
|
156
|
+
*[](Tensor& self) {
|
157
|
+
auto grad = self.grad();
|
158
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
159
|
+
})
|
160
|
+
.define_method(
|
161
|
+
"grad=",
|
162
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
163
|
+
self.mutable_grad() = grad;
|
164
|
+
})
|
165
|
+
.define_method(
|
166
|
+
"_dtype",
|
167
|
+
*[](Tensor& self) {
|
168
|
+
return (int) at::typeMetaToScalarType(self.dtype());
|
169
|
+
})
|
170
|
+
.define_method(
|
171
|
+
"_type",
|
172
|
+
*[](Tensor& self, int dtype) {
|
173
|
+
return self.toType((torch::ScalarType) dtype);
|
174
|
+
})
|
175
|
+
.define_method(
|
176
|
+
"_layout",
|
177
|
+
*[](Tensor& self) {
|
178
|
+
std::stringstream s;
|
179
|
+
s << self.layout();
|
180
|
+
return s.str();
|
181
|
+
})
|
182
|
+
.define_method(
|
183
|
+
"device",
|
184
|
+
*[](Tensor& self) {
|
185
|
+
std::stringstream s;
|
186
|
+
s << self.device();
|
187
|
+
return s.str();
|
188
|
+
})
|
189
|
+
.define_method(
|
190
|
+
"_data_str",
|
191
|
+
*[](Tensor& self) {
|
192
|
+
Tensor tensor = self;
|
193
|
+
|
194
|
+
// move to CPU to get data
|
195
|
+
if (tensor.device().type() != torch::kCPU) {
|
196
|
+
torch::Device device("cpu");
|
197
|
+
tensor = tensor.to(device);
|
198
|
+
}
|
199
|
+
|
200
|
+
if (!tensor.is_contiguous()) {
|
201
|
+
tensor = tensor.contiguous();
|
202
|
+
}
|
203
|
+
|
204
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
205
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
206
|
+
})
|
207
|
+
// for TorchVision
|
208
|
+
.define_method(
|
209
|
+
"_data_ptr",
|
210
|
+
*[](Tensor& self) {
|
211
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
212
|
+
})
|
213
|
+
// TODO figure out a better way to do this
|
214
|
+
.define_method(
|
215
|
+
"_flat_data",
|
216
|
+
*[](Tensor& self) {
|
217
|
+
Tensor tensor = self;
|
218
|
+
|
219
|
+
// move to CPU to get data
|
220
|
+
if (tensor.device().type() != torch::kCPU) {
|
221
|
+
torch::Device device("cpu");
|
222
|
+
tensor = tensor.to(device);
|
223
|
+
}
|
224
|
+
|
225
|
+
Array a;
|
226
|
+
auto dtype = tensor.dtype();
|
227
|
+
|
228
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
229
|
+
|
230
|
+
// TODO DRY if someone knows C++
|
231
|
+
if (dtype == torch::kByte) {
|
232
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
233
|
+
a.push(view[i].item().to<uint8_t>());
|
234
|
+
}
|
235
|
+
} else if (dtype == torch::kChar) {
|
236
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
237
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
238
|
+
}
|
239
|
+
} else if (dtype == torch::kShort) {
|
240
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
241
|
+
a.push(view[i].item().to<int16_t>());
|
242
|
+
}
|
243
|
+
} else if (dtype == torch::kInt) {
|
244
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
245
|
+
a.push(view[i].item().to<int32_t>());
|
246
|
+
}
|
247
|
+
} else if (dtype == torch::kLong) {
|
248
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
249
|
+
a.push(view[i].item().to<int64_t>());
|
250
|
+
}
|
251
|
+
} else if (dtype == torch::kFloat) {
|
252
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
253
|
+
a.push(view[i].item().to<float>());
|
254
|
+
}
|
255
|
+
} else if (dtype == torch::kDouble) {
|
256
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
257
|
+
a.push(view[i].item().to<double>());
|
258
|
+
}
|
259
|
+
} else if (dtype == torch::kBool) {
|
260
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
261
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
262
|
+
}
|
263
|
+
} else {
|
264
|
+
throw std::runtime_error("Unsupported type");
|
265
|
+
}
|
266
|
+
return a;
|
267
|
+
})
|
268
|
+
.define_method(
|
269
|
+
"_to",
|
270
|
+
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
271
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
272
|
+
});
|
273
|
+
|
274
|
+
Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
275
|
+
.add_handler<torch::Error>(handle_error)
|
276
|
+
.define_constructor(Rice::Constructor<torch::TensorOptions>())
|
277
|
+
.define_method(
|
278
|
+
"dtype",
|
279
|
+
*[](torch::TensorOptions& self, int dtype) {
|
280
|
+
return self.dtype((torch::ScalarType) dtype);
|
281
|
+
})
|
282
|
+
.define_method(
|
283
|
+
"layout",
|
284
|
+
*[](torch::TensorOptions& self, const std::string& layout) {
|
285
|
+
torch::Layout l;
|
286
|
+
if (layout == "strided") {
|
287
|
+
l = torch::kStrided;
|
288
|
+
} else if (layout == "sparse") {
|
289
|
+
l = torch::kSparse;
|
290
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
291
|
+
} else {
|
292
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
293
|
+
}
|
294
|
+
return self.layout(l);
|
295
|
+
})
|
296
|
+
.define_method(
|
297
|
+
"device",
|
298
|
+
*[](torch::TensorOptions& self, const std::string& device) {
|
299
|
+
torch::Device d(device);
|
300
|
+
return self.device(d);
|
301
|
+
})
|
302
|
+
.define_method(
|
303
|
+
"requires_grad",
|
304
|
+
*[](torch::TensorOptions& self, bool requires_grad) {
|
305
|
+
return self.requires_grad(requires_grad);
|
306
|
+
});
|
307
|
+
}
|