torch-rb 0.7.0 → 0.8.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/ext.cpp CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
5
7
  void init_nn(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
6
9
  void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
7
10
  void init_torch(Rice::Module& m);
8
11
 
12
+ void init_backends(Rice::Module& m);
9
13
  void init_cuda(Rice::Module& m);
10
14
  void init_device(Rice::Module& m);
11
15
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
@@ -27,7 +31,11 @@ void Init_ext()
27
31
  init_torch(m);
28
32
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
29
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
30
37
 
38
+ init_backends(m);
31
39
  init_cuda(m);
32
40
  init_device(m);
33
41
  init_ivalue(m, rb_cIValue);
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -78,6 +78,7 @@ struct RubyArgs {
78
78
  inline OptionalTensor optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
+ inline std::vector<at::Scalar> scalarlist(int i);
81
82
  inline std::vector<at::Tensor> tensorlist(int i);
82
83
  template<int N>
83
84
  inline std::array<at::Tensor, N> tensorlist_n(int i);
@@ -134,6 +135,11 @@ inline at::Scalar RubyArgs::scalar(int i) {
134
135
  return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
135
136
  }
136
137
 
138
+ inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
139
+ if (NIL_P(args[i])) return std::vector<at::Scalar>();
140
+ return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
141
+ }
142
+
137
143
  inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
138
144
  if (NIL_P(args[i])) return std::vector<at::Tensor>();
139
145
  return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
@@ -312,7 +318,7 @@ inline std::string RubyArgs::string(int i) {
312
318
  }
313
319
 
314
320
  inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
315
- if (!args[i]) return c10::nullopt;
321
+ if (NIL_P(args[i])) return c10::nullopt;
316
322
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
317
323
  }
318
324
 
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -21,6 +21,7 @@ using torch::IntArrayRef;
21
21
  using torch::ArrayRef;
22
22
  using torch::TensorList;
23
23
  using torch::Storage;
24
+ using ScalarList = ArrayRef<Scalar>;
24
25
 
25
26
  using torch::nn::init::FanModeType;
26
27
  using torch::nn::init::NonlinearityType;
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
5
+
4
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
7
  super()
6
8
  raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
@@ -0,0 +1,241 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def in_projection_packed(q, k, v, w, b: nil)
6
+ e = q.size(-1)
7
+
8
+ if k.eql? v
9
+ if q.eql? k
10
+ # self-attention
11
+ return linear(q, w, b).chunk(3, dim: -1)
12
+ else
13
+ # encoder-decoder attention
14
+ w_q, w_kv = w.split_with_sizes([e, e * 2])
15
+ if b.nil?
16
+ b_q = b_kv = nil
17
+ else
18
+ b_q, b_kv = b.split_with_sizes([e, e * 2])
19
+ end
20
+
21
+ return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
+ end
23
+ else
24
+ w_q, w_k, w_v = w.chunk(3)
25
+ if b.nil?
26
+ b_q = b_k = b_v = nil
27
+ else
28
+ b_q, b_k, b_v = b.chunk(3)
29
+ end
30
+
31
+ return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
+ end
33
+ end
34
+
35
+ def in_projection(
36
+ q, k, v,
37
+ w_q, w_k, w_v,
38
+ b_q: nil, b_k: nil, b_v: nil
39
+ )
40
+
41
+ e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
+
43
+ raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
44
+ raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
45
+ raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
46
+
47
+ raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
48
+ raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
49
+ raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
50
+
51
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
52
+ end
53
+
54
+ def scaled_dot_product_attention(
55
+ q, k, v,
56
+ attn_mask: nil, dropout_p: 0.0
57
+ )
58
+
59
+ b, nt, e = q.shape
60
+
61
+ q = q / Math.sqrt(e)
62
+
63
+ attn = Torch.bmm(q, k.transpose(-2, -1))
64
+ attn += attn_mask if attn_mask
65
+ attn = softmax(attn, dim: -1)
66
+ attn = dropout(attn, p: dropout_p) if dropout_p > 0
67
+
68
+ output = Torch.bmm(attn, v)
69
+
70
+ [output, attn]
71
+ end
72
+
73
+ def multi_head_attention_forward(
74
+ query, key, value,
75
+ embed_dim_to_check, num_heads,
76
+ in_proj_weight, in_proj_bias,
77
+ bias_k, bias_v,
78
+ add_zero_attn,
79
+ dropout_p,
80
+ out_proj_weight, out_proj_bias,
81
+ training: true,
82
+ key_padding_mask: nil,
83
+ need_weights: true,
84
+ attn_mask: nil,
85
+ use_separate_proj_weight: false,
86
+ q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
+ static_k: nil, static_v: nil
88
+ )
89
+
90
+ tgt_len, bsz, embed_dim = query.shape
91
+ src_len = key.shape.first
92
+
93
+ raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
94
+
95
+ head_dim = if embed_dim.is_a?(Torch::Tensor)
96
+ embed_dim.div(num_heads, rounding_mode: 'trunc')
97
+ else
98
+ head_dim = embed_dim.div num_heads
99
+ end
100
+
101
+ if use_separate_proj_weight
102
+ raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
103
+ else
104
+ raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
105
+ end
106
+
107
+ # compute in-projection
108
+ q, k, v =
109
+ if use_separate_proj_weight
110
+ raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
111
+ raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
112
+ raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
113
+
114
+ if in_proj_bias
115
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
116
+ else
117
+ b_q = b_k = b_v = nil
118
+ end
119
+
120
+ in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
121
+ else
122
+ in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
123
+ end
124
+
125
+ # prep attention mask
126
+ if attn_mask
127
+ if attn_mask.dtype == :uint8
128
+ puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
129
+ attn_mask = attn_mask.bool
130
+ else
131
+ raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
132
+ end
133
+
134
+ if attn_mask.dim == 2
135
+ correct_2d_size = [tgt_len, src_len]
136
+ raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
137
+
138
+ attn_mask = attn_mask.unsqueeze(0)
139
+ elsif attn_mask.dim == 3
140
+ correct_3d_size = [bsz * num_heads, tgt_len, src_len]
141
+ raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
142
+ else
143
+ raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
144
+ end
145
+ end
146
+
147
+ # prep key padding mask
148
+ if key_padding_mask && key_padding_mask.dtype == :uint8
149
+ puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
150
+ key_padding_mask = key_padding_mask.bool
151
+ end
152
+
153
+ # add bias along batch dimension (currently second)
154
+ if bias_k && bias_v
155
+ raise ArgumentError, "bias cannot be added to static key." if static_k
156
+ raise ArgumentError, "bias cannot be added to static value." if static_v
157
+
158
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
159
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
160
+
161
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
162
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
163
+ else
164
+ raise ArgumentError unless bias_k.nil?
165
+ raise ArgumentError unless bias_v.nil?
166
+ end
167
+
168
+ # reshape q, k, v for multihead attention and make em batch first
169
+ q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
170
+
171
+ if static_k.nil?
172
+ k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
173
+ else
174
+ raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
175
+ raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
176
+
177
+ k = static_k
178
+ end
179
+
180
+ if static_v.nil?
181
+ v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
182
+ else
183
+ raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
184
+ raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
185
+
186
+ v = static_v
187
+ end
188
+
189
+ # add zero attention along batch dimension (now first)
190
+ if add_zero_attn
191
+ zero_attn_shape = [bsz * num_heads, 1, head_dim]
192
+ k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
193
+ v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
194
+
195
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
196
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
197
+ end
198
+
199
+ # update source sequence length after adjustments
200
+ src_len = k.size(1)
201
+
202
+ # merge key padding and attention masks
203
+ if key_padding_mask
204
+ raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
205
+
206
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
207
+
208
+ attn_mask = if attn_mask.nil?
209
+ key_padding_mask
210
+ elsif attn_mask.dtype == :bool
211
+ attn_mask.logical_or(key_padding_mask)
212
+ else
213
+ attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
214
+ end
215
+ end
216
+
217
+ # convert mask to float
218
+ if attn_mask && attn_mask.dtype == :bool
219
+ new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
220
+ attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
221
+ end
222
+
223
+ dropout_p = 0.0 unless training
224
+
225
+ # (deep breath) calculate attention and out projection
226
+ attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
227
+ attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
228
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
229
+
230
+ if need_weights
231
+ # average attention weights over heads
232
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
233
+ [attn_output, attn_output_weights.sum(dim: 1) / num_heads]
234
+ else
235
+ [attn_output, nil]
236
+ end
237
+ end
238
+ end
239
+ end
240
+ end
241
+ end
@@ -3,6 +3,8 @@ module Torch
3
3
  class Module
4
4
  include Utils
5
5
 
6
+ attr_reader :training
7
+
6
8
  def initialize
7
9
  @training = true
8
10
  @parameters = {}
@@ -278,6 +280,11 @@ module Torch
278
280
  end
279
281
  end
280
282
 
283
+ def deep_dup
284
+ memo = {}
285
+ dup_value(self, memo)
286
+ end
287
+
281
288
  def method_missing(method, *args, &block)
282
289
  name = method.to_s
283
290
  if named_parameters.key?(name)
@@ -386,6 +393,29 @@ module Torch
386
393
  destination[prefix + k] = v
387
394
  end
388
395
  end
396
+
397
+ # keep memo hash like Python deepcopy
398
+ # https://docs.python.org/3/library/copy.html
399
+ def dup_value(v, memo)
400
+ memo[v.object_id] ||= begin
401
+ case v
402
+ when Method, UnboundMethod
403
+ v
404
+ when Hash
405
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
406
+ when Array
407
+ v.map { |v2| dup_value(v2, memo) }
408
+ when Torch::NN::Module
409
+ copy = v.dup
410
+ v.instance_variables.each do |var|
411
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
412
+ end
413
+ copy
414
+ else
415
+ v.dup
416
+ end
417
+ end
418
+ end
389
419
  end
390
420
  end
391
421
  end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class ModuleList < Module
4
+ include Enumerable
5
+
6
+ def initialize(mods = nil)
7
+ super()
8
+
9
+ self.concat(mods) if mods
10
+ end
11
+
12
+ def length
13
+ @modules.length
14
+ end
15
+ alias_method :count, :length
16
+ alias_method :size, :length
17
+
18
+ def concat(mods)
19
+ raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
20
+
21
+ mods.each { |m| append m }
22
+
23
+ self
24
+ end
25
+
26
+ def each(&block)
27
+ if block_given?
28
+ @modules.values.each(&block)
29
+ else
30
+ to_enum(:each)
31
+ end
32
+ end
33
+
34
+ def append(mod)
35
+ raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
36
+ add_module(length.to_s, mod)
37
+ self
38
+ end
39
+
40
+ def [](idx)
41
+ if idx.is_a?(Range)
42
+ self.class.new(@modules.values[idx])
43
+ else
44
+ @modules[idx.to_s]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,123 @@
1
+ module Torch
2
+ module NN
3
+ class MultiheadAttention < Module
4
+ def initialize(
5
+ embed_dim, num_heads,
6
+ dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
7
+ kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
8
+ )
9
+
10
+ super()
11
+
12
+ @embed_dim = embed_dim
13
+ @kdim = kdim || @embed_dim
14
+ @vdim = vdim || @embed_dim
15
+
16
+ @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
17
+
18
+ @num_heads = num_heads
19
+ @dropout = dropout
20
+ @batch_first = batch_first
21
+
22
+ @head_dim = @embed_dim.div @num_heads
23
+
24
+ raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim
25
+
26
+ if @qkv_same_embed_dim
27
+ @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
28
+ %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
29
+ else
30
+ @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
31
+ @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
32
+ @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))
33
+
34
+ register_parameter('in_proj_weight', nil)
35
+ end
36
+
37
+ if bias
38
+ @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
39
+ else
40
+ register_parameter('in_proj_bias', nil)
41
+ end
42
+
43
+ @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)
44
+
45
+ if add_bias_kv
46
+ @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
47
+ @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
48
+ else
49
+ @bias_k = @bias_v = nil
50
+ end
51
+
52
+ @add_zero_attn = add_zero_attn
53
+
54
+ reset_parameters
55
+ end
56
+
57
+ def batch_first?
58
+ !!@batch_first
59
+ end
60
+
61
+ def reset_parameters
62
+ if @qkv_same_embed_dim
63
+ Init.xavier_uniform!(@in_proj_weight)
64
+ else
65
+ Init.xavier_uniform!(@q_proj_weight)
66
+ Init.xavier_uniform!(@k_proj_weight)
67
+ Init.xavier_uniform!(@v_proj_weight)
68
+ end
69
+
70
+ if @in_proj_bias
71
+ Init.constant!(@in_proj_bias, 0.0)
72
+ Init.constant!(@out_proj.bias, 0.0)
73
+ end
74
+
75
+ Init.xavier_uniform!(@bias_k) if @bias_k
76
+ Init.xavier_uniform!(@bias_v) if @bias_v
77
+ end
78
+
79
+ def forward(
80
+ query, key, value,
81
+ key_padding_mask: nil, need_weights: true, attn_mask: nil
82
+ )
83
+
84
+ if batch_first?
85
+ query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
86
+ end
87
+
88
+ attn_output, attn_output_weights =
89
+ if @qkv_same_embed_dim
90
+ F.multi_head_attention_forward(
91
+ query, key, value,
92
+ @embed_dim, @num_heads,
93
+ @in_proj_weight, @in_proj_bias,
94
+ @bias_k, @bias_v, @add_zero_attn,
95
+ @dropout, @out_proj.weight, @out_proj.bias,
96
+ training: @training,
97
+ key_padding_mask: key_padding_mask,
98
+ need_weights: need_weights,
99
+ attn_mask: attn_mask
100
+ )
101
+ else
102
+ F.multi_head_attention_forward(
103
+ query, key, value,
104
+ @embed_dim, @num_heads,
105
+ @in_proj_weight, @in_proj_bias,
106
+ @bias_k, @bias_v, @add_zero_attn,
107
+ @dropout, @out_proj.weight, @out_proj.bias,
108
+ training: @training,
109
+ key_padding_mask: key_padding_mask,
110
+ need_weights: need_weights,
111
+ attn_mask: attn_mask,
112
+ use_separate_proj_weight: true,
113
+ q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
114
+ )
115
+ end
116
+
117
+ attn_output = attn_output.transpose(1, 0) if batch_first?
118
+
119
+ [attn_output, attn_output_weights]
120
+ end
121
+ end
122
+ end
123
+ end
@@ -9,6 +9,12 @@ module Torch
9
9
  def inspect
10
10
  "Parameter containing:\n#{super}"
11
11
  end
12
+
13
+ def dup
14
+ Torch.no_grad do
15
+ Parameter.new(clone, requires_grad: requires_grad)
16
+ end
17
+ end
12
18
  end
13
19
  end
14
20
  end