torch-rb 0.7.0 → 0.8.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
data/ext/torch/ext.cpp CHANGED
@@ -2,10 +2,14 @@
2
2
 
3
3
  #include <rice/rice.hpp>
4
4
 
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
5
7
  void init_nn(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
6
9
  void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
7
10
  void init_torch(Rice::Module& m);
8
11
 
12
+ void init_backends(Rice::Module& m);
9
13
  void init_cuda(Rice::Module& m);
10
14
  void init_device(Rice::Module& m);
11
15
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
@@ -27,7 +31,11 @@ void Init_ext()
27
31
  init_torch(m);
28
32
  init_tensor(m, rb_cTensor, rb_cTensorOptions);
29
33
  init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
30
37
 
38
+ init_backends(m);
31
39
  init_cuda(m);
32
40
  init_device(m);
33
41
  init_ivalue(m, rb_cIValue);
data/ext/torch/fft.cpp ADDED
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "fft_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_fft(Rice::Module& m) {
10
+ auto rb_mFFT = Rice::define_module_under(m, "FFT");
11
+ rb_mFFT.add_handler<torch::Error>(handle_error);
12
+ add_fft_functions(rb_mFFT);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_fft_functions(Rice::Module& m);
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "linalg_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_linalg(Rice::Module& m) {
10
+ auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
11
+ rb_mLinalg.add_handler<torch::Error>(handle_error);
12
+ add_linalg_functions(rb_mLinalg);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_linalg_functions(Rice::Module& m);
@@ -78,6 +78,7 @@ struct RubyArgs {
78
78
  inline OptionalTensor optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
+ inline std::vector<at::Scalar> scalarlist(int i);
81
82
  inline std::vector<at::Tensor> tensorlist(int i);
82
83
  template<int N>
83
84
  inline std::array<at::Tensor, N> tensorlist_n(int i);
@@ -134,6 +135,11 @@ inline at::Scalar RubyArgs::scalar(int i) {
134
135
  return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
135
136
  }
136
137
 
138
+ inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
139
+ if (NIL_P(args[i])) return std::vector<at::Scalar>();
140
+ return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
141
+ }
142
+
137
143
  inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
138
144
  if (NIL_P(args[i])) return std::vector<at::Tensor>();
139
145
  return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
@@ -312,7 +318,7 @@ inline std::string RubyArgs::string(int i) {
312
318
  }
313
319
 
314
320
  inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
315
- if (!args[i]) return c10::nullopt;
321
+ if (NIL_P(args[i])) return c10::nullopt;
316
322
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
317
323
  }
318
324
 
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -21,6 +21,7 @@ using torch::IntArrayRef;
21
21
  using torch::ArrayRef;
22
22
  using torch::TensorList;
23
23
  using torch::Storage;
24
+ using ScalarList = ArrayRef<Scalar>;
24
25
 
25
26
  using torch::nn::init::FanModeType;
26
27
  using torch::nn::init::NonlinearityType;
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
5
+
4
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
7
  super()
6
8
  raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
@@ -0,0 +1,241 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def in_projection_packed(q, k, v, w, b: nil)
6
+ e = q.size(-1)
7
+
8
+ if k.eql? v
9
+ if q.eql? k
10
+ # self-attention
11
+ return linear(q, w, b).chunk(3, dim: -1)
12
+ else
13
+ # encoder-decoder attention
14
+ w_q, w_kv = w.split_with_sizes([e, e * 2])
15
+ if b.nil?
16
+ b_q = b_kv = nil
17
+ else
18
+ b_q, b_kv = b.split_with_sizes([e, e * 2])
19
+ end
20
+
21
+ return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
+ end
23
+ else
24
+ w_q, w_k, w_v = w.chunk(3)
25
+ if b.nil?
26
+ b_q = b_k = b_v = nil
27
+ else
28
+ b_q, b_k, b_v = b.chunk(3)
29
+ end
30
+
31
+ return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
+ end
33
+ end
34
+
35
+ def in_projection(
36
+ q, k, v,
37
+ w_q, w_k, w_v,
38
+ b_q: nil, b_k: nil, b_v: nil
39
+ )
40
+
41
+ e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
+
43
+ raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
44
+ raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
45
+ raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
46
+
47
+ raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
48
+ raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
49
+ raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
50
+
51
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
52
+ end
53
+
54
+ def scaled_dot_product_attention(
55
+ q, k, v,
56
+ attn_mask: nil, dropout_p: 0.0
57
+ )
58
+
59
+ b, nt, e = q.shape
60
+
61
+ q = q / Math.sqrt(e)
62
+
63
+ attn = Torch.bmm(q, k.transpose(-2, -1))
64
+ attn += attn_mask if attn_mask
65
+ attn = softmax(attn, dim: -1)
66
+ attn = dropout(attn, p: dropout_p) if dropout_p > 0
67
+
68
+ output = Torch.bmm(attn, v)
69
+
70
+ [output, attn]
71
+ end
72
+
73
+ def multi_head_attention_forward(
74
+ query, key, value,
75
+ embed_dim_to_check, num_heads,
76
+ in_proj_weight, in_proj_bias,
77
+ bias_k, bias_v,
78
+ add_zero_attn,
79
+ dropout_p,
80
+ out_proj_weight, out_proj_bias,
81
+ training: true,
82
+ key_padding_mask: nil,
83
+ need_weights: true,
84
+ attn_mask: nil,
85
+ use_separate_proj_weight: false,
86
+ q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
+ static_k: nil, static_v: nil
88
+ )
89
+
90
+ tgt_len, bsz, embed_dim = query.shape
91
+ src_len = key.shape.first
92
+
93
+ raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
94
+
95
+ head_dim = if embed_dim.is_a?(Torch::Tensor)
96
+ embed_dim.div(num_heads, rounding_mode: 'trunc')
97
+ else
98
+ head_dim = embed_dim.div num_heads
99
+ end
100
+
101
+ if use_separate_proj_weight
102
+ raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
103
+ else
104
+ raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
105
+ end
106
+
107
+ # compute in-projection
108
+ q, k, v =
109
+ if use_separate_proj_weight
110
+ raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
111
+ raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
112
+ raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
113
+
114
+ if in_proj_bias
115
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
116
+ else
117
+ b_q = b_k = b_v = nil
118
+ end
119
+
120
+ in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
121
+ else
122
+ in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
123
+ end
124
+
125
+ # prep attention mask
126
+ if attn_mask
127
+ if attn_mask.dtype == :uint8
128
+ puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
129
+ attn_mask = attn_mask.bool
130
+ else
131
+ raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
132
+ end
133
+
134
+ if attn_mask.dim == 2
135
+ correct_2d_size = [tgt_len, src_len]
136
+ raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
137
+
138
+ attn_mask = attn_mask.unsqueeze(0)
139
+ elsif attn_mask.dim == 3
140
+ correct_3d_size = [bsz * num_heads, tgt_len, src_len]
141
+ raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
142
+ else
143
+ raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
144
+ end
145
+ end
146
+
147
+ # prep key padding mask
148
+ if key_padding_mask && key_padding_mask.dtype == :uint8
149
+ puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
150
+ key_padding_mask = key_padding_mask.bool
151
+ end
152
+
153
+ # add bias along batch dimension (currently second)
154
+ if bias_k && bias_v
155
+ raise ArgumentError, "bias cannot be added to static key." if static_k
156
+ raise ArgumentError, "bias cannot be added to static value." if static_v
157
+
158
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
159
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
160
+
161
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
162
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
163
+ else
164
+ raise ArgumentError unless bias_k.nil?
165
+ raise ArgumentError unless bias_v.nil?
166
+ end
167
+
168
+ # reshape q, k, v for multihead attention and make em batch first
169
+ q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
170
+
171
+ if static_k.nil?
172
+ k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
173
+ else
174
+ raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
175
+ raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
176
+
177
+ k = static_k
178
+ end
179
+
180
+ if static_v.nil?
181
+ v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
182
+ else
183
+ raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
184
+ raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
185
+
186
+ v = static_v
187
+ end
188
+
189
+ # add zero attention along batch dimension (now first)
190
+ if add_zero_attn
191
+ zero_attn_shape = [bsz * num_heads, 1, head_dim]
192
+ k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
193
+ v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
194
+
195
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
196
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
197
+ end
198
+
199
+ # update source sequence length after adjustments
200
+ src_len = k.size(1)
201
+
202
+ # merge key padding and attention masks
203
+ if key_padding_mask
204
+ raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
205
+
206
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
207
+
208
+ attn_mask = if attn_mask.nil?
209
+ key_padding_mask
210
+ elsif attn_mask.dtype == :bool
211
+ attn_mask.logical_or(key_padding_mask)
212
+ else
213
+ attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
214
+ end
215
+ end
216
+
217
+ # convert mask to float
218
+ if attn_mask && attn_mask.dtype == :bool
219
+ new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
220
+ attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
221
+ end
222
+
223
+ dropout_p = 0.0 unless training
224
+
225
+ # (deep breath) calculate attention and out projection
226
+ attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
227
+ attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
228
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
229
+
230
+ if need_weights
231
+ # average attention weights over heads
232
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
233
+ [attn_output, attn_output_weights.sum(dim: 1) / num_heads]
234
+ else
235
+ [attn_output, nil]
236
+ end
237
+ end
238
+ end
239
+ end
240
+ end
241
+ end
@@ -3,6 +3,8 @@ module Torch
3
3
  class Module
4
4
  include Utils
5
5
 
6
+ attr_reader :training
7
+
6
8
  def initialize
7
9
  @training = true
8
10
  @parameters = {}
@@ -278,6 +280,11 @@ module Torch
278
280
  end
279
281
  end
280
282
 
283
+ def deep_dup
284
+ memo = {}
285
+ dup_value(self, memo)
286
+ end
287
+
281
288
  def method_missing(method, *args, &block)
282
289
  name = method.to_s
283
290
  if named_parameters.key?(name)
@@ -386,6 +393,29 @@ module Torch
386
393
  destination[prefix + k] = v
387
394
  end
388
395
  end
396
+
397
+ # keep memo hash like Python deepcopy
398
+ # https://docs.python.org/3/library/copy.html
399
+ def dup_value(v, memo)
400
+ memo[v.object_id] ||= begin
401
+ case v
402
+ when Method, UnboundMethod
403
+ v
404
+ when Hash
405
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
406
+ when Array
407
+ v.map { |v2| dup_value(v2, memo) }
408
+ when Torch::NN::Module
409
+ copy = v.dup
410
+ v.instance_variables.each do |var|
411
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
412
+ end
413
+ copy
414
+ else
415
+ v.dup
416
+ end
417
+ end
418
+ end
389
419
  end
390
420
  end
391
421
  end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class ModuleList < Module
4
+ include Enumerable
5
+
6
+ def initialize(mods = nil)
7
+ super()
8
+
9
+ self.concat(mods) if mods
10
+ end
11
+
12
+ def length
13
+ @modules.length
14
+ end
15
+ alias_method :count, :length
16
+ alias_method :size, :length
17
+
18
+ def concat(mods)
19
+ raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
20
+
21
+ mods.each { |m| append m }
22
+
23
+ self
24
+ end
25
+
26
+ def each(&block)
27
+ if block_given?
28
+ @modules.values.each(&block)
29
+ else
30
+ to_enum(:each)
31
+ end
32
+ end
33
+
34
+ def append(mod)
35
+ raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
36
+ add_module(length.to_s, mod)
37
+ self
38
+ end
39
+
40
+ def [](idx)
41
+ if idx.is_a?(Range)
42
+ self.class.new(@modules.values[idx])
43
+ else
44
+ @modules[idx.to_s]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,123 @@
1
+ module Torch
2
+ module NN
3
+ class MultiheadAttention < Module
4
+ def initialize(
5
+ embed_dim, num_heads,
6
+ dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
7
+ kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
8
+ )
9
+
10
+ super()
11
+
12
+ @embed_dim = embed_dim
13
+ @kdim = kdim || @embed_dim
14
+ @vdim = vdim || @embed_dim
15
+
16
+ @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
17
+
18
+ @num_heads = num_heads
19
+ @dropout = dropout
20
+ @batch_first = batch_first
21
+
22
+ @head_dim = @embed_dim.div @num_heads
23
+
24
+ raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim
25
+
26
+ if @qkv_same_embed_dim
27
+ @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
28
+ %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
29
+ else
30
+ @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
31
+ @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
32
+ @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))
33
+
34
+ register_parameter('in_proj_weight', nil)
35
+ end
36
+
37
+ if bias
38
+ @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
39
+ else
40
+ register_parameter('in_proj_bias', nil)
41
+ end
42
+
43
+ @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)
44
+
45
+ if add_bias_kv
46
+ @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
47
+ @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
48
+ else
49
+ @bias_k = @bias_v = nil
50
+ end
51
+
52
+ @add_zero_attn = add_zero_attn
53
+
54
+ reset_parameters
55
+ end
56
+
57
+ def batch_first?
58
+ !!@batch_first
59
+ end
60
+
61
+ def reset_parameters
62
+ if @qkv_same_embed_dim
63
+ Init.xavier_uniform!(@in_proj_weight)
64
+ else
65
+ Init.xavier_uniform!(@q_proj_weight)
66
+ Init.xavier_uniform!(@k_proj_weight)
67
+ Init.xavier_uniform!(@v_proj_weight)
68
+ end
69
+
70
+ if @in_proj_bias
71
+ Init.constant!(@in_proj_bias, 0.0)
72
+ Init.constant!(@out_proj.bias, 0.0)
73
+ end
74
+
75
+ Init.xavier_uniform!(@bias_k) if @bias_k
76
+ Init.xavier_uniform!(@bias_v) if @bias_v
77
+ end
78
+
79
+ def forward(
80
+ query, key, value,
81
+ key_padding_mask: nil, need_weights: true, attn_mask: nil
82
+ )
83
+
84
+ if batch_first?
85
+ query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
86
+ end
87
+
88
+ attn_output, attn_output_weights =
89
+ if @qkv_same_embed_dim
90
+ F.multi_head_attention_forward(
91
+ query, key, value,
92
+ @embed_dim, @num_heads,
93
+ @in_proj_weight, @in_proj_bias,
94
+ @bias_k, @bias_v, @add_zero_attn,
95
+ @dropout, @out_proj.weight, @out_proj.bias,
96
+ training: @training,
97
+ key_padding_mask: key_padding_mask,
98
+ need_weights: need_weights,
99
+ attn_mask: attn_mask
100
+ )
101
+ else
102
+ F.multi_head_attention_forward(
103
+ query, key, value,
104
+ @embed_dim, @num_heads,
105
+ @in_proj_weight, @in_proj_bias,
106
+ @bias_k, @bias_v, @add_zero_attn,
107
+ @dropout, @out_proj.weight, @out_proj.bias,
108
+ training: @training,
109
+ key_padding_mask: key_padding_mask,
110
+ need_weights: need_weights,
111
+ attn_mask: attn_mask,
112
+ use_separate_proj_weight: true,
113
+ q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
114
+ )
115
+ end
116
+
117
+ attn_output = attn_output.transpose(1, 0) if batch_first?
118
+
119
+ [attn_output, attn_output_weights]
120
+ end
121
+ end
122
+ end
123
+ end
@@ -9,6 +9,12 @@ module Torch
9
9
  def inspect
10
10
  "Parameter containing:\n#{super}"
11
11
  end
12
+
13
+ def dup
14
+ Torch.no_grad do
15
+ Parameter.new(clone, requires_grad: requires_grad)
16
+ end
17
+ end
12
18
  end
13
19
  end
14
20
  end