torch-rb 0.8.1 → 0.9.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +20 -0
- data/README.md +26 -44
- data/codegen/generate_functions.rb +13 -5
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +19 -5
- data/ext/torch/templates.h +0 -37
- data/ext/torch/tensor.cpp +8 -8
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/parameter.rb +6 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +12 -0
- data/lib/torch/tensor.rb +21 -8
- data/lib/torch/utils/data/data_loader.rb +3 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +6 -45
- metadata +11 -3
data/ext/torch/backends.cpp
CHANGED
@@ -7,11 +7,11 @@
|
|
7
7
|
void init_backends(Rice::Module& m) {
|
8
8
|
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
9
|
|
10
|
-
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
11
|
.add_handler<torch::Error>(handle_error)
|
12
12
|
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
13
|
|
14
|
-
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
17
|
}
|
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
472
472
|
auto param_idx = find_param(signature, key);
|
473
473
|
if (param_idx < 0) {
|
474
474
|
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
-
signature.name.c_str(),
|
475
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
476
476
|
}
|
477
477
|
|
478
478
|
if (param_idx < num_pos_args) {
|
479
479
|
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
-
signature.name.c_str(),
|
480
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
481
481
|
}
|
482
482
|
}
|
483
483
|
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -75,7 +75,7 @@ struct RubyArgs {
|
|
75
75
|
int idx;
|
76
76
|
|
77
77
|
inline at::Tensor tensor(int i);
|
78
|
-
inline
|
78
|
+
inline c10::optional<at::Tensor> optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
81
|
inline std::vector<at::Scalar> scalarlist(int i);
|
@@ -109,6 +109,9 @@ struct RubyArgs {
|
|
109
109
|
// inline at::QScheme toQScheme(int i);
|
110
110
|
inline std::string string(int i);
|
111
111
|
inline c10::optional<std::string> stringOptional(int i);
|
112
|
+
inline c10::string_view stringView(int i);
|
113
|
+
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
114
|
+
inline c10::optional<c10::string_view> stringViewOptional(int i);
|
112
115
|
// inline PyObject* pyobject(int i);
|
113
116
|
inline int64_t toInt64(int i);
|
114
117
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
|
|
125
128
|
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
126
129
|
}
|
127
130
|
|
128
|
-
inline
|
129
|
-
if (NIL_P(args[i])) return
|
131
|
+
inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
|
132
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
130
133
|
return tensor(i);
|
131
134
|
}
|
132
135
|
|
@@ -232,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
232
235
|
|
233
236
|
auto it = dtype_map.find(args[i]);
|
234
237
|
if (it == dtype_map.end()) {
|
235
|
-
rb_raise(rb_eArgError, "invalid dtype: %s",
|
238
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
236
239
|
}
|
237
240
|
return it->second;
|
238
241
|
}
|
@@ -290,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
290
293
|
|
291
294
|
auto it = layout_map.find(args[i]);
|
292
295
|
if (it == layout_map.end()) {
|
293
|
-
rb_raise(rb_eArgError, "invalid layout: %s",
|
296
|
+
rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
294
297
|
}
|
295
298
|
return it->second;
|
296
299
|
}
|
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
322
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
323
326
|
}
|
324
327
|
|
328
|
+
// string_view does not own data
|
329
|
+
inline c10::string_view RubyArgs::stringView(int i) {
|
330
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
331
|
+
}
|
332
|
+
|
333
|
+
// string_view does not own data
|
334
|
+
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
335
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
336
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
337
|
+
}
|
338
|
+
|
325
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
326
340
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
327
341
|
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
data/ext/torch/templates.h
CHANGED
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
-
class OptionalTensor {
|
45
|
-
torch::Tensor value;
|
46
|
-
public:
|
47
|
-
OptionalTensor(Object o) {
|
48
|
-
if (o.is_nil()) {
|
49
|
-
value = {};
|
50
|
-
} else {
|
51
|
-
value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
|
52
|
-
}
|
53
|
-
}
|
54
|
-
OptionalTensor(torch::Tensor o) {
|
55
|
-
value = o;
|
56
|
-
}
|
57
|
-
operator torch::Tensor() const {
|
58
|
-
return value;
|
59
|
-
}
|
60
|
-
};
|
61
|
-
|
62
44
|
namespace Rice::detail
|
63
45
|
{
|
64
46
|
template<>
|
@@ -131,25 +113,6 @@ namespace Rice::detail
|
|
131
113
|
}
|
132
114
|
};
|
133
115
|
|
134
|
-
template<>
|
135
|
-
struct Type<OptionalTensor>
|
136
|
-
{
|
137
|
-
static bool verify()
|
138
|
-
{
|
139
|
-
return true;
|
140
|
-
}
|
141
|
-
};
|
142
|
-
|
143
|
-
template<>
|
144
|
-
class From_Ruby<OptionalTensor>
|
145
|
-
{
|
146
|
-
public:
|
147
|
-
OptionalTensor convert(VALUE x)
|
148
|
-
{
|
149
|
-
return OptionalTensor(x);
|
150
|
-
}
|
151
|
-
};
|
152
|
-
|
153
116
|
template<>
|
154
117
|
struct Type<Scalar>
|
155
118
|
{
|
data/ext/torch/tensor.cpp
CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
107
107
|
ParsedArgs<4> parsed_args;
|
108
108
|
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
109
109
|
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
110
|
-
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const
|
110
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
111
111
|
// in future, release GVL
|
112
112
|
self._backward(inputs, gradient, retain_graph, create_graph);
|
113
113
|
};
|
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
125
125
|
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
126
126
|
|
127
127
|
rb_cTensor
|
128
|
-
.define_method("cuda?", &
|
129
|
-
.define_method("sparse?", &
|
130
|
-
.define_method("quantized?", &
|
131
|
-
.define_method("dim", &
|
132
|
-
.define_method("numel", &
|
133
|
-
.define_method("element_size", &
|
134
|
-
.define_method("requires_grad", &
|
128
|
+
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
129
|
+
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
130
|
+
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
131
|
+
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
132
|
+
.define_method("numel", [](Tensor& self) { return self.numel(); })
|
133
|
+
.define_method("element_size", [](Tensor& self) { return self.element_size(); })
|
134
|
+
.define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
|
135
135
|
.define_method(
|
136
136
|
"_size",
|
137
137
|
[](Tensor& self, int64_t dim) {
|
data/ext/torch/utils.h
CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
|
|
16
16
|
return Rice::Symbol(str);
|
17
17
|
}
|
18
18
|
|
19
|
-
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
20
|
-
Check_Type(obj, T_SYMBOL);
|
21
|
-
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
22
|
-
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
23
|
-
}
|
24
|
-
|
25
19
|
inline std::string THPUtils_unpackString(VALUE obj) {
|
26
20
|
Check_Type(obj, T_STRING);
|
27
21
|
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
data/lib/torch/inspector.rb
CHANGED
@@ -247,7 +247,7 @@ module Torch
|
|
247
247
|
# length includes spaces and comma between elements
|
248
248
|
element_length = formatter.width + 2
|
249
249
|
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
250
|
-
|
250
|
+
_char_per_line = element_length * elements_per_line
|
251
251
|
|
252
252
|
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
253
253
|
data = (
|
data/lib/torch/nn/convnd.rb
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ConvNd < Module
|
4
|
+
attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
|
5
|
+
|
4
6
|
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
7
|
super()
|
6
8
|
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -571,7 +571,7 @@ module Torch
|
|
571
571
|
end
|
572
572
|
|
573
573
|
def _interp_output_size(closed_over_args)
|
574
|
-
input, size, scale_factor,
|
574
|
+
input, size, scale_factor, _recompute_scale_factor = closed_over_args
|
575
575
|
dim = input.dim - 2
|
576
576
|
if size.nil? && scale_factor.nil?
|
577
577
|
raise ArgumentError, "either size or scale_factor should be defined"
|
@@ -0,0 +1,241 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def in_projection_packed(q, k, v, w, b: nil)
|
6
|
+
e = q.size(-1)
|
7
|
+
|
8
|
+
if k.eql? v
|
9
|
+
if q.eql? k
|
10
|
+
# self-attention
|
11
|
+
return linear(q, w, b).chunk(3, dim: -1)
|
12
|
+
else
|
13
|
+
# encoder-decoder attention
|
14
|
+
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
15
|
+
if b.nil?
|
16
|
+
b_q = b_kv = nil
|
17
|
+
else
|
18
|
+
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
|
+
end
|
20
|
+
|
21
|
+
return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
|
+
end
|
23
|
+
else
|
24
|
+
w_q, w_k, w_v = w.chunk(3)
|
25
|
+
if b.nil?
|
26
|
+
b_q = b_k = b_v = nil
|
27
|
+
else
|
28
|
+
b_q, b_k, b_v = b.chunk(3)
|
29
|
+
end
|
30
|
+
|
31
|
+
return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
def in_projection(
|
36
|
+
q, k, v,
|
37
|
+
w_q, w_k, w_v,
|
38
|
+
b_q: nil, b_k: nil, b_v: nil
|
39
|
+
)
|
40
|
+
|
41
|
+
e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
|
42
|
+
|
43
|
+
raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
|
44
|
+
raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
|
45
|
+
raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
|
46
|
+
|
47
|
+
raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
|
48
|
+
raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
|
49
|
+
raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
|
50
|
+
|
51
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
52
|
+
end
|
53
|
+
|
54
|
+
def scaled_dot_product_attention(
|
55
|
+
q, k, v,
|
56
|
+
attn_mask: nil, dropout_p: 0.0
|
57
|
+
)
|
58
|
+
|
59
|
+
_b, _nt, e = q.shape
|
60
|
+
|
61
|
+
q = q / Math.sqrt(e)
|
62
|
+
|
63
|
+
attn = Torch.bmm(q, k.transpose(-2, -1))
|
64
|
+
attn += attn_mask if attn_mask
|
65
|
+
attn = softmax(attn, dim: -1)
|
66
|
+
attn = dropout(attn, p: dropout_p) if dropout_p > 0
|
67
|
+
|
68
|
+
output = Torch.bmm(attn, v)
|
69
|
+
|
70
|
+
[output, attn]
|
71
|
+
end
|
72
|
+
|
73
|
+
def multi_head_attention_forward(
|
74
|
+
query, key, value,
|
75
|
+
embed_dim_to_check, num_heads,
|
76
|
+
in_proj_weight, in_proj_bias,
|
77
|
+
bias_k, bias_v,
|
78
|
+
add_zero_attn,
|
79
|
+
dropout_p,
|
80
|
+
out_proj_weight, out_proj_bias,
|
81
|
+
training: true,
|
82
|
+
key_padding_mask: nil,
|
83
|
+
need_weights: true,
|
84
|
+
attn_mask: nil,
|
85
|
+
use_separate_proj_weight: false,
|
86
|
+
q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
|
87
|
+
static_k: nil, static_v: nil
|
88
|
+
)
|
89
|
+
|
90
|
+
tgt_len, bsz, embed_dim = query.shape
|
91
|
+
src_len = key.shape.first
|
92
|
+
|
93
|
+
raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
|
94
|
+
|
95
|
+
head_dim = if embed_dim.is_a?(Torch::Tensor)
|
96
|
+
embed_dim.div(num_heads, rounding_mode: 'trunc')
|
97
|
+
else
|
98
|
+
head_dim = embed_dim.div num_heads
|
99
|
+
end
|
100
|
+
|
101
|
+
if use_separate_proj_weight
|
102
|
+
raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
|
103
|
+
else
|
104
|
+
raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
|
105
|
+
end
|
106
|
+
|
107
|
+
# compute in-projection
|
108
|
+
q, k, v =
|
109
|
+
if use_separate_proj_weight
|
110
|
+
raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
|
111
|
+
raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
|
112
|
+
raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
|
113
|
+
|
114
|
+
if in_proj_bias
|
115
|
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
116
|
+
else
|
117
|
+
b_q = b_k = b_v = nil
|
118
|
+
end
|
119
|
+
|
120
|
+
in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
|
121
|
+
else
|
122
|
+
in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
|
123
|
+
end
|
124
|
+
|
125
|
+
# prep attention mask
|
126
|
+
if attn_mask
|
127
|
+
if attn_mask.dtype == :uint8
|
128
|
+
puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
129
|
+
attn_mask = attn_mask.bool
|
130
|
+
else
|
131
|
+
raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
|
132
|
+
end
|
133
|
+
|
134
|
+
if attn_mask.dim == 2
|
135
|
+
correct_2d_size = [tgt_len, src_len]
|
136
|
+
raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
|
137
|
+
|
138
|
+
attn_mask = attn_mask.unsqueeze(0)
|
139
|
+
elsif attn_mask.dim == 3
|
140
|
+
correct_3d_size = [bsz * num_heads, tgt_len, src_len]
|
141
|
+
raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
|
142
|
+
else
|
143
|
+
raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
# prep key padding mask
|
148
|
+
if key_padding_mask && key_padding_mask.dtype == :uint8
|
149
|
+
puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
150
|
+
key_padding_mask = key_padding_mask.bool
|
151
|
+
end
|
152
|
+
|
153
|
+
# add bias along batch dimension (currently second)
|
154
|
+
if bias_k && bias_v
|
155
|
+
raise ArgumentError, "bias cannot be added to static key." if static_k
|
156
|
+
raise ArgumentError, "bias cannot be added to static value." if static_v
|
157
|
+
|
158
|
+
k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
159
|
+
v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
160
|
+
|
161
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
162
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
163
|
+
else
|
164
|
+
raise ArgumentError unless bias_k.nil?
|
165
|
+
raise ArgumentError unless bias_v.nil?
|
166
|
+
end
|
167
|
+
|
168
|
+
# reshape q, k, v for multihead attention and make em batch first
|
169
|
+
q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
170
|
+
|
171
|
+
if static_k.nil?
|
172
|
+
k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
173
|
+
else
|
174
|
+
raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
|
175
|
+
raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
|
176
|
+
|
177
|
+
k = static_k
|
178
|
+
end
|
179
|
+
|
180
|
+
if static_v.nil?
|
181
|
+
v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
182
|
+
else
|
183
|
+
raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
|
184
|
+
raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
|
185
|
+
|
186
|
+
v = static_v
|
187
|
+
end
|
188
|
+
|
189
|
+
# add zero attention along batch dimension (now first)
|
190
|
+
if add_zero_attn
|
191
|
+
zero_attn_shape = [bsz * num_heads, 1, head_dim]
|
192
|
+
k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
|
193
|
+
v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
|
194
|
+
|
195
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
196
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
197
|
+
end
|
198
|
+
|
199
|
+
# update source sequence length after adjustments
|
200
|
+
src_len = k.size(1)
|
201
|
+
|
202
|
+
# merge key padding and attention masks
|
203
|
+
if key_padding_mask
|
204
|
+
raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
|
205
|
+
|
206
|
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
207
|
+
|
208
|
+
attn_mask = if attn_mask.nil?
|
209
|
+
key_padding_mask
|
210
|
+
elsif attn_mask.dtype == :bool
|
211
|
+
attn_mask.logical_or(key_padding_mask)
|
212
|
+
else
|
213
|
+
attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
# convert mask to float
|
218
|
+
if attn_mask && attn_mask.dtype == :bool
|
219
|
+
new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
|
220
|
+
attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
|
221
|
+
end
|
222
|
+
|
223
|
+
dropout_p = 0.0 unless training
|
224
|
+
|
225
|
+
# (deep breath) calculate attention and out projection
|
226
|
+
attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
|
227
|
+
attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
|
228
|
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
229
|
+
|
230
|
+
if need_weights
|
231
|
+
# average attention weights over heads
|
232
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
233
|
+
[attn_output, attn_output_weights.sum(dim: 1) / num_heads]
|
234
|
+
else
|
235
|
+
[attn_output, nil]
|
236
|
+
end
|
237
|
+
end
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -3,6 +3,8 @@ module Torch
|
|
3
3
|
class Module
|
4
4
|
include Utils
|
5
5
|
|
6
|
+
attr_reader :training
|
7
|
+
|
6
8
|
def initialize
|
7
9
|
@training = true
|
8
10
|
@parameters = {}
|
@@ -278,6 +280,11 @@ module Torch
|
|
278
280
|
end
|
279
281
|
end
|
280
282
|
|
283
|
+
def deep_dup
|
284
|
+
memo = {}
|
285
|
+
dup_value(self, memo)
|
286
|
+
end
|
287
|
+
|
281
288
|
def method_missing(method, *args, &block)
|
282
289
|
name = method.to_s
|
283
290
|
if named_parameters.key?(name)
|
@@ -386,6 +393,29 @@ module Torch
|
|
386
393
|
destination[prefix + k] = v
|
387
394
|
end
|
388
395
|
end
|
396
|
+
|
397
|
+
# keep memo hash like Python deepcopy
|
398
|
+
# https://docs.python.org/3/library/copy.html
|
399
|
+
def dup_value(v, memo)
|
400
|
+
memo[v.object_id] ||= begin
|
401
|
+
case v
|
402
|
+
when Method, UnboundMethod
|
403
|
+
v
|
404
|
+
when Hash
|
405
|
+
v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
|
406
|
+
when Array
|
407
|
+
v.map { |v2| dup_value(v2, memo) }
|
408
|
+
when Torch::NN::Module
|
409
|
+
copy = v.dup
|
410
|
+
v.instance_variables.each do |var|
|
411
|
+
copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
|
412
|
+
end
|
413
|
+
copy
|
414
|
+
else
|
415
|
+
v.dup
|
416
|
+
end
|
417
|
+
end
|
418
|
+
end
|
389
419
|
end
|
390
420
|
end
|
391
421
|
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ModuleList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(mods = nil)
|
7
|
+
super()
|
8
|
+
|
9
|
+
self.concat(mods) if mods
|
10
|
+
end
|
11
|
+
|
12
|
+
def length
|
13
|
+
@modules.length
|
14
|
+
end
|
15
|
+
alias_method :count, :length
|
16
|
+
alias_method :size, :length
|
17
|
+
|
18
|
+
def concat(mods)
|
19
|
+
raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
|
20
|
+
|
21
|
+
mods.each { |m| append m }
|
22
|
+
|
23
|
+
self
|
24
|
+
end
|
25
|
+
|
26
|
+
def each(&block)
|
27
|
+
if block_given?
|
28
|
+
@modules.values.each(&block)
|
29
|
+
else
|
30
|
+
to_enum(:each)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def append(mod)
|
35
|
+
raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
|
36
|
+
add_module(length.to_s, mod)
|
37
|
+
self
|
38
|
+
end
|
39
|
+
|
40
|
+
def [](idx)
|
41
|
+
if idx.is_a?(Range)
|
42
|
+
self.class.new(@modules.values[idx])
|
43
|
+
else
|
44
|
+
@modules[idx.to_s]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
@@ -0,0 +1,123 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MultiheadAttention < Module
|
4
|
+
def initialize(
|
5
|
+
embed_dim, num_heads,
|
6
|
+
dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
|
7
|
+
kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
|
8
|
+
)
|
9
|
+
|
10
|
+
super()
|
11
|
+
|
12
|
+
@embed_dim = embed_dim
|
13
|
+
@kdim = kdim || @embed_dim
|
14
|
+
@vdim = vdim || @embed_dim
|
15
|
+
|
16
|
+
@qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
|
17
|
+
|
18
|
+
@num_heads = num_heads
|
19
|
+
@dropout = dropout
|
20
|
+
@batch_first = batch_first
|
21
|
+
|
22
|
+
@head_dim = @embed_dim.div @num_heads
|
23
|
+
|
24
|
+
raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim
|
25
|
+
|
26
|
+
if @qkv_same_embed_dim
|
27
|
+
@in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
|
28
|
+
%w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
|
29
|
+
else
|
30
|
+
@q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
|
31
|
+
@k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
|
32
|
+
@v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))
|
33
|
+
|
34
|
+
register_parameter('in_proj_weight', nil)
|
35
|
+
end
|
36
|
+
|
37
|
+
if bias
|
38
|
+
@in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
|
39
|
+
else
|
40
|
+
register_parameter('in_proj_bias', nil)
|
41
|
+
end
|
42
|
+
|
43
|
+
@out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)
|
44
|
+
|
45
|
+
if add_bias_kv
|
46
|
+
@bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
|
47
|
+
@bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
|
48
|
+
else
|
49
|
+
@bias_k = @bias_v = nil
|
50
|
+
end
|
51
|
+
|
52
|
+
@add_zero_attn = add_zero_attn
|
53
|
+
|
54
|
+
reset_parameters
|
55
|
+
end
|
56
|
+
|
57
|
+
def batch_first?
|
58
|
+
!!@batch_first
|
59
|
+
end
|
60
|
+
|
61
|
+
def reset_parameters
|
62
|
+
if @qkv_same_embed_dim
|
63
|
+
Init.xavier_uniform!(@in_proj_weight)
|
64
|
+
else
|
65
|
+
Init.xavier_uniform!(@q_proj_weight)
|
66
|
+
Init.xavier_uniform!(@k_proj_weight)
|
67
|
+
Init.xavier_uniform!(@v_proj_weight)
|
68
|
+
end
|
69
|
+
|
70
|
+
if @in_proj_bias
|
71
|
+
Init.constant!(@in_proj_bias, 0.0)
|
72
|
+
Init.constant!(@out_proj.bias, 0.0)
|
73
|
+
end
|
74
|
+
|
75
|
+
Init.xavier_uniform!(@bias_k) if @bias_k
|
76
|
+
Init.xavier_uniform!(@bias_v) if @bias_v
|
77
|
+
end
|
78
|
+
|
79
|
+
def forward(
|
80
|
+
query, key, value,
|
81
|
+
key_padding_mask: nil, need_weights: true, attn_mask: nil
|
82
|
+
)
|
83
|
+
|
84
|
+
if batch_first?
|
85
|
+
query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
|
86
|
+
end
|
87
|
+
|
88
|
+
attn_output, attn_output_weights =
|
89
|
+
if @qkv_same_embed_dim
|
90
|
+
F.multi_head_attention_forward(
|
91
|
+
query, key, value,
|
92
|
+
@embed_dim, @num_heads,
|
93
|
+
@in_proj_weight, @in_proj_bias,
|
94
|
+
@bias_k, @bias_v, @add_zero_attn,
|
95
|
+
@dropout, @out_proj.weight, @out_proj.bias,
|
96
|
+
training: @training,
|
97
|
+
key_padding_mask: key_padding_mask,
|
98
|
+
need_weights: need_weights,
|
99
|
+
attn_mask: attn_mask
|
100
|
+
)
|
101
|
+
else
|
102
|
+
F.multi_head_attention_forward(
|
103
|
+
query, key, value,
|
104
|
+
@embed_dim, @num_heads,
|
105
|
+
@in_proj_weight, @in_proj_bias,
|
106
|
+
@bias_k, @bias_v, @add_zero_attn,
|
107
|
+
@dropout, @out_proj.weight, @out_proj.bias,
|
108
|
+
training: @training,
|
109
|
+
key_padding_mask: key_padding_mask,
|
110
|
+
need_weights: need_weights,
|
111
|
+
attn_mask: attn_mask,
|
112
|
+
use_separate_proj_weight: true,
|
113
|
+
q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
|
114
|
+
)
|
115
|
+
end
|
116
|
+
|
117
|
+
attn_output = attn_output.transpose(1, 0) if batch_first?
|
118
|
+
|
119
|
+
[attn_output, attn_output_weights]
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|