torch-rb 0.20.0 → 0.21.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +12 -10
- data/codegen/native_functions.yaml +286 -244
- data/ext/torch/device.cpp +3 -0
- data/ext/torch/ext.cpp +1 -2
- data/ext/torch/ivalue.cpp +2 -0
- data/ext/torch/nn.cpp +3 -1
- data/ext/torch/ruby_arg_parser.cpp +7 -3
- data/ext/torch/ruby_arg_parser.h +5 -2
- data/ext/torch/templates.h +18 -36
- data/ext/torch/tensor.cpp +11 -8
- data/ext/torch/torch.cpp +6 -3
- data/ext/torch/utils.h +3 -1
- data/lib/torch/nn/conv1d.rb +11 -3
- data/lib/torch/nn/conv2d.rb +11 -3
- data/lib/torch/nn/conv3d.rb +11 -3
- data/lib/torch/nn/convnd.rb +1 -1
- data/lib/torch/nn/embedding.rb +10 -3
- data/lib/torch/nn/embedding_bag.rb +10 -3
- data/lib/torch/nn/functional.rb +20 -6
- data/lib/torch/nn/functional_attention.rb +30 -15
- data/lib/torch/nn/multihead_attention.rb +17 -7
- data/lib/torch/nn/rnn_base.rb +10 -3
- data/lib/torch/nn/transformer.rb +19 -10
- data/lib/torch/nn/transformer_decoder_layer.rb +7 -4
- data/lib/torch/nn/transformer_encoder_layer.rb +7 -4
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -1
- metadata +3 -3
data/ext/torch/device.cpp
CHANGED
data/ext/torch/ext.cpp
CHANGED
@@ -17,8 +17,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
|
17
17
|
void init_random(Rice::Module& m);
|
18
18
|
|
19
19
|
extern "C"
|
20
|
-
void Init_ext()
|
21
|
-
{
|
20
|
+
void Init_ext() {
|
22
21
|
auto m = Rice::define_module("Torch");
|
23
22
|
|
24
23
|
// need to define certain classes up front to keep Rice happy
|
data/ext/torch/ivalue.cpp
CHANGED
data/ext/torch/nn.cpp
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
#include <utility>
|
2
|
+
|
1
3
|
#include <torch/torch.h>
|
2
4
|
|
3
5
|
#include <rice/rice.hpp>
|
@@ -93,7 +95,7 @@ void init_nn(Rice::Module& m) {
|
|
93
95
|
"grad",
|
94
96
|
[](Parameter& self) {
|
95
97
|
auto grad = self.grad();
|
96
|
-
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
98
|
+
return grad.defined() ? Rice::Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
|
97
99
|
})
|
98
100
|
// can't use grad=
|
99
101
|
// assignment methods fail with Ruby 3.0
|
@@ -1,5 +1,10 @@
|
|
1
1
|
// adapted from PyTorch - python_arg_parser.cpp
|
2
2
|
|
3
|
+
#include <string>
|
4
|
+
#include <unordered_map>
|
5
|
+
#include <unordered_set>
|
6
|
+
#include <vector>
|
7
|
+
|
3
8
|
#include "ruby_arg_parser.h"
|
4
9
|
|
5
10
|
VALUE THPGeneratorClass = Qnil;
|
@@ -99,7 +104,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
|
99
104
|
ruby_name = THPUtils_internSymbol(name);
|
100
105
|
auto np_compat_it = numpy_compatibility_arg_names.find(name);
|
101
106
|
if (np_compat_it != numpy_compatibility_arg_names.end()) {
|
102
|
-
for (const auto& str: np_compat_it->second) {
|
107
|
+
for (const auto& str : np_compat_it->second) {
|
103
108
|
numpy_python_names.push_back(THPUtils_internSymbol(str));
|
104
109
|
}
|
105
110
|
}
|
@@ -190,8 +195,7 @@ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
|
|
190
195
|
}
|
191
196
|
|
192
197
|
// argnum is needed for raising the TypeError, it's used in the error message.
|
193
|
-
auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
194
|
-
{
|
198
|
+
auto FunctionParameter::check(VALUE obj, int argnum) -> bool {
|
195
199
|
switch (type_) {
|
196
200
|
case ParameterType::TENSOR: {
|
197
201
|
if (THPVariable_Check(obj)) {
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -3,6 +3,9 @@
|
|
3
3
|
#pragma once
|
4
4
|
|
5
5
|
#include <sstream>
|
6
|
+
#include <unordered_map>
|
7
|
+
#include <string>
|
8
|
+
#include <vector>
|
6
9
|
|
7
10
|
#include <torch/torch.h>
|
8
11
|
#include <rice/rice.hpp>
|
@@ -162,7 +165,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
162
165
|
Check_Type(arg, T_ARRAY);
|
163
166
|
auto size = RARRAY_LEN(arg);
|
164
167
|
if (size != N) {
|
165
|
-
rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (
|
168
|
+
rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
|
166
169
|
}
|
167
170
|
for (int idx = 0; idx < size; idx++) {
|
168
171
|
VALUE obj = rb_ary_entry(arg, idx);
|
@@ -463,7 +466,7 @@ struct RubyArgParser {
|
|
463
466
|
template<int N>
|
464
467
|
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
|
465
468
|
if (N < max_args) {
|
466
|
-
rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (
|
469
|
+
rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
|
467
470
|
}
|
468
471
|
return raw_parse(self, argc, argv, dst.args);
|
469
472
|
}
|
data/ext/torch/templates.h
CHANGED
@@ -1,13 +1,13 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <string>
|
4
|
+
|
3
5
|
#ifdef isfinite
|
4
6
|
#undef isfinite
|
5
7
|
#endif
|
6
8
|
|
7
9
|
#include <rice/rice.hpp>
|
8
10
|
|
9
|
-
using namespace Rice;
|
10
|
-
|
11
11
|
using torch::Device;
|
12
12
|
using torch::Scalar;
|
13
13
|
using torch::ScalarType;
|
@@ -41,55 +41,43 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
-
namespace Rice::detail
|
45
|
-
{
|
44
|
+
namespace Rice::detail {
|
46
45
|
template<typename T>
|
47
|
-
struct Type<c10::complex<T>>
|
48
|
-
{
|
46
|
+
struct Type<c10::complex<T>> {
|
49
47
|
static bool verify() { return true; }
|
50
48
|
};
|
51
49
|
|
52
50
|
template<typename T>
|
53
|
-
class To_Ruby<c10::complex<T>>
|
54
|
-
{
|
51
|
+
class To_Ruby<c10::complex<T>> {
|
55
52
|
public:
|
56
|
-
VALUE convert(c10::complex<T> const& x)
|
57
|
-
{
|
53
|
+
VALUE convert(c10::complex<T> const& x) {
|
58
54
|
return rb_dbl_complex_new(x.real(), x.imag());
|
59
55
|
}
|
60
56
|
};
|
61
57
|
|
62
58
|
template<typename T>
|
63
|
-
class From_Ruby<c10::complex<T>>
|
64
|
-
{
|
59
|
+
class From_Ruby<c10::complex<T>> {
|
65
60
|
public:
|
66
61
|
Convertible is_convertible(VALUE value) { return Convertible::Cast; }
|
67
62
|
|
68
|
-
c10::complex<T> convert(VALUE x)
|
69
|
-
{
|
63
|
+
c10::complex<T> convert(VALUE x) {
|
70
64
|
VALUE real = rb_funcall(x, rb_intern("real"), 0);
|
71
65
|
VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
|
72
66
|
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
73
67
|
}
|
74
68
|
};
|
75
|
-
}
|
76
69
|
|
77
|
-
namespace Rice::detail
|
78
|
-
{
|
79
70
|
template<>
|
80
|
-
struct Type<FanModeType>
|
81
|
-
{
|
71
|
+
struct Type<FanModeType> {
|
82
72
|
static bool verify() { return true; }
|
83
73
|
};
|
84
74
|
|
85
75
|
template<>
|
86
|
-
class From_Ruby<FanModeType>
|
87
|
-
{
|
76
|
+
class From_Ruby<FanModeType> {
|
88
77
|
public:
|
89
78
|
Convertible is_convertible(VALUE value) { return Convertible::Cast; }
|
90
79
|
|
91
|
-
FanModeType convert(VALUE x)
|
92
|
-
{
|
80
|
+
FanModeType convert(VALUE x) {
|
93
81
|
auto s = String(x).str();
|
94
82
|
if (s == "fan_in") {
|
95
83
|
return torch::kFanIn;
|
@@ -102,19 +90,16 @@ namespace Rice::detail
|
|
102
90
|
};
|
103
91
|
|
104
92
|
template<>
|
105
|
-
struct Type<NonlinearityType>
|
106
|
-
{
|
93
|
+
struct Type<NonlinearityType> {
|
107
94
|
static bool verify() { return true; }
|
108
95
|
};
|
109
96
|
|
110
97
|
template<>
|
111
|
-
class From_Ruby<NonlinearityType>
|
112
|
-
{
|
98
|
+
class From_Ruby<NonlinearityType> {
|
113
99
|
public:
|
114
100
|
Convertible is_convertible(VALUE value) { return Convertible::Cast; }
|
115
101
|
|
116
|
-
NonlinearityType convert(VALUE x)
|
117
|
-
{
|
102
|
+
NonlinearityType convert(VALUE x) {
|
118
103
|
auto s = String(x).str();
|
119
104
|
if (s == "linear") {
|
120
105
|
return torch::kLinear;
|
@@ -145,19 +130,16 @@ namespace Rice::detail
|
|
145
130
|
};
|
146
131
|
|
147
132
|
template<>
|
148
|
-
struct Type<Scalar>
|
149
|
-
{
|
133
|
+
struct Type<Scalar> {
|
150
134
|
static bool verify() { return true; }
|
151
135
|
};
|
152
136
|
|
153
137
|
template<>
|
154
|
-
class From_Ruby<Scalar>
|
155
|
-
{
|
138
|
+
class From_Ruby<Scalar> {
|
156
139
|
public:
|
157
140
|
Convertible is_convertible(VALUE value) { return Convertible::Cast; }
|
158
141
|
|
159
|
-
Scalar convert(VALUE x)
|
160
|
-
{
|
142
|
+
Scalar convert(VALUE x) {
|
161
143
|
if (FIXNUM_P(x)) {
|
162
144
|
return torch::Scalar(From_Ruby<int64_t>().convert(x));
|
163
145
|
} else {
|
@@ -165,4 +147,4 @@ namespace Rice::detail
|
|
165
147
|
}
|
166
148
|
}
|
167
149
|
};
|
168
|
-
}
|
150
|
+
} // namespace Rice::detail
|
data/ext/torch/tensor.cpp
CHANGED
@@ -1,3 +1,6 @@
|
|
1
|
+
#include <string>
|
2
|
+
#include <vector>
|
3
|
+
|
1
4
|
#include <torch/torch.h>
|
2
5
|
|
3
6
|
#include <rice/rice.hpp>
|
@@ -7,7 +10,8 @@
|
|
7
10
|
#include "templates.h"
|
8
11
|
#include "utils.h"
|
9
12
|
|
10
|
-
using
|
13
|
+
using Rice::Array;
|
14
|
+
using Rice::Object;
|
11
15
|
using torch::indexing::TensorIndex;
|
12
16
|
|
13
17
|
template<typename T>
|
@@ -21,7 +25,7 @@ Array flat_data(Tensor& tensor) {
|
|
21
25
|
return a;
|
22
26
|
}
|
23
27
|
|
24
|
-
Class rb_cTensor;
|
28
|
+
Rice::Class rb_cTensor;
|
25
29
|
|
26
30
|
std::vector<TensorIndex> index_vector(Array a) {
|
27
31
|
Object obj;
|
@@ -62,10 +66,10 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
62
66
|
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
63
67
|
} else if (obj.is_nil()) {
|
64
68
|
indices.push_back(torch::indexing::None);
|
65
|
-
} else if (obj == True || obj == False) {
|
69
|
+
} else if (obj == Rice::True || obj == Rice::False) {
|
66
70
|
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
67
71
|
} else {
|
68
|
-
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
72
|
+
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
69
73
|
}
|
70
74
|
}
|
71
75
|
return indices;
|
@@ -75,8 +79,7 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
75
79
|
// https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
|
76
80
|
// TODO add support for inputs argument
|
77
81
|
// _backward
|
78
|
-
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
79
|
-
{
|
82
|
+
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_) {
|
80
83
|
HANDLE_TH_ERRORS
|
81
84
|
Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
|
82
85
|
static RubyArgParser parser({
|
@@ -165,7 +168,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
165
168
|
"grad",
|
166
169
|
[](Tensor& self) {
|
167
170
|
auto grad = self.grad();
|
168
|
-
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
171
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
|
169
172
|
})
|
170
173
|
// can't use grad=
|
171
174
|
// assignment methods fail with Ruby 3.0
|
@@ -197,7 +200,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
197
200
|
.define_method(
|
198
201
|
"_dtype",
|
199
202
|
[](Tensor& self) {
|
200
|
-
return (
|
203
|
+
return static_cast<int>(at::typeMetaToScalarType(self.dtype()));
|
201
204
|
})
|
202
205
|
.define_method(
|
203
206
|
"_type",
|
data/ext/torch/torch.cpp
CHANGED
@@ -1,8 +1,11 @@
|
|
1
|
+
#include <fstream>
|
2
|
+
#include <string>
|
3
|
+
#include <vector>
|
4
|
+
|
1
5
|
#include <torch/torch.h>
|
2
6
|
|
3
7
|
#include <rice/rice.hpp>
|
4
|
-
|
5
|
-
#include <fstream>
|
8
|
+
#include <rice/stl.hpp>
|
6
9
|
|
7
10
|
#include "torch_functions.h"
|
8
11
|
#include "templates.h"
|
@@ -60,7 +63,7 @@ void init_torch(Rice::Module& m) {
|
|
60
63
|
"_save",
|
61
64
|
[](const torch::IValue &value) {
|
62
65
|
auto v = torch::pickle_save(value);
|
63
|
-
return Object(rb_str_new(v.data(), v.size()));
|
66
|
+
return Rice::Object(rb_str_new(v.data(), v.size()));
|
64
67
|
})
|
65
68
|
.define_singleton_function(
|
66
69
|
"_load",
|
data/ext/torch/utils.h
CHANGED
@@ -1,12 +1,14 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <string>
|
4
|
+
|
3
5
|
#include <torch/torch.h>
|
4
6
|
|
5
7
|
#include <rice/rice.hpp>
|
6
8
|
#include <rice/stl.hpp>
|
7
9
|
|
8
10
|
static_assert(
|
9
|
-
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR ==
|
11
|
+
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 8,
|
10
12
|
"Incompatible LibTorch version"
|
11
13
|
);
|
12
14
|
|
data/lib/torch/nn/conv1d.rb
CHANGED
@@ -1,9 +1,17 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Conv1d < ConvNd
|
4
|
-
def initialize(
|
5
|
-
|
6
|
-
|
4
|
+
def initialize(
|
5
|
+
in_channels,
|
6
|
+
out_channels,
|
7
|
+
kernel_size,
|
8
|
+
stride: 1,
|
9
|
+
padding: 0,
|
10
|
+
dilation: 1,
|
11
|
+
groups: 1,
|
12
|
+
bias: true,
|
13
|
+
padding_mode: "zeros"
|
14
|
+
)
|
7
15
|
kernel_size = _single(kernel_size)
|
8
16
|
stride = _single(stride)
|
9
17
|
padding = _single(padding)
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -1,9 +1,17 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Conv2d < ConvNd
|
4
|
-
def initialize(
|
5
|
-
|
6
|
-
|
4
|
+
def initialize(
|
5
|
+
in_channels,
|
6
|
+
out_channels,
|
7
|
+
kernel_size,
|
8
|
+
stride: 1,
|
9
|
+
padding: 0,
|
10
|
+
dilation: 1,
|
11
|
+
groups: 1,
|
12
|
+
bias: true,
|
13
|
+
padding_mode: "zeros"
|
14
|
+
)
|
7
15
|
kernel_size = _pair(kernel_size)
|
8
16
|
stride = _pair(stride)
|
9
17
|
padding = _pair(padding)
|
data/lib/torch/nn/conv3d.rb
CHANGED
@@ -1,9 +1,17 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Conv3d < ConvNd
|
4
|
-
def initialize(
|
5
|
-
|
6
|
-
|
4
|
+
def initialize(
|
5
|
+
in_channels,
|
6
|
+
out_channels,
|
7
|
+
kernel_size,
|
8
|
+
stride: 1,
|
9
|
+
padding: 0,
|
10
|
+
dilation: 1,
|
11
|
+
groups: 1,
|
12
|
+
bias: true,
|
13
|
+
padding_mode: "zeros"
|
14
|
+
)
|
7
15
|
kernel_size = _triple(kernel_size)
|
8
16
|
stride = _triple(stride)
|
9
17
|
padding = _triple(padding)
|
data/lib/torch/nn/convnd.rb
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ConvNd < Module
|
4
|
-
attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :
|
4
|
+
attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_padding, :groups, :padding_mode
|
5
5
|
|
6
6
|
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
7
7
|
super()
|
data/lib/torch/nn/embedding.rb
CHANGED
@@ -2,9 +2,16 @@
|
|
2
2
|
module Torch
|
3
3
|
module NN
|
4
4
|
class Embedding < Module
|
5
|
-
def initialize(
|
6
|
-
|
7
|
-
|
5
|
+
def initialize(
|
6
|
+
num_embeddings,
|
7
|
+
embedding_dim,
|
8
|
+
padding_idx: nil,
|
9
|
+
max_norm: nil,
|
10
|
+
norm_type: 2.0,
|
11
|
+
scale_grad_by_freq: false,
|
12
|
+
sparse: false,
|
13
|
+
_weight: nil
|
14
|
+
)
|
8
15
|
super()
|
9
16
|
@num_embeddings = num_embeddings
|
10
17
|
@embedding_dim = embedding_dim
|
@@ -2,9 +2,16 @@
|
|
2
2
|
module Torch
|
3
3
|
module NN
|
4
4
|
class EmbeddingBag < Module
|
5
|
-
def initialize(
|
6
|
-
|
7
|
-
|
5
|
+
def initialize(
|
6
|
+
num_embeddings,
|
7
|
+
embedding_dim,
|
8
|
+
max_norm: nil,
|
9
|
+
norm_type: 2.0,
|
10
|
+
scale_grad_by_freq: false,
|
11
|
+
mode: "mean",
|
12
|
+
sparse: false,
|
13
|
+
_weight: nil
|
14
|
+
)
|
8
15
|
super()
|
9
16
|
@num_embeddings = num_embeddings
|
10
17
|
@embedding_dim = embedding_dim
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -250,9 +250,16 @@ module Torch
|
|
250
250
|
|
251
251
|
# normalization layers
|
252
252
|
|
253
|
-
def batch_norm(
|
254
|
-
|
255
|
-
|
253
|
+
def batch_norm(
|
254
|
+
input,
|
255
|
+
running_mean,
|
256
|
+
running_var,
|
257
|
+
weight: nil,
|
258
|
+
bias: nil,
|
259
|
+
training: false,
|
260
|
+
momentum: 0.1,
|
261
|
+
eps: 1e-5
|
262
|
+
)
|
256
263
|
if training
|
257
264
|
size = input.size
|
258
265
|
size_prods = size[0]
|
@@ -274,9 +281,16 @@ module Torch
|
|
274
281
|
Torch.group_norm(input, num_groups, weight, bias, eps, false)
|
275
282
|
end
|
276
283
|
|
277
|
-
def instance_norm(
|
278
|
-
|
279
|
-
|
284
|
+
def instance_norm(
|
285
|
+
input,
|
286
|
+
running_mean: nil,
|
287
|
+
running_var: nil,
|
288
|
+
weight: nil,
|
289
|
+
bias: nil,
|
290
|
+
use_input_stats: true,
|
291
|
+
momentum: 0.1,
|
292
|
+
eps: 1e-5
|
293
|
+
)
|
280
294
|
Torch.instance_norm(
|
281
295
|
input, weight, bias, running_mean, running_var,
|
282
296
|
use_input_stats, momentum, eps, false
|
@@ -33,11 +33,16 @@ module Torch
|
|
33
33
|
end
|
34
34
|
|
35
35
|
def in_projection(
|
36
|
-
q,
|
37
|
-
|
38
|
-
|
36
|
+
q,
|
37
|
+
k,
|
38
|
+
v,
|
39
|
+
w_q,
|
40
|
+
w_k,
|
41
|
+
w_v,
|
42
|
+
b_q: nil,
|
43
|
+
b_k: nil,
|
44
|
+
b_v: nil
|
39
45
|
)
|
40
|
-
|
41
46
|
e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
|
42
47
|
|
43
48
|
raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
|
@@ -52,10 +57,12 @@ module Torch
|
|
52
57
|
end
|
53
58
|
|
54
59
|
def scaled_dot_product_attention(
|
55
|
-
q,
|
56
|
-
|
60
|
+
q,
|
61
|
+
k,
|
62
|
+
v,
|
63
|
+
attn_mask: nil,
|
64
|
+
dropout_p: 0.0
|
57
65
|
)
|
58
|
-
|
59
66
|
_b, _nt, e = q.shape
|
60
67
|
|
61
68
|
q = q / Math.sqrt(e)
|
@@ -71,22 +78,30 @@ module Torch
|
|
71
78
|
end
|
72
79
|
|
73
80
|
def multi_head_attention_forward(
|
74
|
-
query,
|
75
|
-
|
76
|
-
|
77
|
-
|
81
|
+
query,
|
82
|
+
key,
|
83
|
+
value,
|
84
|
+
embed_dim_to_check,
|
85
|
+
num_heads,
|
86
|
+
in_proj_weight,
|
87
|
+
in_proj_bias,
|
88
|
+
bias_k,
|
89
|
+
bias_v,
|
78
90
|
add_zero_attn,
|
79
91
|
dropout_p,
|
80
|
-
out_proj_weight,
|
92
|
+
out_proj_weight,
|
93
|
+
out_proj_bias,
|
81
94
|
training: true,
|
82
95
|
key_padding_mask: nil,
|
83
96
|
need_weights: true,
|
84
97
|
attn_mask: nil,
|
85
98
|
use_separate_proj_weight: false,
|
86
|
-
q_proj_weight: nil,
|
87
|
-
|
99
|
+
q_proj_weight: nil,
|
100
|
+
k_proj_weight: nil,
|
101
|
+
v_proj_weight: nil,
|
102
|
+
static_k: nil,
|
103
|
+
static_v: nil
|
88
104
|
)
|
89
|
-
|
90
105
|
tgt_len, bsz, embed_dim = query.shape
|
91
106
|
src_len = key.shape.first
|
92
107
|
|
@@ -2,11 +2,18 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class MultiheadAttention < Module
|
4
4
|
def initialize(
|
5
|
-
embed_dim,
|
6
|
-
|
7
|
-
|
5
|
+
embed_dim,
|
6
|
+
num_heads,
|
7
|
+
dropout: 0.0,
|
8
|
+
bias: true,
|
9
|
+
add_bias_kv: false,
|
10
|
+
add_zero_attn: false,
|
11
|
+
kdim: nil,
|
12
|
+
vdim: nil,
|
13
|
+
batch_first: false,
|
14
|
+
device: nil,
|
15
|
+
dtype: nil
|
8
16
|
)
|
9
|
-
|
10
17
|
super()
|
11
18
|
|
12
19
|
@embed_dim = embed_dim
|
@@ -77,10 +84,13 @@ module Torch
|
|
77
84
|
end
|
78
85
|
|
79
86
|
def forward(
|
80
|
-
query,
|
81
|
-
|
87
|
+
query,
|
88
|
+
key,
|
89
|
+
value,
|
90
|
+
key_padding_mask: nil,
|
91
|
+
need_weights: true,
|
92
|
+
attn_mask: nil
|
82
93
|
)
|
83
|
-
|
84
94
|
if batch_first?
|
85
95
|
query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
|
86
96
|
end
|
data/lib/torch/nn/rnn_base.rb
CHANGED
@@ -1,9 +1,16 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class RNNBase < Module
|
4
|
-
def initialize(
|
5
|
-
|
6
|
-
|
4
|
+
def initialize(
|
5
|
+
mode,
|
6
|
+
input_size,
|
7
|
+
hidden_size,
|
8
|
+
num_layers: 1,
|
9
|
+
bias: true,
|
10
|
+
batch_first: false,
|
11
|
+
dropout: 0.0,
|
12
|
+
bidirectional: false
|
13
|
+
)
|
7
14
|
super()
|
8
15
|
@mode = mode
|
9
16
|
@input_size = input_size
|
data/lib/torch/nn/transformer.rb
CHANGED
@@ -7,13 +7,18 @@ module Torch
|
|
7
7
|
module NN
|
8
8
|
class Transformer < Module
|
9
9
|
def initialize(
|
10
|
-
d_model: 512,
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
10
|
+
d_model: 512,
|
11
|
+
nhead: 8,
|
12
|
+
num_encoder_layers: 6,
|
13
|
+
num_decoder_layers: 6,
|
14
|
+
dim_feedforward: 2048,
|
15
|
+
dropout: 0.1,
|
16
|
+
activation: :relu,
|
17
|
+
custom_encoder: nil,
|
18
|
+
custom_decoder: nil,
|
19
|
+
layer_norm_eps: 1e-5,
|
20
|
+
batch_first: false
|
15
21
|
)
|
16
|
-
|
17
22
|
super()
|
18
23
|
|
19
24
|
@encoder =
|
@@ -60,11 +65,15 @@ module Torch
|
|
60
65
|
end
|
61
66
|
|
62
67
|
def forward(
|
63
|
-
src,
|
64
|
-
|
65
|
-
|
68
|
+
src,
|
69
|
+
tgt,
|
70
|
+
src_mask: nil,
|
71
|
+
tgt_mask: nil,
|
72
|
+
memory_mask: nil,
|
73
|
+
src_key_padding_mask: nil,
|
74
|
+
tgt_key_padding_mask: nil,
|
75
|
+
memory_key_padding_mask: nil
|
66
76
|
)
|
67
|
-
|
68
77
|
if (!batch_first? && src.size(1) != tgt.size(1)) ||
|
69
78
|
(batch_first? && src.size(0) != tgt.size(0))
|
70
79
|
|