torch-rb 0.7.0 → 0.8.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +41 -4
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/ext.cpp +8 -0
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.h +7 -1
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +1 -0
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/parameter.rb +6 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +12 -0
- data/lib/torch/tensor.rb +8 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +6 -0
- metadata +18 -3
@@ -0,0 +1,17 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_backends(Rice::Module& m) {
|
8
|
+
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
|
+
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
|
+
.add_handler<torch::Error>(handle_error)
|
12
|
+
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
|
+
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
|
+
.add_handler<torch::Error>(handle_error)
|
16
|
+
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
}
|
data/ext/torch/ext.cpp
CHANGED
@@ -2,10 +2,14 @@
|
|
2
2
|
|
3
3
|
#include <rice/rice.hpp>
|
4
4
|
|
5
|
+
void init_fft(Rice::Module& m);
|
6
|
+
void init_linalg(Rice::Module& m);
|
5
7
|
void init_nn(Rice::Module& m);
|
8
|
+
void init_special(Rice::Module& m);
|
6
9
|
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
7
10
|
void init_torch(Rice::Module& m);
|
8
11
|
|
12
|
+
void init_backends(Rice::Module& m);
|
9
13
|
void init_cuda(Rice::Module& m);
|
10
14
|
void init_device(Rice::Module& m);
|
11
15
|
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
@@ -27,7 +31,11 @@ void Init_ext()
|
|
27
31
|
init_torch(m);
|
28
32
|
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
29
33
|
init_nn(m);
|
34
|
+
init_fft(m);
|
35
|
+
init_linalg(m);
|
36
|
+
init_special(m);
|
30
37
|
|
38
|
+
init_backends(m);
|
31
39
|
init_cuda(m);
|
32
40
|
init_device(m);
|
33
41
|
init_ivalue(m, rb_cIValue);
|
data/ext/torch/fft.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "fft_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_fft(Rice::Module& m) {
|
10
|
+
auto rb_mFFT = Rice::define_module_under(m, "FFT");
|
11
|
+
rb_mFFT.add_handler<torch::Error>(handle_error);
|
12
|
+
add_fft_functions(rb_mFFT);
|
13
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "linalg_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_linalg(Rice::Module& m) {
|
10
|
+
auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
|
11
|
+
rb_mLinalg.add_handler<torch::Error>(handle_error);
|
12
|
+
add_linalg_functions(rb_mLinalg);
|
13
|
+
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -78,6 +78,7 @@ struct RubyArgs {
|
|
78
78
|
inline OptionalTensor optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
|
+
inline std::vector<at::Scalar> scalarlist(int i);
|
81
82
|
inline std::vector<at::Tensor> tensorlist(int i);
|
82
83
|
template<int N>
|
83
84
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
@@ -134,6 +135,11 @@ inline at::Scalar RubyArgs::scalar(int i) {
|
|
134
135
|
return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
|
135
136
|
}
|
136
137
|
|
138
|
+
inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
|
139
|
+
if (NIL_P(args[i])) return std::vector<at::Scalar>();
|
140
|
+
return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
|
141
|
+
}
|
142
|
+
|
137
143
|
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
138
144
|
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
139
145
|
return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
|
@@ -312,7 +318,7 @@ inline std::string RubyArgs::string(int i) {
|
|
312
318
|
}
|
313
319
|
|
314
320
|
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
315
|
-
if (
|
321
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
316
322
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
317
323
|
}
|
318
324
|
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "special_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_special(Rice::Module& m) {
|
10
|
+
auto rb_mSpecial = Rice::define_module_under(m, "Special");
|
11
|
+
rb_mSpecial.add_handler<torch::Error>(handle_error);
|
12
|
+
add_special_functions(rb_mSpecial);
|
13
|
+
}
|
data/ext/torch/templates.h
CHANGED
data/lib/torch/nn/convnd.rb
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ConvNd < Module
|
4
|
+
attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
|
5
|
+
|
4
6
|
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
7
|
super()
|
6
8
|
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
@@ -0,0 +1,241 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def in_projection_packed(q, k, v, w, b: nil)
|
6
|
+
e = q.size(-1)
|
7
|
+
|
8
|
+
if k.eql? v
|
9
|
+
if q.eql? k
|
10
|
+
# self-attention
|
11
|
+
return linear(q, w, b).chunk(3, dim: -1)
|
12
|
+
else
|
13
|
+
# encoder-decoder attention
|
14
|
+
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
15
|
+
if b.nil?
|
16
|
+
b_q = b_kv = nil
|
17
|
+
else
|
18
|
+
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
|
+
end
|
20
|
+
|
21
|
+
return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
|
+
end
|
23
|
+
else
|
24
|
+
w_q, w_k, w_v = w.chunk(3)
|
25
|
+
if b.nil?
|
26
|
+
b_q = b_k = b_v = nil
|
27
|
+
else
|
28
|
+
b_q, b_k, b_v = b.chunk(3)
|
29
|
+
end
|
30
|
+
|
31
|
+
return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
def in_projection(
|
36
|
+
q, k, v,
|
37
|
+
w_q, w_k, w_v,
|
38
|
+
b_q: nil, b_k: nil, b_v: nil
|
39
|
+
)
|
40
|
+
|
41
|
+
e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
|
42
|
+
|
43
|
+
raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
|
44
|
+
raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
|
45
|
+
raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
|
46
|
+
|
47
|
+
raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
|
48
|
+
raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
|
49
|
+
raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
|
50
|
+
|
51
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
52
|
+
end
|
53
|
+
|
54
|
+
def scaled_dot_product_attention(
|
55
|
+
q, k, v,
|
56
|
+
attn_mask: nil, dropout_p: 0.0
|
57
|
+
)
|
58
|
+
|
59
|
+
b, nt, e = q.shape
|
60
|
+
|
61
|
+
q = q / Math.sqrt(e)
|
62
|
+
|
63
|
+
attn = Torch.bmm(q, k.transpose(-2, -1))
|
64
|
+
attn += attn_mask if attn_mask
|
65
|
+
attn = softmax(attn, dim: -1)
|
66
|
+
attn = dropout(attn, p: dropout_p) if dropout_p > 0
|
67
|
+
|
68
|
+
output = Torch.bmm(attn, v)
|
69
|
+
|
70
|
+
[output, attn]
|
71
|
+
end
|
72
|
+
|
73
|
+
def multi_head_attention_forward(
|
74
|
+
query, key, value,
|
75
|
+
embed_dim_to_check, num_heads,
|
76
|
+
in_proj_weight, in_proj_bias,
|
77
|
+
bias_k, bias_v,
|
78
|
+
add_zero_attn,
|
79
|
+
dropout_p,
|
80
|
+
out_proj_weight, out_proj_bias,
|
81
|
+
training: true,
|
82
|
+
key_padding_mask: nil,
|
83
|
+
need_weights: true,
|
84
|
+
attn_mask: nil,
|
85
|
+
use_separate_proj_weight: false,
|
86
|
+
q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
|
87
|
+
static_k: nil, static_v: nil
|
88
|
+
)
|
89
|
+
|
90
|
+
tgt_len, bsz, embed_dim = query.shape
|
91
|
+
src_len = key.shape.first
|
92
|
+
|
93
|
+
raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
|
94
|
+
|
95
|
+
head_dim = if embed_dim.is_a?(Torch::Tensor)
|
96
|
+
embed_dim.div(num_heads, rounding_mode: 'trunc')
|
97
|
+
else
|
98
|
+
head_dim = embed_dim.div num_heads
|
99
|
+
end
|
100
|
+
|
101
|
+
if use_separate_proj_weight
|
102
|
+
raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
|
103
|
+
else
|
104
|
+
raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
|
105
|
+
end
|
106
|
+
|
107
|
+
# compute in-projection
|
108
|
+
q, k, v =
|
109
|
+
if use_separate_proj_weight
|
110
|
+
raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
|
111
|
+
raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
|
112
|
+
raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
|
113
|
+
|
114
|
+
if in_proj_bias
|
115
|
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
116
|
+
else
|
117
|
+
b_q = b_k = b_v = nil
|
118
|
+
end
|
119
|
+
|
120
|
+
in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
|
121
|
+
else
|
122
|
+
in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
|
123
|
+
end
|
124
|
+
|
125
|
+
# prep attention mask
|
126
|
+
if attn_mask
|
127
|
+
if attn_mask.dtype == :uint8
|
128
|
+
puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
129
|
+
attn_mask = attn_mask.bool
|
130
|
+
else
|
131
|
+
raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
|
132
|
+
end
|
133
|
+
|
134
|
+
if attn_mask.dim == 2
|
135
|
+
correct_2d_size = [tgt_len, src_len]
|
136
|
+
raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
|
137
|
+
|
138
|
+
attn_mask = attn_mask.unsqueeze(0)
|
139
|
+
elsif attn_mask.dim == 3
|
140
|
+
correct_3d_size = [bsz * num_heads, tgt_len, src_len]
|
141
|
+
raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
|
142
|
+
else
|
143
|
+
raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
# prep key padding mask
|
148
|
+
if key_padding_mask && key_padding_mask.dtype == :uint8
|
149
|
+
puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
150
|
+
key_padding_mask = key_padding_mask.bool
|
151
|
+
end
|
152
|
+
|
153
|
+
# add bias along batch dimension (currently second)
|
154
|
+
if bias_k && bias_v
|
155
|
+
raise ArgumentError, "bias cannot be added to static key." if static_k
|
156
|
+
raise ArgumentError, "bias cannot be added to static value." if static_v
|
157
|
+
|
158
|
+
k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
159
|
+
v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
160
|
+
|
161
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
162
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
163
|
+
else
|
164
|
+
raise ArgumentError unless bias_k.nil?
|
165
|
+
raise ArgumentError unless bias_v.nil?
|
166
|
+
end
|
167
|
+
|
168
|
+
# reshape q, k, v for multihead attention and make em batch first
|
169
|
+
q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
170
|
+
|
171
|
+
if static_k.nil?
|
172
|
+
k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
173
|
+
else
|
174
|
+
raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
|
175
|
+
raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
|
176
|
+
|
177
|
+
k = static_k
|
178
|
+
end
|
179
|
+
|
180
|
+
if static_v.nil?
|
181
|
+
v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
182
|
+
else
|
183
|
+
raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
|
184
|
+
raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
|
185
|
+
|
186
|
+
v = static_v
|
187
|
+
end
|
188
|
+
|
189
|
+
# add zero attention along batch dimension (now first)
|
190
|
+
if add_zero_attn
|
191
|
+
zero_attn_shape = [bsz * num_heads, 1, head_dim]
|
192
|
+
k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
|
193
|
+
v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
|
194
|
+
|
195
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
196
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
197
|
+
end
|
198
|
+
|
199
|
+
# update source sequence length after adjustments
|
200
|
+
src_len = k.size(1)
|
201
|
+
|
202
|
+
# merge key padding and attention masks
|
203
|
+
if key_padding_mask
|
204
|
+
raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
|
205
|
+
|
206
|
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
207
|
+
|
208
|
+
attn_mask = if attn_mask.nil?
|
209
|
+
key_padding_mask
|
210
|
+
elsif attn_mask.dtype == :bool
|
211
|
+
attn_mask.logical_or(key_padding_mask)
|
212
|
+
else
|
213
|
+
attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
# convert mask to float
|
218
|
+
if attn_mask && attn_mask.dtype == :bool
|
219
|
+
new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
|
220
|
+
attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
|
221
|
+
end
|
222
|
+
|
223
|
+
dropout_p = 0.0 unless training
|
224
|
+
|
225
|
+
# (deep breath) calculate attention and out projection
|
226
|
+
attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
|
227
|
+
attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
|
228
|
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
229
|
+
|
230
|
+
if need_weights
|
231
|
+
# average attention weights over heads
|
232
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
233
|
+
[attn_output, attn_output_weights.sum(dim: 1) / num_heads]
|
234
|
+
else
|
235
|
+
[attn_output, nil]
|
236
|
+
end
|
237
|
+
end
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -3,6 +3,8 @@ module Torch
|
|
3
3
|
class Module
|
4
4
|
include Utils
|
5
5
|
|
6
|
+
attr_reader :training
|
7
|
+
|
6
8
|
def initialize
|
7
9
|
@training = true
|
8
10
|
@parameters = {}
|
@@ -278,6 +280,11 @@ module Torch
|
|
278
280
|
end
|
279
281
|
end
|
280
282
|
|
283
|
+
def deep_dup
|
284
|
+
memo = {}
|
285
|
+
dup_value(self, memo)
|
286
|
+
end
|
287
|
+
|
281
288
|
def method_missing(method, *args, &block)
|
282
289
|
name = method.to_s
|
283
290
|
if named_parameters.key?(name)
|
@@ -386,6 +393,29 @@ module Torch
|
|
386
393
|
destination[prefix + k] = v
|
387
394
|
end
|
388
395
|
end
|
396
|
+
|
397
|
+
# keep memo hash like Python deepcopy
|
398
|
+
# https://docs.python.org/3/library/copy.html
|
399
|
+
def dup_value(v, memo)
|
400
|
+
memo[v.object_id] ||= begin
|
401
|
+
case v
|
402
|
+
when Method, UnboundMethod
|
403
|
+
v
|
404
|
+
when Hash
|
405
|
+
v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
|
406
|
+
when Array
|
407
|
+
v.map { |v2| dup_value(v2, memo) }
|
408
|
+
when Torch::NN::Module
|
409
|
+
copy = v.dup
|
410
|
+
v.instance_variables.each do |var|
|
411
|
+
copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
|
412
|
+
end
|
413
|
+
copy
|
414
|
+
else
|
415
|
+
v.dup
|
416
|
+
end
|
417
|
+
end
|
418
|
+
end
|
389
419
|
end
|
390
420
|
end
|
391
421
|
end
|
@@ -0,0 +1,49 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ModuleList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(mods = nil)
|
7
|
+
super()
|
8
|
+
|
9
|
+
self.concat(mods) if mods
|
10
|
+
end
|
11
|
+
|
12
|
+
def length
|
13
|
+
@modules.length
|
14
|
+
end
|
15
|
+
alias_method :count, :length
|
16
|
+
alias_method :size, :length
|
17
|
+
|
18
|
+
def concat(mods)
|
19
|
+
raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
|
20
|
+
|
21
|
+
mods.each { |m| append m }
|
22
|
+
|
23
|
+
self
|
24
|
+
end
|
25
|
+
|
26
|
+
def each(&block)
|
27
|
+
if block_given?
|
28
|
+
@modules.values.each(&block)
|
29
|
+
else
|
30
|
+
to_enum(:each)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def append(mod)
|
35
|
+
raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
|
36
|
+
add_module(length.to_s, mod)
|
37
|
+
self
|
38
|
+
end
|
39
|
+
|
40
|
+
def [](idx)
|
41
|
+
if idx.is_a?(Range)
|
42
|
+
self.class.new(@modules.values[idx])
|
43
|
+
else
|
44
|
+
@modules[idx.to_s]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
@@ -0,0 +1,123 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MultiheadAttention < Module
|
4
|
+
def initialize(
|
5
|
+
embed_dim, num_heads,
|
6
|
+
dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
|
7
|
+
kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
|
8
|
+
)
|
9
|
+
|
10
|
+
super()
|
11
|
+
|
12
|
+
@embed_dim = embed_dim
|
13
|
+
@kdim = kdim || @embed_dim
|
14
|
+
@vdim = vdim || @embed_dim
|
15
|
+
|
16
|
+
@qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
|
17
|
+
|
18
|
+
@num_heads = num_heads
|
19
|
+
@dropout = dropout
|
20
|
+
@batch_first = batch_first
|
21
|
+
|
22
|
+
@head_dim = @embed_dim.div @num_heads
|
23
|
+
|
24
|
+
raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim
|
25
|
+
|
26
|
+
if @qkv_same_embed_dim
|
27
|
+
@in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
|
28
|
+
%w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
|
29
|
+
else
|
30
|
+
@q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
|
31
|
+
@k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
|
32
|
+
@v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))
|
33
|
+
|
34
|
+
register_parameter('in_proj_weight', nil)
|
35
|
+
end
|
36
|
+
|
37
|
+
if bias
|
38
|
+
@in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
|
39
|
+
else
|
40
|
+
register_parameter('in_proj_bias', nil)
|
41
|
+
end
|
42
|
+
|
43
|
+
@out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)
|
44
|
+
|
45
|
+
if add_bias_kv
|
46
|
+
@bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
|
47
|
+
@bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
|
48
|
+
else
|
49
|
+
@bias_k = @bias_v = nil
|
50
|
+
end
|
51
|
+
|
52
|
+
@add_zero_attn = add_zero_attn
|
53
|
+
|
54
|
+
reset_parameters
|
55
|
+
end
|
56
|
+
|
57
|
+
def batch_first?
|
58
|
+
!!@batch_first
|
59
|
+
end
|
60
|
+
|
61
|
+
def reset_parameters
|
62
|
+
if @qkv_same_embed_dim
|
63
|
+
Init.xavier_uniform!(@in_proj_weight)
|
64
|
+
else
|
65
|
+
Init.xavier_uniform!(@q_proj_weight)
|
66
|
+
Init.xavier_uniform!(@k_proj_weight)
|
67
|
+
Init.xavier_uniform!(@v_proj_weight)
|
68
|
+
end
|
69
|
+
|
70
|
+
if @in_proj_bias
|
71
|
+
Init.constant!(@in_proj_bias, 0.0)
|
72
|
+
Init.constant!(@out_proj.bias, 0.0)
|
73
|
+
end
|
74
|
+
|
75
|
+
Init.xavier_uniform!(@bias_k) if @bias_k
|
76
|
+
Init.xavier_uniform!(@bias_v) if @bias_v
|
77
|
+
end
|
78
|
+
|
79
|
+
def forward(
|
80
|
+
query, key, value,
|
81
|
+
key_padding_mask: nil, need_weights: true, attn_mask: nil
|
82
|
+
)
|
83
|
+
|
84
|
+
if batch_first?
|
85
|
+
query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
|
86
|
+
end
|
87
|
+
|
88
|
+
attn_output, attn_output_weights =
|
89
|
+
if @qkv_same_embed_dim
|
90
|
+
F.multi_head_attention_forward(
|
91
|
+
query, key, value,
|
92
|
+
@embed_dim, @num_heads,
|
93
|
+
@in_proj_weight, @in_proj_bias,
|
94
|
+
@bias_k, @bias_v, @add_zero_attn,
|
95
|
+
@dropout, @out_proj.weight, @out_proj.bias,
|
96
|
+
training: @training,
|
97
|
+
key_padding_mask: key_padding_mask,
|
98
|
+
need_weights: need_weights,
|
99
|
+
attn_mask: attn_mask
|
100
|
+
)
|
101
|
+
else
|
102
|
+
F.multi_head_attention_forward(
|
103
|
+
query, key, value,
|
104
|
+
@embed_dim, @num_heads,
|
105
|
+
@in_proj_weight, @in_proj_bias,
|
106
|
+
@bias_k, @bias_v, @add_zero_attn,
|
107
|
+
@dropout, @out_proj.weight, @out_proj.bias,
|
108
|
+
training: @training,
|
109
|
+
key_padding_mask: key_padding_mask,
|
110
|
+
need_weights: need_weights,
|
111
|
+
attn_mask: attn_mask,
|
112
|
+
use_separate_proj_weight: true,
|
113
|
+
q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
|
114
|
+
)
|
115
|
+
end
|
116
|
+
|
117
|
+
attn_output = attn_output.transpose(1, 0) if batch_first?
|
118
|
+
|
119
|
+
[attn_output, attn_output_weights]
|
120
|
+
end
|
121
|
+
end
|
122
|
+
end
|
123
|
+
end
|