torch-rb 0.1.3 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/README.md +5 -2
- data/ext/torch/ext.cpp +130 -555
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +244 -0
- data/lib/torch.rb +209 -171
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +110 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6491 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +10 -20
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +379 -32
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +186 -35
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +51 -44
- data/lib/torch/version.rb +1 -1
- metadata +98 -6
- data/lib/torch/ext.bundle +0 -0
data/ext/torch/extconf.rb
CHANGED
@@ -10,6 +10,9 @@ $CXXFLAGS << " -std=c++11"
|
|
10
10
|
# silence ruby/intern.h warning
|
11
11
|
$CXXFLAGS << " -Wno-deprecated-register"
|
12
12
|
|
13
|
+
# silence torch warnings
|
14
|
+
$CXXFLAGS << " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
15
|
+
|
13
16
|
inc, lib = dir_config("torch")
|
14
17
|
|
15
18
|
inc ||= "/usr/local/include"
|
@@ -22,4 +25,10 @@ $LDFLAGS << " -Wl,-rpath,#{lib}"
|
|
22
25
|
$LDFLAGS << " -L#{lib}"
|
23
26
|
$LDFLAGS << " -ltorch -lc10"
|
24
27
|
|
28
|
+
# generate C++ functions
|
29
|
+
puts "Generating C++ functions..."
|
30
|
+
require_relative "../../lib/torch/native/generator"
|
31
|
+
Torch::Native::Generator.generate_cpp_functions
|
32
|
+
|
33
|
+
# create makefile
|
25
34
|
create_makefile("torch/ext")
|
@@ -0,0 +1,55 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
#include <rice/Object.hpp>
|
3
|
+
#include "templates.hpp"
|
4
|
+
|
5
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
|
+
Array a;
|
7
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
8
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
9
|
+
return Object(a);
|
10
|
+
}
|
11
|
+
|
12
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
13
|
+
Array a;
|
14
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
15
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
16
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
17
|
+
return Object(a);
|
18
|
+
}
|
19
|
+
|
20
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
21
|
+
Array a;
|
22
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
23
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
24
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
25
|
+
a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
|
26
|
+
return Object(a);
|
27
|
+
}
|
28
|
+
|
29
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
30
|
+
Array a;
|
31
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
32
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
33
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
34
|
+
a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
|
35
|
+
a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
|
36
|
+
return Object(a);
|
37
|
+
}
|
38
|
+
|
39
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
40
|
+
Array a;
|
41
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
42
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
43
|
+
a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
|
44
|
+
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
45
|
+
return Object(a);
|
46
|
+
}
|
47
|
+
|
48
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
49
|
+
Array a;
|
50
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
51
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
52
|
+
a.push(to_ruby<double>(std::get<2>(x)));
|
53
|
+
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
|
+
return Object(a);
|
55
|
+
}
|
@@ -0,0 +1,244 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#ifdef isfinite
|
4
|
+
#undef isfinite
|
5
|
+
#endif
|
6
|
+
|
7
|
+
#include <rice/Array.hpp>
|
8
|
+
#include <rice/Object.hpp>
|
9
|
+
|
10
|
+
using namespace Rice;
|
11
|
+
|
12
|
+
// need to wrap torch::IntArrayRef() since
|
13
|
+
// it doesn't own underlying data
|
14
|
+
class IntArrayRef {
|
15
|
+
std::vector<int64_t> vec;
|
16
|
+
public:
|
17
|
+
IntArrayRef(Object o) {
|
18
|
+
Array a = Array(o);
|
19
|
+
for (size_t i = 0; i < a.size(); i++) {
|
20
|
+
vec.push_back(from_ruby<int64_t>(a[i]));
|
21
|
+
}
|
22
|
+
}
|
23
|
+
operator torch::IntArrayRef() {
|
24
|
+
return torch::IntArrayRef(vec);
|
25
|
+
}
|
26
|
+
};
|
27
|
+
|
28
|
+
template<>
|
29
|
+
inline
|
30
|
+
IntArrayRef from_ruby<IntArrayRef>(Object x)
|
31
|
+
{
|
32
|
+
return IntArrayRef(x);
|
33
|
+
}
|
34
|
+
|
35
|
+
// for now
|
36
|
+
class Scalar {
|
37
|
+
torch::Scalar value;
|
38
|
+
public:
|
39
|
+
Scalar(Object o) {
|
40
|
+
// TODO cast based on Ruby type
|
41
|
+
if (o.rb_type() == T_FIXNUM) {
|
42
|
+
value = torch::Scalar(from_ruby<int64_t>(o));
|
43
|
+
} else {
|
44
|
+
value = torch::Scalar(from_ruby<float>(o));
|
45
|
+
}
|
46
|
+
}
|
47
|
+
operator torch::Scalar() {
|
48
|
+
return value;
|
49
|
+
}
|
50
|
+
};
|
51
|
+
|
52
|
+
template<>
|
53
|
+
inline
|
54
|
+
Scalar from_ruby<Scalar>(Object x)
|
55
|
+
{
|
56
|
+
return Scalar(x);
|
57
|
+
}
|
58
|
+
|
59
|
+
class TensorList {
|
60
|
+
std::vector<torch::Tensor> vec;
|
61
|
+
public:
|
62
|
+
TensorList(Object o) {
|
63
|
+
Array a = Array(o);
|
64
|
+
for (size_t i = 0; i < a.size(); i++) {
|
65
|
+
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
66
|
+
}
|
67
|
+
}
|
68
|
+
operator torch::TensorList() {
|
69
|
+
return torch::TensorList(vec);
|
70
|
+
}
|
71
|
+
};
|
72
|
+
|
73
|
+
template<>
|
74
|
+
inline
|
75
|
+
TensorList from_ruby<TensorList>(Object x)
|
76
|
+
{
|
77
|
+
return TensorList(x);
|
78
|
+
}
|
79
|
+
|
80
|
+
class FanModeType {
|
81
|
+
std::string s;
|
82
|
+
public:
|
83
|
+
FanModeType(Object o) {
|
84
|
+
s = String(o).str();
|
85
|
+
}
|
86
|
+
operator torch::nn::init::FanModeType() {
|
87
|
+
if (s == "fan_in") {
|
88
|
+
return torch::kFanIn;
|
89
|
+
} else if (s == "fan_out") {
|
90
|
+
return torch::kFanOut;
|
91
|
+
} else {
|
92
|
+
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
93
|
+
}
|
94
|
+
}
|
95
|
+
};
|
96
|
+
|
97
|
+
template<>
|
98
|
+
inline
|
99
|
+
FanModeType from_ruby<FanModeType>(Object x)
|
100
|
+
{
|
101
|
+
return FanModeType(x);
|
102
|
+
}
|
103
|
+
|
104
|
+
class NonlinearityType {
|
105
|
+
std::string s;
|
106
|
+
public:
|
107
|
+
NonlinearityType(Object o) {
|
108
|
+
s = String(o).str();
|
109
|
+
}
|
110
|
+
operator torch::nn::init::NonlinearityType() {
|
111
|
+
if (s == "linear") {
|
112
|
+
return torch::kLinear;
|
113
|
+
} else if (s == "conv1d") {
|
114
|
+
return torch::kConv1D;
|
115
|
+
} else if (s == "conv2d") {
|
116
|
+
return torch::kConv2D;
|
117
|
+
} else if (s == "conv3d") {
|
118
|
+
return torch::kConv3D;
|
119
|
+
} else if (s == "conv_transpose1d") {
|
120
|
+
return torch::kConvTranspose1D;
|
121
|
+
} else if (s == "conv_transpose2d") {
|
122
|
+
return torch::kConvTranspose2D;
|
123
|
+
} else if (s == "conv_transpose3d") {
|
124
|
+
return torch::kConvTranspose3D;
|
125
|
+
} else if (s == "sigmoid") {
|
126
|
+
return torch::kSigmoid;
|
127
|
+
} else if (s == "tanh") {
|
128
|
+
return torch::kTanh;
|
129
|
+
} else if (s == "relu") {
|
130
|
+
return torch::kReLU;
|
131
|
+
} else if (s == "leaky_relu") {
|
132
|
+
return torch::kLeakyReLU;
|
133
|
+
} else {
|
134
|
+
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
135
|
+
}
|
136
|
+
}
|
137
|
+
};
|
138
|
+
|
139
|
+
template<>
|
140
|
+
inline
|
141
|
+
NonlinearityType from_ruby<NonlinearityType>(Object x)
|
142
|
+
{
|
143
|
+
return NonlinearityType(x);
|
144
|
+
}
|
145
|
+
|
146
|
+
class MyReduction {
|
147
|
+
Object value;
|
148
|
+
public:
|
149
|
+
MyReduction(Object o) {
|
150
|
+
value = o;
|
151
|
+
}
|
152
|
+
operator int64_t() {
|
153
|
+
if (value.is_nil()) {
|
154
|
+
return torch::Reduction::None;
|
155
|
+
}
|
156
|
+
|
157
|
+
std::string s = String(value).str();
|
158
|
+
if (s == "mean") {
|
159
|
+
return torch::Reduction::Mean;
|
160
|
+
} else if (s == "sum") {
|
161
|
+
return torch::Reduction::Sum;
|
162
|
+
} else {
|
163
|
+
throw std::runtime_error("Unsupported reduction: " + s);
|
164
|
+
}
|
165
|
+
}
|
166
|
+
};
|
167
|
+
|
168
|
+
template<>
|
169
|
+
inline
|
170
|
+
MyReduction from_ruby<MyReduction>(Object x)
|
171
|
+
{
|
172
|
+
return MyReduction(x);
|
173
|
+
}
|
174
|
+
|
175
|
+
typedef torch::Tensor Tensor;
|
176
|
+
|
177
|
+
class OptionalTensor {
|
178
|
+
Object value;
|
179
|
+
public:
|
180
|
+
OptionalTensor(Object o) {
|
181
|
+
value = o;
|
182
|
+
}
|
183
|
+
operator torch::Tensor() {
|
184
|
+
if (value.is_nil()) {
|
185
|
+
return {};
|
186
|
+
}
|
187
|
+
return from_ruby<torch::Tensor>(value);
|
188
|
+
}
|
189
|
+
};
|
190
|
+
|
191
|
+
template<>
|
192
|
+
inline
|
193
|
+
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
194
|
+
{
|
195
|
+
return OptionalTensor(x);
|
196
|
+
}
|
197
|
+
|
198
|
+
class ScalarType {
|
199
|
+
Object value;
|
200
|
+
public:
|
201
|
+
ScalarType(Object o) {
|
202
|
+
value = o;
|
203
|
+
}
|
204
|
+
operator at::ScalarType() {
|
205
|
+
throw std::runtime_error("ScalarType arguments not implemented yet");
|
206
|
+
}
|
207
|
+
};
|
208
|
+
|
209
|
+
template<>
|
210
|
+
inline
|
211
|
+
ScalarType from_ruby<ScalarType>(Object x)
|
212
|
+
{
|
213
|
+
return ScalarType(x);
|
214
|
+
}
|
215
|
+
|
216
|
+
class OptionalScalarType {
|
217
|
+
Object value;
|
218
|
+
public:
|
219
|
+
OptionalScalarType(Object o) {
|
220
|
+
value = o;
|
221
|
+
}
|
222
|
+
operator c10::optional<at::ScalarType>() {
|
223
|
+
if (value.is_nil()) {
|
224
|
+
return c10::nullopt;
|
225
|
+
}
|
226
|
+
return ScalarType(value);
|
227
|
+
}
|
228
|
+
};
|
229
|
+
|
230
|
+
template<>
|
231
|
+
inline
|
232
|
+
OptionalScalarType from_ruby<OptionalScalarType>(Object x)
|
233
|
+
{
|
234
|
+
return OptionalScalarType(x);
|
235
|
+
}
|
236
|
+
|
237
|
+
typedef torch::Device Device;
|
238
|
+
|
239
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
240
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
241
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
242
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
|
244
|
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
|
data/lib/torch.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# ext
|
2
2
|
require "torch/ext"
|
3
3
|
|
4
|
+
# native functions
|
5
|
+
require "torch/native/generator"
|
6
|
+
require "torch/native/parser"
|
7
|
+
require "torch/native/dispatcher"
|
8
|
+
|
4
9
|
# modules
|
5
10
|
require "torch/inspector"
|
6
11
|
require "torch/tensor"
|
@@ -22,31 +27,145 @@ require "torch/optim/sgd"
|
|
22
27
|
require "torch/optim/lr_scheduler/lr_scheduler"
|
23
28
|
require "torch/optim/lr_scheduler/step_lr"
|
24
29
|
|
25
|
-
# nn
|
30
|
+
# nn parameters
|
31
|
+
require "torch/nn/parameter"
|
32
|
+
require "torch/nn/utils"
|
33
|
+
|
34
|
+
# nn containers
|
26
35
|
require "torch/nn/module"
|
36
|
+
require "torch/nn/sequential"
|
37
|
+
|
38
|
+
# nn convolution layers
|
27
39
|
require "torch/nn/convnd"
|
28
|
-
require "torch/nn/
|
40
|
+
require "torch/nn/conv1d"
|
41
|
+
require "torch/nn/conv2d"
|
42
|
+
require "torch/nn/conv3d"
|
43
|
+
require "torch/nn/unfold"
|
44
|
+
require "torch/nn/fold"
|
45
|
+
|
46
|
+
# nn pooling layers
|
47
|
+
require "torch/nn/max_poolnd"
|
48
|
+
require "torch/nn/max_pool1d"
|
49
|
+
require "torch/nn/max_pool2d"
|
50
|
+
require "torch/nn/max_pool3d"
|
51
|
+
require "torch/nn/max_unpoolnd"
|
52
|
+
require "torch/nn/max_unpool1d"
|
53
|
+
require "torch/nn/max_unpool2d"
|
54
|
+
require "torch/nn/max_unpool3d"
|
55
|
+
require "torch/nn/avg_poolnd"
|
56
|
+
require "torch/nn/avg_pool1d"
|
57
|
+
require "torch/nn/avg_pool2d"
|
58
|
+
require "torch/nn/avg_pool3d"
|
59
|
+
require "torch/nn/lp_poolnd"
|
60
|
+
require "torch/nn/lp_pool1d"
|
61
|
+
require "torch/nn/lp_pool2d"
|
62
|
+
|
63
|
+
# nn padding layers
|
64
|
+
require "torch/nn/reflection_padnd"
|
65
|
+
require "torch/nn/reflection_pad1d"
|
66
|
+
require "torch/nn/reflection_pad2d"
|
67
|
+
require "torch/nn/replication_padnd"
|
68
|
+
require "torch/nn/replication_pad1d"
|
69
|
+
require "torch/nn/replication_pad2d"
|
70
|
+
require "torch/nn/replication_pad3d"
|
71
|
+
require "torch/nn/constant_padnd"
|
72
|
+
require "torch/nn/constant_pad1d"
|
73
|
+
require "torch/nn/constant_pad2d"
|
74
|
+
require "torch/nn/constant_pad3d"
|
75
|
+
require "torch/nn/zero_pad2d"
|
76
|
+
|
77
|
+
# nn normalization layers
|
78
|
+
require "torch/nn/batch_norm"
|
79
|
+
require "torch/nn/batch_norm1d"
|
80
|
+
require "torch/nn/batch_norm2d"
|
81
|
+
require "torch/nn/batch_norm3d"
|
82
|
+
require "torch/nn/group_norm"
|
83
|
+
require "torch/nn/instance_norm"
|
84
|
+
require "torch/nn/instance_norm1d"
|
85
|
+
require "torch/nn/instance_norm2d"
|
86
|
+
require "torch/nn/instance_norm3d"
|
87
|
+
require "torch/nn/layer_norm"
|
88
|
+
require "torch/nn/local_response_norm"
|
89
|
+
|
90
|
+
# nn recurrent layers
|
91
|
+
require "torch/nn/rnn_base"
|
92
|
+
require "torch/nn/rnn"
|
93
|
+
require "torch/nn/lstm"
|
94
|
+
require "torch/nn/gru"
|
95
|
+
|
96
|
+
# nn linear layers
|
97
|
+
require "torch/nn/bilinear"
|
98
|
+
require "torch/nn/identity"
|
99
|
+
require "torch/nn/linear"
|
29
100
|
|
30
|
-
# nn
|
101
|
+
# nn dropout layers
|
102
|
+
require "torch/nn/dropoutnd"
|
31
103
|
require "torch/nn/alpha_dropout"
|
32
|
-
require "torch/nn/conv2d"
|
33
104
|
require "torch/nn/dropout"
|
34
105
|
require "torch/nn/dropout2d"
|
35
106
|
require "torch/nn/dropout3d"
|
36
|
-
require "torch/nn/embedding"
|
37
107
|
require "torch/nn/feature_alpha_dropout"
|
108
|
+
|
109
|
+
# nn activations
|
110
|
+
require "torch/nn/hardshrink"
|
111
|
+
require "torch/nn/leaky_relu"
|
112
|
+
require "torch/nn/log_sigmoid"
|
113
|
+
require "torch/nn/prelu"
|
114
|
+
require "torch/nn/relu"
|
115
|
+
require "torch/nn/sigmoid"
|
116
|
+
require "torch/nn/softplus"
|
117
|
+
require "torch/nn/softshrink"
|
118
|
+
require "torch/nn/softsign"
|
119
|
+
require "torch/nn/tanh"
|
120
|
+
require "torch/nn/tanhshrink"
|
121
|
+
|
122
|
+
# nn activations other
|
123
|
+
require "torch/nn/log_softmax"
|
124
|
+
require "torch/nn/softmax"
|
125
|
+
require "torch/nn/softmax2d"
|
126
|
+
require "torch/nn/softmin"
|
127
|
+
|
128
|
+
# nn sparse layers
|
129
|
+
require "torch/nn/embedding"
|
130
|
+
require "torch/nn/embedding_bag"
|
131
|
+
|
132
|
+
# nn distance functions
|
133
|
+
require "torch/nn/cosine_similarity"
|
134
|
+
require "torch/nn/pairwise_distance"
|
135
|
+
|
136
|
+
# nn loss functions
|
137
|
+
require "torch/nn/loss"
|
138
|
+
require "torch/nn/weighted_loss"
|
139
|
+
require "torch/nn/bce_loss"
|
140
|
+
require "torch/nn/bce_with_logits_loss"
|
141
|
+
require "torch/nn/cosine_embedding_loss"
|
142
|
+
require "torch/nn/cross_entropy_loss"
|
143
|
+
require "torch/nn/ctc_loss"
|
144
|
+
require "torch/nn/hinge_embedding_loss"
|
145
|
+
require "torch/nn/kl_div_loss"
|
146
|
+
require "torch/nn/l1_loss"
|
147
|
+
require "torch/nn/margin_ranking_loss"
|
148
|
+
require "torch/nn/mse_loss"
|
149
|
+
require "torch/nn/multi_label_margin_loss"
|
150
|
+
require "torch/nn/multi_label_soft_margin_loss"
|
151
|
+
require "torch/nn/multi_margin_loss"
|
152
|
+
require "torch/nn/nll_loss"
|
153
|
+
require "torch/nn/poisson_nll_loss"
|
154
|
+
require "torch/nn/smooth_l1_loss"
|
155
|
+
require "torch/nn/soft_margin_loss"
|
156
|
+
require "torch/nn/triplet_margin_loss"
|
157
|
+
|
158
|
+
# nn other
|
38
159
|
require "torch/nn/functional"
|
39
160
|
require "torch/nn/init"
|
40
|
-
require "torch/nn/linear"
|
41
|
-
require "torch/nn/mse_loss"
|
42
|
-
require "torch/nn/parameter"
|
43
|
-
require "torch/nn/relu"
|
44
|
-
require "torch/nn/sequential"
|
45
161
|
|
46
162
|
# utils
|
47
163
|
require "torch/utils/data/data_loader"
|
48
164
|
require "torch/utils/data/tensor_dataset"
|
49
165
|
|
166
|
+
# random
|
167
|
+
require "torch/random"
|
168
|
+
|
50
169
|
module Torch
|
51
170
|
class Error < StandardError; end
|
52
171
|
class NotImplementedYet < StandardError
|
@@ -57,7 +176,6 @@ module Torch
|
|
57
176
|
|
58
177
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
59
178
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
60
|
-
# complex and quantized types not supported by PyTorch yet
|
61
179
|
DTYPE_TO_ENUM = {
|
62
180
|
uint8: 0,
|
63
181
|
int8: 1,
|
@@ -73,17 +191,52 @@ module Torch
|
|
73
191
|
float32: 6,
|
74
192
|
double: 7,
|
75
193
|
float64: 7,
|
76
|
-
|
77
|
-
|
78
|
-
|
194
|
+
complex_half: 8,
|
195
|
+
complex_float: 9,
|
196
|
+
complex_double: 10,
|
79
197
|
bool: 11,
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
198
|
+
qint8: 12,
|
199
|
+
quint8: 13,
|
200
|
+
qint32: 14,
|
201
|
+
bfloat16: 15
|
84
202
|
}
|
85
203
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
86
204
|
|
205
|
+
def self._make_tensor_class(dtype, cuda = false)
|
206
|
+
cls = Class.new
|
207
|
+
device = cuda ? "cuda" : "cpu"
|
208
|
+
cls.define_singleton_method("new") do |*args|
|
209
|
+
if args.size == 1 && args.first.is_a?(Tensor)
|
210
|
+
args.first.send(dtype).to(device)
|
211
|
+
elsif args.size == 1 && args.first.is_a?(Array)
|
212
|
+
Torch.tensor(args.first, dtype: dtype, device: device)
|
213
|
+
else
|
214
|
+
Torch.empty(*args, dtype: dtype, device: device)
|
215
|
+
end
|
216
|
+
end
|
217
|
+
cls
|
218
|
+
end
|
219
|
+
|
220
|
+
FloatTensor = _make_tensor_class(:float32)
|
221
|
+
DoubleTensor = _make_tensor_class(:float64)
|
222
|
+
HalfTensor = _make_tensor_class(:float16)
|
223
|
+
ByteTensor = _make_tensor_class(:uint8)
|
224
|
+
CharTensor = _make_tensor_class(:int8)
|
225
|
+
ShortTensor = _make_tensor_class(:int16)
|
226
|
+
IntTensor = _make_tensor_class(:int32)
|
227
|
+
LongTensor = _make_tensor_class(:int64)
|
228
|
+
BoolTensor = _make_tensor_class(:bool)
|
229
|
+
|
230
|
+
CUDA::FloatTensor = _make_tensor_class(:float32, true)
|
231
|
+
CUDA::DoubleTensor = _make_tensor_class(:float64, true)
|
232
|
+
CUDA::HalfTensor = _make_tensor_class(:float16, true)
|
233
|
+
CUDA::ByteTensor = _make_tensor_class(:uint8, true)
|
234
|
+
CUDA::CharTensor = _make_tensor_class(:int8, true)
|
235
|
+
CUDA::ShortTensor = _make_tensor_class(:int16, true)
|
236
|
+
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
237
|
+
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
238
|
+
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
239
|
+
|
87
240
|
class << self
|
88
241
|
# Torch.float, Torch.long, etc
|
89
242
|
DTYPE_TO_ENUM.each_key do |dtype|
|
@@ -120,6 +273,8 @@ module Torch
|
|
120
273
|
# use method for cases when Numo not available
|
121
274
|
# or available after Torch loaded
|
122
275
|
def _dtype_to_numo
|
276
|
+
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
277
|
+
|
123
278
|
{
|
124
279
|
uint8: Numo::UInt8,
|
125
280
|
int8: Numo::Int8,
|
@@ -131,6 +286,29 @@ module Torch
|
|
131
286
|
}
|
132
287
|
end
|
133
288
|
|
289
|
+
def no_grad
|
290
|
+
previous_value = grad_enabled?
|
291
|
+
begin
|
292
|
+
_set_grad_enabled(false)
|
293
|
+
yield
|
294
|
+
ensure
|
295
|
+
_set_grad_enabled(previous_value)
|
296
|
+
end
|
297
|
+
end
|
298
|
+
|
299
|
+
def device(str)
|
300
|
+
Device.new(str)
|
301
|
+
end
|
302
|
+
|
303
|
+
def save(obj, f)
|
304
|
+
raise NotImplementedYet unless obj.is_a?(Tensor)
|
305
|
+
File.binwrite(f, _save(obj))
|
306
|
+
end
|
307
|
+
|
308
|
+
def load(f)
|
309
|
+
raise NotImplementedYet
|
310
|
+
end
|
311
|
+
|
134
312
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
135
313
|
|
136
314
|
def arange(start, finish = nil, step = 1, **options)
|
@@ -200,8 +378,12 @@ module Torch
|
|
200
378
|
data = [data].compact
|
201
379
|
end
|
202
380
|
|
203
|
-
if options[:dtype].nil?
|
204
|
-
|
381
|
+
if options[:dtype].nil?
|
382
|
+
if data.all? { |v| v.is_a?(Integer) }
|
383
|
+
options[:dtype] = :int64
|
384
|
+
elsif data.all? { |v| v == true || v == false }
|
385
|
+
options[:dtype] = :bool
|
386
|
+
end
|
205
387
|
end
|
206
388
|
|
207
389
|
_tensor(data, size, tensor_options(**options))
|
@@ -210,19 +392,19 @@ module Torch
|
|
210
392
|
# --- begin like ---
|
211
393
|
|
212
394
|
def ones_like(input, **options)
|
213
|
-
ones(input.size, like_options(input, options))
|
395
|
+
ones(input.size, **like_options(input, options))
|
214
396
|
end
|
215
397
|
|
216
398
|
def empty_like(input, **options)
|
217
|
-
empty(input.size, like_options(input, options))
|
399
|
+
empty(input.size, **like_options(input, options))
|
218
400
|
end
|
219
401
|
|
220
402
|
def full_like(input, fill_value, **options)
|
221
|
-
full(input.size, fill_value, like_options(input, options))
|
403
|
+
full(input.size, fill_value, **like_options(input, options))
|
222
404
|
end
|
223
405
|
|
224
406
|
def rand_like(input, **options)
|
225
|
-
rand(input.size, like_options(input, options))
|
407
|
+
rand(input.size, **like_options(input, options))
|
226
408
|
end
|
227
409
|
|
228
410
|
def randint_like(input, low, high = nil, **options)
|
@@ -231,163 +413,19 @@ module Torch
|
|
231
413
|
high = low
|
232
414
|
low = 0
|
233
415
|
end
|
234
|
-
randint(low, high, input.size, like_options(input, options))
|
416
|
+
randint(low, high, input.size, **like_options(input, options))
|
235
417
|
end
|
236
418
|
|
237
419
|
def randn_like(input, **options)
|
238
|
-
randn(input.size, like_options(input, options))
|
420
|
+
randn(input.size, **like_options(input, options))
|
239
421
|
end
|
240
422
|
|
241
423
|
def zeros_like(input, **options)
|
242
|
-
zeros(input.size, like_options(input, options))
|
243
|
-
end
|
244
|
-
|
245
|
-
# --- begin operations ---
|
246
|
-
|
247
|
-
%w(add sub mul div remainder).each do |op|
|
248
|
-
define_method(op) do |input, other, **options|
|
249
|
-
execute_op(op, input, other, **options)
|
250
|
-
end
|
251
|
-
end
|
252
|
-
|
253
|
-
def neg(input)
|
254
|
-
_neg(input)
|
255
|
-
end
|
256
|
-
|
257
|
-
def no_grad
|
258
|
-
previous_value = grad_enabled?
|
259
|
-
begin
|
260
|
-
_set_grad_enabled(false)
|
261
|
-
yield
|
262
|
-
ensure
|
263
|
-
_set_grad_enabled(previous_value)
|
264
|
-
end
|
265
|
-
end
|
266
|
-
|
267
|
-
# TODO support out
|
268
|
-
def mean(input, dim = nil, keepdim: false)
|
269
|
-
if dim
|
270
|
-
_mean_dim(input, dim, keepdim)
|
271
|
-
else
|
272
|
-
_mean(input)
|
273
|
-
end
|
274
|
-
end
|
275
|
-
|
276
|
-
# TODO support dtype
|
277
|
-
def sum(input, dim = nil, keepdim: false)
|
278
|
-
if dim
|
279
|
-
_sum_dim(input, dim, keepdim)
|
280
|
-
else
|
281
|
-
_sum(input)
|
282
|
-
end
|
283
|
-
end
|
284
|
-
|
285
|
-
def argmax(input, dim = nil, keepdim: false)
|
286
|
-
if dim
|
287
|
-
_argmax_dim(input, dim, keepdim)
|
288
|
-
else
|
289
|
-
_argmax(input)
|
290
|
-
end
|
291
|
-
end
|
292
|
-
|
293
|
-
def eq(input, other)
|
294
|
-
_eq(input, other)
|
295
|
-
end
|
296
|
-
|
297
|
-
def norm(input)
|
298
|
-
_norm(input)
|
299
|
-
end
|
300
|
-
|
301
|
-
def pow(input, exponent)
|
302
|
-
_pow(input, exponent)
|
303
|
-
end
|
304
|
-
|
305
|
-
def min(input)
|
306
|
-
_min(input)
|
307
|
-
end
|
308
|
-
|
309
|
-
def max(input, dim = nil, keepdim: false, out: nil)
|
310
|
-
if dim
|
311
|
-
raise NotImplementedYet unless out
|
312
|
-
_max_out(out[0], out[1], input, dim, keepdim)
|
313
|
-
else
|
314
|
-
_max(input)
|
315
|
-
end
|
316
|
-
end
|
317
|
-
|
318
|
-
def exp(input)
|
319
|
-
_exp(input)
|
320
|
-
end
|
321
|
-
|
322
|
-
def log(input)
|
323
|
-
_log(input)
|
324
|
-
end
|
325
|
-
|
326
|
-
def sign(input)
|
327
|
-
_sign(input)
|
328
|
-
end
|
329
|
-
|
330
|
-
def gt(input, other)
|
331
|
-
_gt(input, other)
|
332
|
-
end
|
333
|
-
|
334
|
-
def lt(input, other)
|
335
|
-
_lt(input, other)
|
336
|
-
end
|
337
|
-
|
338
|
-
def unsqueeze(input, dim)
|
339
|
-
_unsqueeze(input, dim)
|
340
|
-
end
|
341
|
-
|
342
|
-
def dot(input, tensor)
|
343
|
-
_dot(input, tensor)
|
344
|
-
end
|
345
|
-
|
346
|
-
def cat(tensors, dim = 0)
|
347
|
-
_cat(tensors, dim)
|
348
|
-
end
|
349
|
-
|
350
|
-
def matmul(input, other)
|
351
|
-
_matmul(input, other)
|
352
|
-
end
|
353
|
-
|
354
|
-
def reshape(input, shape)
|
355
|
-
_reshape(input, shape)
|
356
|
-
end
|
357
|
-
|
358
|
-
def flatten(input, start_dim: 0, end_dim: -1)
|
359
|
-
_flatten(input, start_dim, end_dim)
|
360
|
-
end
|
361
|
-
|
362
|
-
def sqrt(input)
|
363
|
-
_sqrt(input)
|
364
|
-
end
|
365
|
-
|
366
|
-
def abs(input)
|
367
|
-
_abs(input)
|
368
|
-
end
|
369
|
-
|
370
|
-
def device(str)
|
371
|
-
Device.new(str)
|
424
|
+
zeros(input.size, **like_options(input, options))
|
372
425
|
end
|
373
426
|
|
374
427
|
private
|
375
428
|
|
376
|
-
def execute_op(op, input, other, out: nil)
|
377
|
-
scalar = other.is_a?(Numeric)
|
378
|
-
if out
|
379
|
-
# TODO make work with scalars
|
380
|
-
raise Error, "out not supported with scalar yet" if scalar
|
381
|
-
send("_#{op}_out", out, input, other)
|
382
|
-
else
|
383
|
-
if scalar
|
384
|
-
send("_#{op}_scalar", input, other)
|
385
|
-
else
|
386
|
-
send("_#{op}", input, other)
|
387
|
-
end
|
388
|
-
end
|
389
|
-
end
|
390
|
-
|
391
429
|
def tensor_size(size)
|
392
430
|
size.flatten
|
393
431
|
end
|