torch-rb 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/README.md +23 -41
- data/codegen/generate_functions.rb +46 -8
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/ext.cpp +8 -0
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.h +17 -3
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +0 -37
- data/ext/torch/tensor.cpp +8 -8
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/parameter.rb +6 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +12 -0
- data/lib/torch/tensor.rb +20 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +6 -0
- metadata +18 -3
@@ -0,0 +1,17 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_backends(Rice::Module& m) {
|
8
|
+
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
|
+
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
|
+
.add_handler<torch::Error>(handle_error)
|
12
|
+
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
|
+
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
|
+
.add_handler<torch::Error>(handle_error)
|
16
|
+
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
}
|
data/ext/torch/ext.cpp
CHANGED
@@ -2,10 +2,14 @@
|
|
2
2
|
|
3
3
|
#include <rice/rice.hpp>
|
4
4
|
|
5
|
+
void init_fft(Rice::Module& m);
|
6
|
+
void init_linalg(Rice::Module& m);
|
5
7
|
void init_nn(Rice::Module& m);
|
8
|
+
void init_special(Rice::Module& m);
|
6
9
|
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
7
10
|
void init_torch(Rice::Module& m);
|
8
11
|
|
12
|
+
void init_backends(Rice::Module& m);
|
9
13
|
void init_cuda(Rice::Module& m);
|
10
14
|
void init_device(Rice::Module& m);
|
11
15
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
@@ -27,7 +31,11 @@ void Init_ext()
|
|
27
31
|
init_torch(m);
|
28
32
|
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
29
33
|
init_nn(m);
|
34
|
+
init_fft(m);
|
35
|
+
init_linalg(m);
|
36
|
+
init_special(m);
|
30
37
|
|
38
|
+
init_backends(m);
|
31
39
|
init_cuda(m);
|
32
40
|
init_device(m);
|
33
41
|
init_ivalue(m, rb_cIValue);
|
data/ext/torch/fft.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "fft_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_fft(Rice::Module& m) {
|
10
|
+
auto rb_mFFT = Rice::define_module_under(m, "FFT");
|
11
|
+
rb_mFFT.add_handler<torch::Error>(handle_error);
|
12
|
+
add_fft_functions(rb_mFFT);
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "linalg_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_linalg(Rice::Module& m) {
|
10
|
+
auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
|
11
|
+
rb_mLinalg.add_handler<torch::Error>(handle_error);
|
12
|
+
add_linalg_functions(rb_mLinalg);
|
13
|
+
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -75,7 +75,7 @@ struct RubyArgs {
|
|
75
75
|
int idx;
|
76
76
|
|
77
77
|
inline at::Tensor tensor(int i);
|
78
|
-
inline
|
78
|
+
inline c10::optional<at::Tensor> optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
81
|
inline std::vector<at::Scalar> scalarlist(int i);
|
@@ -109,6 +109,9 @@ struct RubyArgs {
|
|
109
109
|
// inline at::QScheme toQScheme(int i);
|
110
110
|
inline std::string string(int i);
|
111
111
|
inline c10::optional<std::string> stringOptional(int i);
|
112
|
+
inline c10::string_view stringView(int i);
|
113
|
+
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
114
|
+
inline c10::optional<c10::string_view> stringViewOptional(int i);
|
112
115
|
// inline PyObject* pyobject(int i);
|
113
116
|
inline int64_t toInt64(int i);
|
114
117
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
|
|
125
128
|
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
126
129
|
}
|
127
130
|
|
128
|
-
inline
|
129
|
-
if (NIL_P(args[i])) return
|
131
|
+
inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
|
132
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
130
133
|
return tensor(i);
|
131
134
|
}
|
132
135
|
|
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
322
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
323
326
|
}
|
324
327
|
|
328
|
+
inline c10::string_view RubyArgs::stringView(int i) {
|
329
|
+
auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
330
|
+
return c10::string_view(str.data(), str.size());
|
331
|
+
}
|
332
|
+
|
333
|
+
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
334
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
335
|
+
auto str = Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
336
|
+
return c10::string_view(str.data(), str.size());
|
337
|
+
}
|
338
|
+
|
325
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
326
340
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
327
341
|
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "special_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_special(Rice::Module& m) {
|
10
|
+
auto rb_mSpecial = Rice::define_module_under(m, "Special");
|
11
|
+
rb_mSpecial.add_handler<torch::Error>(handle_error);
|
12
|
+
add_special_functions(rb_mSpecial);
|
13
|
+
}
|
data/ext/torch/templates.h
CHANGED
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
-
class OptionalTensor {
|
45
|
-
torch::Tensor value;
|
46
|
-
public:
|
47
|
-
OptionalTensor(Object o) {
|
48
|
-
if (o.is_nil()) {
|
49
|
-
value = {};
|
50
|
-
} else {
|
51
|
-
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
52
|
-
}
|
53
|
-
}
|
54
|
-
OptionalTensor(torch::Tensor o) {
|
55
|
-
value = o;
|
56
|
-
}
|
57
|
-
operator torch::Tensor() const {
|
58
|
-
return value;
|
59
|
-
}
|
60
|
-
};
|
61
|
-
|
62
44
|
namespace Rice::detail
|
63
45
|
{
|
64
46
|
template<>
|
@@ -131,25 +113,6 @@ namespace Rice::detail
|
|
131
113
|
}
|
132
114
|
};
|
133
115
|
|
134
|
-
template<>
|
135
|
-
struct Type<OptionalTensor>
|
136
|
-
{
|
137
|
-
static bool verify()
|
138
|
-
{
|
139
|
-
return true;
|
140
|
-
}
|
141
|
-
};
|
142
|
-
|
143
|
-
template<>
|
144
|
-
class From_Ruby<OptionalTensor>
|
145
|
-
{
|
146
|
-
public:
|
147
|
-
OptionalTensor convert(VALUE x)
|
148
|
-
{
|
149
|
-
return OptionalTensor(x);
|
150
|
-
}
|
151
|
-
};
|
152
|
-
|
153
116
|
template<>
|
154
117
|
struct Type<Scalar>
|
155
118
|
{
|
data/ext/torch/tensor.cpp
CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
107
107
|
ParsedArgs<4> parsed_args;
|
108
108
|
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
109
109
|
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
110
|
-
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const
|
110
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
111
111
|
// in future, release GVL
|
112
112
|
self._backward(inputs, gradient, retain_graph, create_graph);
|
113
113
|
};
|
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
125
125
|
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
126
126
|
|
127
127
|
rb_cTensor
|
128
|
-
.define_method("cuda?", &
|
129
|
-
.define_method("sparse?", &
|
130
|
-
.define_method("quantized?", &
|
131
|
-
.define_method("dim", &
|
132
|
-
.define_method("numel", &
|
133
|
-
.define_method("element_size", &
|
134
|
-
.define_method("requires_grad", &
|
128
|
+
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
129
|
+
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
130
|
+
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
131
|
+
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
132
|
+
.define_method("numel", [](Tensor& self) { return self.numel(); })
|
133
|
+
.define_method("element_size", [](Tensor& self) { return self.element_size(); })
|
134
|
+
.define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
|
135
135
|
.define_method(
|
136
136
|
"_size",
|
137
137
|
[](Tensor& self, int64_t dim) {
|
data/lib/torch/nn/convnd.rb
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ConvNd < Module
|
4
|
+
attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
|
5
|
+
|
4
6
|
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
7
|
super()
|
6
8
|
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
@@ -0,0 +1,241 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def in_projection_packed(q, k, v, w, b: nil)
|
6
|
+
e = q.size(-1)
|
7
|
+
|
8
|
+
if k.eql? v
|
9
|
+
if q.eql? k
|
10
|
+
# self-attention
|
11
|
+
return linear(q, w, b).chunk(3, dim: -1)
|
12
|
+
else
|
13
|
+
# encoder-decoder attention
|
14
|
+
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
15
|
+
if b.nil?
|
16
|
+
b_q = b_kv = nil
|
17
|
+
else
|
18
|
+
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
|
+
end
|
20
|
+
|
21
|
+
return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
|
+
end
|
23
|
+
else
|
24
|
+
w_q, w_k, w_v = w.chunk(3)
|
25
|
+
if b.nil?
|
26
|
+
b_q = b_k = b_v = nil
|
27
|
+
else
|
28
|
+
b_q, b_k, b_v = b.chunk(3)
|
29
|
+
end
|
30
|
+
|
31
|
+
return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
def in_projection(
|
36
|
+
q, k, v,
|
37
|
+
w_q, w_k, w_v,
|
38
|
+
b_q: nil, b_k: nil, b_v: nil
|
39
|
+
)
|
40
|
+
|
41
|
+
e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
|
42
|
+
|
43
|
+
raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
|
44
|
+
raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
|
45
|
+
raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
|
46
|
+
|
47
|
+
raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
|
48
|
+
raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
|
49
|
+
raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
|
50
|
+
|
51
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
52
|
+
end
|
53
|
+
|
54
|
+
def scaled_dot_product_attention(
|
55
|
+
q, k, v,
|
56
|
+
attn_mask: nil, dropout_p: 0.0
|
57
|
+
)
|
58
|
+
|
59
|
+
b, nt, e = q.shape
|
60
|
+
|
61
|
+
q = q / Math.sqrt(e)
|
62
|
+
|
63
|
+
attn = Torch.bmm(q, k.transpose(-2, -1))
|
64
|
+
attn += attn_mask if attn_mask
|
65
|
+
attn = softmax(attn, dim: -1)
|
66
|
+
attn = dropout(attn, p: dropout_p) if dropout_p > 0
|
67
|
+
|
68
|
+
output = Torch.bmm(attn, v)
|
69
|
+
|
70
|
+
[output, attn]
|
71
|
+
end
|
72
|
+
|
73
|
+
def multi_head_attention_forward(
|
74
|
+
query, key, value,
|
75
|
+
embed_dim_to_check, num_heads,
|
76
|
+
in_proj_weight, in_proj_bias,
|
77
|
+
bias_k, bias_v,
|
78
|
+
add_zero_attn,
|
79
|
+
dropout_p,
|
80
|
+
out_proj_weight, out_proj_bias,
|
81
|
+
training: true,
|
82
|
+
key_padding_mask: nil,
|
83
|
+
need_weights: true,
|
84
|
+
attn_mask: nil,
|
85
|
+
use_separate_proj_weight: false,
|
86
|
+
q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
|
87
|
+
static_k: nil, static_v: nil
|
88
|
+
)
|
89
|
+
|
90
|
+
tgt_len, bsz, embed_dim = query.shape
|
91
|
+
src_len = key.shape.first
|
92
|
+
|
93
|
+
raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
|
94
|
+
|
95
|
+
head_dim = if embed_dim.is_a?(Torch::Tensor)
|
96
|
+
embed_dim.div(num_heads, rounding_mode: 'trunc')
|
97
|
+
else
|
98
|
+
head_dim = embed_dim.div num_heads
|
99
|
+
end
|
100
|
+
|
101
|
+
if use_separate_proj_weight
|
102
|
+
raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
|
103
|
+
else
|
104
|
+
raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
|
105
|
+
end
|
106
|
+
|
107
|
+
# compute in-projection
|
108
|
+
q, k, v =
|
109
|
+
if use_separate_proj_weight
|
110
|
+
raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
|
111
|
+
raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
|
112
|
+
raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
|
113
|
+
|
114
|
+
if in_proj_bias
|
115
|
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
116
|
+
else
|
117
|
+
b_q = b_k = b_v = nil
|
118
|
+
end
|
119
|
+
|
120
|
+
in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
|
121
|
+
else
|
122
|
+
in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
|
123
|
+
end
|
124
|
+
|
125
|
+
# prep attention mask
|
126
|
+
if attn_mask
|
127
|
+
if attn_mask.dtype == :uint8
|
128
|
+
puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
129
|
+
attn_mask = attn_mask.bool
|
130
|
+
else
|
131
|
+
raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
|
132
|
+
end
|
133
|
+
|
134
|
+
if attn_mask.dim == 2
|
135
|
+
correct_2d_size = [tgt_len, src_len]
|
136
|
+
raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
|
137
|
+
|
138
|
+
attn_mask = attn_mask.unsqueeze(0)
|
139
|
+
elsif attn_mask.dim == 3
|
140
|
+
correct_3d_size = [bsz * num_heads, tgt_len, src_len]
|
141
|
+
raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
|
142
|
+
else
|
143
|
+
raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
# prep key padding mask
|
148
|
+
if key_padding_mask && key_padding_mask.dtype == :uint8
|
149
|
+
puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
150
|
+
key_padding_mask = key_padding_mask.bool
|
151
|
+
end
|
152
|
+
|
153
|
+
# add bias along batch dimension (currently second)
|
154
|
+
if bias_k && bias_v
|
155
|
+
raise ArgumentError, "bias cannot be added to static key." if static_k
|
156
|
+
raise ArgumentError, "bias cannot be added to static value." if static_v
|
157
|
+
|
158
|
+
k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
159
|
+
v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
160
|
+
|
161
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
162
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
163
|
+
else
|
164
|
+
raise ArgumentError unless bias_k.nil?
|
165
|
+
raise ArgumentError unless bias_v.nil?
|
166
|
+
end
|
167
|
+
|
168
|
+
# reshape q, k, v for multihead attention and make em batch first
|
169
|
+
q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
170
|
+
|
171
|
+
if static_k.nil?
|
172
|
+
k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
173
|
+
else
|
174
|
+
raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
|
175
|
+
raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
|
176
|
+
|
177
|
+
k = static_k
|
178
|
+
end
|
179
|
+
|
180
|
+
if static_v.nil?
|
181
|
+
v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
182
|
+
else
|
183
|
+
raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
|
184
|
+
raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
|
185
|
+
|
186
|
+
v = static_v
|
187
|
+
end
|
188
|
+
|
189
|
+
# add zero attention along batch dimension (now first)
|
190
|
+
if add_zero_attn
|
191
|
+
zero_attn_shape = [bsz * num_heads, 1, head_dim]
|
192
|
+
k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
|
193
|
+
v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
|
194
|
+
|
195
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
196
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
197
|
+
end
|
198
|
+
|
199
|
+
# update source sequence length after adjustments
|
200
|
+
src_len = k.size(1)
|
201
|
+
|
202
|
+
# merge key padding and attention masks
|
203
|
+
if key_padding_mask
|
204
|
+
raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
|
205
|
+
|
206
|
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
207
|
+
|
208
|
+
attn_mask = if attn_mask.nil?
|
209
|
+
key_padding_mask
|
210
|
+
elsif attn_mask.dtype == :bool
|
211
|
+
attn_mask.logical_or(key_padding_mask)
|
212
|
+
else
|
213
|
+
attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
# convert mask to float
|
218
|
+
if attn_mask && attn_mask.dtype == :bool
|
219
|
+
new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
|
220
|
+
attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
|
221
|
+
end
|
222
|
+
|
223
|
+
dropout_p = 0.0 unless training
|
224
|
+
|
225
|
+
# (deep breath) calculate attention and out projection
|
226
|
+
attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
|
227
|
+
attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
|
228
|
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
229
|
+
|
230
|
+
if need_weights
|
231
|
+
# average attention weights over heads
|
232
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
233
|
+
[attn_output, attn_output_weights.sum(dim: 1) / num_heads]
|
234
|
+
else
|
235
|
+
[attn_output, nil]
|
236
|
+
end
|
237
|
+
end
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -3,6 +3,8 @@ module Torch
|
|
3
3
|
class Module
|
4
4
|
include Utils
|
5
5
|
|
6
|
+
attr_reader :training
|
7
|
+
|
6
8
|
def initialize
|
7
9
|
@training = true
|
8
10
|
@parameters = {}
|
@@ -278,6 +280,11 @@ module Torch
|
|
278
280
|
end
|
279
281
|
end
|
280
282
|
|
283
|
+
def deep_dup
|
284
|
+
memo = {}
|
285
|
+
dup_value(self, memo)
|
286
|
+
end
|
287
|
+
|
281
288
|
def method_missing(method, *args, &block)
|
282
289
|
name = method.to_s
|
283
290
|
if named_parameters.key?(name)
|
@@ -386,6 +393,29 @@ module Torch
|
|
386
393
|
destination[prefix + k] = v
|
387
394
|
end
|
388
395
|
end
|
396
|
+
|
397
|
+
# keep memo hash like Python deepcopy
|
398
|
+
# https://docs.python.org/3/library/copy.html
|
399
|
+
def dup_value(v, memo)
|
400
|
+
memo[v.object_id] ||= begin
|
401
|
+
case v
|
402
|
+
when Method, UnboundMethod
|
403
|
+
v
|
404
|
+
when Hash
|
405
|
+
v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
|
406
|
+
when Array
|
407
|
+
v.map { |v2| dup_value(v2, memo) }
|
408
|
+
when Torch::NN::Module
|
409
|
+
copy = v.dup
|
410
|
+
v.instance_variables.each do |var|
|
411
|
+
copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
|
412
|
+
end
|
413
|
+
copy
|
414
|
+
else
|
415
|
+
v.dup
|
416
|
+
end
|
417
|
+
end
|
418
|
+
end
|
389
419
|
end
|
390
420
|
end
|
391
421
|
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ModuleList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(mods = nil)
|
7
|
+
super()
|
8
|
+
|
9
|
+
self.concat(mods) if mods
|
10
|
+
end
|
11
|
+
|
12
|
+
def length
|
13
|
+
@modules.length
|
14
|
+
end
|
15
|
+
alias_method :count, :length
|
16
|
+
alias_method :size, :length
|
17
|
+
|
18
|
+
def concat(mods)
|
19
|
+
raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
|
20
|
+
|
21
|
+
mods.each { |m| append m }
|
22
|
+
|
23
|
+
self
|
24
|
+
end
|
25
|
+
|
26
|
+
def each(&block)
|
27
|
+
if block_given?
|
28
|
+
@modules.values.each(&block)
|
29
|
+
else
|
30
|
+
to_enum(:each)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def append(mod)
|
35
|
+
raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
|
36
|
+
add_module(length.to_s, mod)
|
37
|
+
self
|
38
|
+
end
|
39
|
+
|
40
|
+
def [](idx)
|
41
|
+
if idx.is_a?(Range)
|
42
|
+
self.class.new(@modules.values[idx])
|
43
|
+
else
|
44
|
+
@modules[idx.to_s]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|