torch-rb 0.8.1 → 0.9.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -7,11 +7,11 @@
7
7
  void init_backends(Rice::Module& m) {
8
8
  auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
9
 
10
- Rice::define_module_under(rb_mBackends, "OpenMP")
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
11
  .add_handler<torch::Error>(handle_error)
12
12
  .define_singleton_function("available?", &torch::hasOpenMP);
13
13
 
14
- Rice::define_module_under(rb_mBackends, "MKL")
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
15
  .add_handler<torch::Error>(handle_error)
16
16
  .define_singleton_function("available?", &torch::hasMKL);
17
17
  }
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
472
472
  auto param_idx = find_param(signature, key);
473
473
  if (param_idx < 0) {
474
474
  rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
475
- signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
475
+ signature.name.c_str(), rb_id2name(rb_to_id(key)));
476
476
  }
477
477
 
478
478
  if (param_idx < num_pos_args) {
479
479
  rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
480
- signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
480
+ signature.name.c_str(), rb_id2name(rb_to_id(key)));
481
481
  }
482
482
  }
483
483
 
@@ -75,7 +75,7 @@ struct RubyArgs {
75
75
  int idx;
76
76
 
77
77
  inline at::Tensor tensor(int i);
78
- inline OptionalTensor optionalTensor(int i);
78
+ inline c10::optional<at::Tensor> optionalTensor(int i);
79
79
  inline at::Scalar scalar(int i);
80
80
  // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
81
81
  inline std::vector<at::Scalar> scalarlist(int i);
@@ -109,6 +109,9 @@ struct RubyArgs {
109
109
  // inline at::QScheme toQScheme(int i);
110
110
  inline std::string string(int i);
111
111
  inline c10::optional<std::string> stringOptional(int i);
112
+ inline c10::string_view stringView(int i);
113
+ // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
114
+ inline c10::optional<c10::string_view> stringViewOptional(int i);
112
115
  // inline PyObject* pyobject(int i);
113
116
  inline int64_t toInt64(int i);
114
117
  // inline int64_t toInt64WithDefault(int i, int64_t default_int);
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
125
128
  return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
126
129
  }
127
130
 
128
- inline OptionalTensor RubyArgs::optionalTensor(int i) {
129
- if (NIL_P(args[i])) return OptionalTensor(Nil);
131
+ inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
132
+ if (NIL_P(args[i])) return c10::nullopt;
130
133
  return tensor(i);
131
134
  }
132
135
 
@@ -232,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
232
235
 
233
236
  auto it = dtype_map.find(args[i]);
234
237
  if (it == dtype_map.end()) {
235
- rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
238
+ rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
236
239
  }
237
240
  return it->second;
238
241
  }
@@ -290,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
290
293
 
291
294
  auto it = layout_map.find(args[i]);
292
295
  if (it == layout_map.end()) {
293
- rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
296
+ rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
294
297
  }
295
298
  return it->second;
296
299
  }
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
322
325
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
323
326
  }
324
327
 
328
+ // string_view does not own data
329
+ inline c10::string_view RubyArgs::stringView(int i) {
330
+ return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
331
+ }
332
+
333
+ // string_view does not own data
334
+ inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
335
+ if (NIL_P(args[i])) return c10::nullopt;
336
+ return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
337
+ }
338
+
325
339
  inline int64_t RubyArgs::toInt64(int i) {
326
340
  if (NIL_P(args[i])) return signature.params[i].default_int;
327
341
  return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
@@ -41,24 +41,6 @@ using torch::nn::init::NonlinearityType;
41
41
  #define RETURN_NIL \
42
42
  return Qnil;
43
43
 
44
- class OptionalTensor {
45
- torch::Tensor value;
46
- public:
47
- OptionalTensor(Object o) {
48
- if (o.is_nil()) {
49
- value = {};
50
- } else {
51
- value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
- }
53
- }
54
- OptionalTensor(torch::Tensor o) {
55
- value = o;
56
- }
57
- operator torch::Tensor() const {
58
- return value;
59
- }
60
- };
61
-
62
44
  namespace Rice::detail
63
45
  {
64
46
  template<>
@@ -131,25 +113,6 @@ namespace Rice::detail
131
113
  }
132
114
  };
133
115
 
134
- template<>
135
- struct Type<OptionalTensor>
136
- {
137
- static bool verify()
138
- {
139
- return true;
140
- }
141
- };
142
-
143
- template<>
144
- class From_Ruby<OptionalTensor>
145
- {
146
- public:
147
- OptionalTensor convert(VALUE x)
148
- {
149
- return OptionalTensor(x);
150
- }
151
- };
152
-
153
116
  template<>
154
117
  struct Type<Scalar>
155
118
  {
data/ext/torch/tensor.cpp CHANGED
@@ -107,7 +107,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
107
107
  ParsedArgs<4> parsed_args;
108
108
  auto _r = parser.parse(self_, argc, argv, parsed_args);
109
109
  // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
- auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
111
  // in future, release GVL
112
112
  self._backward(inputs, gradient, retain_graph, create_graph);
113
113
  };
@@ -125,13 +125,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
125
125
  rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
126
 
127
127
  rb_cTensor
128
- .define_method("cuda?", &torch::Tensor::is_cuda)
129
- .define_method("sparse?", &torch::Tensor::is_sparse)
130
- .define_method("quantized?", &torch::Tensor::is_quantized)
131
- .define_method("dim", &torch::Tensor::dim)
132
- .define_method("numel", &torch::Tensor::numel)
133
- .define_method("element_size", &torch::Tensor::element_size)
134
- .define_method("requires_grad", &torch::Tensor::requires_grad)
128
+ .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
129
+ .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
130
+ .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
131
+ .define_method("dim", [](Tensor& self) { return self.dim(); })
132
+ .define_method("numel", [](Tensor& self) { return self.numel(); })
133
+ .define_method("element_size", [](Tensor& self) { return self.element_size(); })
134
+ .define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
135
135
  .define_method(
136
136
  "_size",
137
137
  [](Tensor& self, int64_t dim) {
data/ext/torch/utils.h CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
16
16
  return Rice::Symbol(str);
17
17
  }
18
18
 
19
- inline std::string THPUtils_unpackSymbol(VALUE obj) {
20
- Check_Type(obj, T_SYMBOL);
21
- obj = rb_funcall(obj, rb_intern("to_s"), 0);
22
- return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
23
- }
24
-
25
19
  inline std::string THPUtils_unpackString(VALUE obj) {
26
20
  Check_Type(obj, T_STRING);
27
21
  return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
@@ -247,7 +247,7 @@ module Torch
247
247
  # length includes spaces and comma between elements
248
248
  element_length = formatter.width + 2
249
249
  elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
250
- char_per_line = element_length * elements_per_line
250
+ _char_per_line = element_length * elements_per_line
251
251
 
252
252
  if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
253
253
  data = (
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
5
+
4
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
7
  super()
6
8
  raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
@@ -571,7 +571,7 @@ module Torch
571
571
  end
572
572
 
573
573
  def _interp_output_size(closed_over_args)
574
- input, size, scale_factor, recompute_scale_factor = closed_over_args
574
+ input, size, scale_factor, _recompute_scale_factor = closed_over_args
575
575
  dim = input.dim - 2
576
576
  if size.nil? && scale_factor.nil?
577
577
  raise ArgumentError, "either size or scale_factor should be defined"
@@ -0,0 +1,241 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def in_projection_packed(q, k, v, w, b: nil)
6
+ e = q.size(-1)
7
+
8
+ if k.eql? v
9
+ if q.eql? k
10
+ # self-attention
11
+ return linear(q, w, b).chunk(3, dim: -1)
12
+ else
13
+ # encoder-decoder attention
14
+ w_q, w_kv = w.split_with_sizes([e, e * 2])
15
+ if b.nil?
16
+ b_q = b_kv = nil
17
+ else
18
+ b_q, b_kv = b.split_with_sizes([e, e * 2])
19
+ end
20
+
21
+ return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
+ end
23
+ else
24
+ w_q, w_k, w_v = w.chunk(3)
25
+ if b.nil?
26
+ b_q = b_k = b_v = nil
27
+ else
28
+ b_q, b_k, b_v = b.chunk(3)
29
+ end
30
+
31
+ return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
+ end
33
+ end
34
+
35
+ def in_projection(
36
+ q, k, v,
37
+ w_q, w_k, w_v,
38
+ b_q: nil, b_k: nil, b_v: nil
39
+ )
40
+
41
+ e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
+
43
+ raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
44
+ raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
45
+ raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
46
+
47
+ raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
48
+ raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
49
+ raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
50
+
51
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
52
+ end
53
+
54
+ def scaled_dot_product_attention(
55
+ q, k, v,
56
+ attn_mask: nil, dropout_p: 0.0
57
+ )
58
+
59
+ _b, _nt, e = q.shape
60
+
61
+ q = q / Math.sqrt(e)
62
+
63
+ attn = Torch.bmm(q, k.transpose(-2, -1))
64
+ attn += attn_mask if attn_mask
65
+ attn = softmax(attn, dim: -1)
66
+ attn = dropout(attn, p: dropout_p) if dropout_p > 0
67
+
68
+ output = Torch.bmm(attn, v)
69
+
70
+ [output, attn]
71
+ end
72
+
73
+ def multi_head_attention_forward(
74
+ query, key, value,
75
+ embed_dim_to_check, num_heads,
76
+ in_proj_weight, in_proj_bias,
77
+ bias_k, bias_v,
78
+ add_zero_attn,
79
+ dropout_p,
80
+ out_proj_weight, out_proj_bias,
81
+ training: true,
82
+ key_padding_mask: nil,
83
+ need_weights: true,
84
+ attn_mask: nil,
85
+ use_separate_proj_weight: false,
86
+ q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
+ static_k: nil, static_v: nil
88
+ )
89
+
90
+ tgt_len, bsz, embed_dim = query.shape
91
+ src_len = key.shape.first
92
+
93
+ raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
94
+
95
+ head_dim = if embed_dim.is_a?(Torch::Tensor)
96
+ embed_dim.div(num_heads, rounding_mode: 'trunc')
97
+ else
98
+ head_dim = embed_dim.div num_heads
99
+ end
100
+
101
+ if use_separate_proj_weight
102
+ raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
103
+ else
104
+ raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
105
+ end
106
+
107
+ # compute in-projection
108
+ q, k, v =
109
+ if use_separate_proj_weight
110
+ raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
111
+ raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
112
+ raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
113
+
114
+ if in_proj_bias
115
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
116
+ else
117
+ b_q = b_k = b_v = nil
118
+ end
119
+
120
+ in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
121
+ else
122
+ in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
123
+ end
124
+
125
+ # prep attention mask
126
+ if attn_mask
127
+ if attn_mask.dtype == :uint8
128
+ puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
129
+ attn_mask = attn_mask.bool
130
+ else
131
+ raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
132
+ end
133
+
134
+ if attn_mask.dim == 2
135
+ correct_2d_size = [tgt_len, src_len]
136
+ raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
137
+
138
+ attn_mask = attn_mask.unsqueeze(0)
139
+ elsif attn_mask.dim == 3
140
+ correct_3d_size = [bsz * num_heads, tgt_len, src_len]
141
+ raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
142
+ else
143
+ raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
144
+ end
145
+ end
146
+
147
+ # prep key padding mask
148
+ if key_padding_mask && key_padding_mask.dtype == :uint8
149
+ puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
150
+ key_padding_mask = key_padding_mask.bool
151
+ end
152
+
153
+ # add bias along batch dimension (currently second)
154
+ if bias_k && bias_v
155
+ raise ArgumentError, "bias cannot be added to static key." if static_k
156
+ raise ArgumentError, "bias cannot be added to static value." if static_v
157
+
158
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
159
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
160
+
161
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
162
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
163
+ else
164
+ raise ArgumentError unless bias_k.nil?
165
+ raise ArgumentError unless bias_v.nil?
166
+ end
167
+
168
+ # reshape q, k, v for multihead attention and make em batch first
169
+ q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
170
+
171
+ if static_k.nil?
172
+ k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
173
+ else
174
+ raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
175
+ raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
176
+
177
+ k = static_k
178
+ end
179
+
180
+ if static_v.nil?
181
+ v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
182
+ else
183
+ raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
184
+ raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
185
+
186
+ v = static_v
187
+ end
188
+
189
+ # add zero attention along batch dimension (now first)
190
+ if add_zero_attn
191
+ zero_attn_shape = [bsz * num_heads, 1, head_dim]
192
+ k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
193
+ v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
194
+
195
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
196
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
197
+ end
198
+
199
+ # update source sequence length after adjustments
200
+ src_len = k.size(1)
201
+
202
+ # merge key padding and attention masks
203
+ if key_padding_mask
204
+ raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
205
+
206
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
207
+
208
+ attn_mask = if attn_mask.nil?
209
+ key_padding_mask
210
+ elsif attn_mask.dtype == :bool
211
+ attn_mask.logical_or(key_padding_mask)
212
+ else
213
+ attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
214
+ end
215
+ end
216
+
217
+ # convert mask to float
218
+ if attn_mask && attn_mask.dtype == :bool
219
+ new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
220
+ attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
221
+ end
222
+
223
+ dropout_p = 0.0 unless training
224
+
225
+ # (deep breath) calculate attention and out projection
226
+ attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
227
+ attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
228
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
229
+
230
+ if need_weights
231
+ # average attention weights over heads
232
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
233
+ [attn_output, attn_output_weights.sum(dim: 1) / num_heads]
234
+ else
235
+ [attn_output, nil]
236
+ end
237
+ end
238
+ end
239
+ end
240
+ end
241
+ end
@@ -3,6 +3,8 @@ module Torch
3
3
  class Module
4
4
  include Utils
5
5
 
6
+ attr_reader :training
7
+
6
8
  def initialize
7
9
  @training = true
8
10
  @parameters = {}
@@ -278,6 +280,11 @@ module Torch
278
280
  end
279
281
  end
280
282
 
283
+ def deep_dup
284
+ memo = {}
285
+ dup_value(self, memo)
286
+ end
287
+
281
288
  def method_missing(method, *args, &block)
282
289
  name = method.to_s
283
290
  if named_parameters.key?(name)
@@ -386,6 +393,29 @@ module Torch
386
393
  destination[prefix + k] = v
387
394
  end
388
395
  end
396
+
397
+ # keep memo hash like Python deepcopy
398
+ # https://docs.python.org/3/library/copy.html
399
+ def dup_value(v, memo)
400
+ memo[v.object_id] ||= begin
401
+ case v
402
+ when Method, UnboundMethod
403
+ v
404
+ when Hash
405
+ v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
406
+ when Array
407
+ v.map { |v2| dup_value(v2, memo) }
408
+ when Torch::NN::Module
409
+ copy = v.dup
410
+ v.instance_variables.each do |var|
411
+ copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
412
+ end
413
+ copy
414
+ else
415
+ v.dup
416
+ end
417
+ end
418
+ end
389
419
  end
390
420
  end
391
421
  end
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class ModuleList < Module
4
+ include Enumerable
5
+
6
+ def initialize(mods = nil)
7
+ super()
8
+
9
+ self.concat(mods) if mods
10
+ end
11
+
12
+ def length
13
+ @modules.length
14
+ end
15
+ alias_method :count, :length
16
+ alias_method :size, :length
17
+
18
+ def concat(mods)
19
+ raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
20
+
21
+ mods.each { |m| append m }
22
+
23
+ self
24
+ end
25
+
26
+ def each(&block)
27
+ if block_given?
28
+ @modules.values.each(&block)
29
+ else
30
+ to_enum(:each)
31
+ end
32
+ end
33
+
34
+ def append(mod)
35
+ raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
36
+ add_module(length.to_s, mod)
37
+ self
38
+ end
39
+
40
+ def [](idx)
41
+ if idx.is_a?(Range)
42
+ self.class.new(@modules.values[idx])
43
+ else
44
+ @modules[idx.to_s]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,123 @@
1
+ module Torch
2
+ module NN
3
+ class MultiheadAttention < Module
4
+ def initialize(
5
+ embed_dim, num_heads,
6
+ dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false,
7
+ kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil
8
+ )
9
+
10
+ super()
11
+
12
+ @embed_dim = embed_dim
13
+ @kdim = kdim || @embed_dim
14
+ @vdim = vdim || @embed_dim
15
+
16
+ @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim
17
+
18
+ @num_heads = num_heads
19
+ @dropout = dropout
20
+ @batch_first = batch_first
21
+
22
+ @head_dim = @embed_dim.div @num_heads
23
+
24
+ raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim
25
+
26
+ if @qkv_same_embed_dim
27
+ @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim]))
28
+ %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) }
29
+ else
30
+ @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim]))
31
+ @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim]))
32
+ @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim]))
33
+
34
+ register_parameter('in_proj_weight', nil)
35
+ end
36
+
37
+ if bias
38
+ @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim))
39
+ else
40
+ register_parameter('in_proj_bias', nil)
41
+ end
42
+
43
+ @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias)
44
+
45
+ if add_bias_kv
46
+ @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim]))
47
+ @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim]))
48
+ else
49
+ @bias_k = @bias_v = nil
50
+ end
51
+
52
+ @add_zero_attn = add_zero_attn
53
+
54
+ reset_parameters
55
+ end
56
+
57
+ def batch_first?
58
+ !!@batch_first
59
+ end
60
+
61
+ def reset_parameters
62
+ if @qkv_same_embed_dim
63
+ Init.xavier_uniform!(@in_proj_weight)
64
+ else
65
+ Init.xavier_uniform!(@q_proj_weight)
66
+ Init.xavier_uniform!(@k_proj_weight)
67
+ Init.xavier_uniform!(@v_proj_weight)
68
+ end
69
+
70
+ if @in_proj_bias
71
+ Init.constant!(@in_proj_bias, 0.0)
72
+ Init.constant!(@out_proj.bias, 0.0)
73
+ end
74
+
75
+ Init.xavier_uniform!(@bias_k) if @bias_k
76
+ Init.xavier_uniform!(@bias_v) if @bias_v
77
+ end
78
+
79
+ def forward(
80
+ query, key, value,
81
+ key_padding_mask: nil, need_weights: true, attn_mask: nil
82
+ )
83
+
84
+ if batch_first?
85
+ query, key, value = [query, key, value].map { |t| t.transpose(1, 0) }
86
+ end
87
+
88
+ attn_output, attn_output_weights =
89
+ if @qkv_same_embed_dim
90
+ F.multi_head_attention_forward(
91
+ query, key, value,
92
+ @embed_dim, @num_heads,
93
+ @in_proj_weight, @in_proj_bias,
94
+ @bias_k, @bias_v, @add_zero_attn,
95
+ @dropout, @out_proj.weight, @out_proj.bias,
96
+ training: @training,
97
+ key_padding_mask: key_padding_mask,
98
+ need_weights: need_weights,
99
+ attn_mask: attn_mask
100
+ )
101
+ else
102
+ F.multi_head_attention_forward(
103
+ query, key, value,
104
+ @embed_dim, @num_heads,
105
+ @in_proj_weight, @in_proj_bias,
106
+ @bias_k, @bias_v, @add_zero_attn,
107
+ @dropout, @out_proj.weight, @out_proj.bias,
108
+ training: @training,
109
+ key_padding_mask: key_padding_mask,
110
+ need_weights: need_weights,
111
+ attn_mask: attn_mask,
112
+ use_separate_proj_weight: true,
113
+ q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight
114
+ )
115
+ end
116
+
117
+ attn_output = attn_output.transpose(1, 0) if batch_first?
118
+
119
+ [attn_output, attn_output_weights]
120
+ end
121
+ end
122
+ end
123
+ end