torch-rb 0.6.0 → 0.8.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +21 -0
  3. data/README.md +23 -41
  4. data/codegen/function.rb +2 -0
  5. data/codegen/generate_functions.rb +43 -6
  6. data/codegen/native_functions.yaml +2007 -1327
  7. data/ext/torch/backends.cpp +17 -0
  8. data/ext/torch/cuda.cpp +5 -5
  9. data/ext/torch/device.cpp +13 -6
  10. data/ext/torch/ext.cpp +22 -5
  11. data/ext/torch/extconf.rb +1 -3
  12. data/ext/torch/fft.cpp +13 -0
  13. data/ext/torch/fft_functions.h +6 -0
  14. data/ext/torch/ivalue.cpp +31 -33
  15. data/ext/torch/linalg.cpp +13 -0
  16. data/ext/torch/linalg_functions.h +6 -0
  17. data/ext/torch/nn.cpp +34 -34
  18. data/ext/torch/random.cpp +5 -5
  19. data/ext/torch/ruby_arg_parser.cpp +2 -2
  20. data/ext/torch/ruby_arg_parser.h +23 -12
  21. data/ext/torch/special.cpp +13 -0
  22. data/ext/torch/special_functions.h +6 -0
  23. data/ext/torch/templates.h +111 -133
  24. data/ext/torch/tensor.cpp +80 -67
  25. data/ext/torch/torch.cpp +30 -21
  26. data/ext/torch/utils.h +3 -4
  27. data/ext/torch/wrap_outputs.h +72 -65
  28. data/lib/torch/inspector.rb +5 -2
  29. data/lib/torch/nn/convnd.rb +2 -0
  30. data/lib/torch/nn/functional_attention.rb +241 -0
  31. data/lib/torch/nn/module.rb +2 -0
  32. data/lib/torch/nn/module_list.rb +49 -0
  33. data/lib/torch/nn/multihead_attention.rb +123 -0
  34. data/lib/torch/nn/transformer.rb +92 -0
  35. data/lib/torch/nn/transformer_decoder.rb +25 -0
  36. data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
  37. data/lib/torch/nn/transformer_encoder.rb +25 -0
  38. data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
  39. data/lib/torch/nn/utils.rb +16 -0
  40. data/lib/torch/tensor.rb +2 -0
  41. data/lib/torch/utils/data/data_loader.rb +2 -0
  42. data/lib/torch/version.rb +1 -1
  43. data/lib/torch.rb +11 -0
  44. metadata +20 -5
data/ext/torch/torch.cpp CHANGED
@@ -1,6 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
4
4
 
5
5
  #include "torch_functions.h"
6
6
  #include "templates.h"
@@ -9,69 +9,78 @@
9
9
  void init_torch(Rice::Module& m) {
10
10
  m.add_handler<torch::Error>(handle_error);
11
11
  add_torch_functions(m);
12
- m.define_singleton_method(
12
+ m.define_singleton_function(
13
13
  "grad_enabled?",
14
- *[]() {
14
+ []() {
15
15
  return torch::GradMode::is_enabled();
16
16
  })
17
- .define_singleton_method(
17
+ .define_singleton_function(
18
18
  "_set_grad_enabled",
19
- *[](bool enabled) {
19
+ [](bool enabled) {
20
20
  torch::GradMode::set_enabled(enabled);
21
21
  })
22
- .define_singleton_method(
22
+ .define_singleton_function(
23
23
  "manual_seed",
24
- *[](uint64_t seed) {
24
+ [](uint64_t seed) {
25
25
  return torch::manual_seed(seed);
26
26
  })
27
27
  // config
28
- .define_singleton_method(
28
+ .define_singleton_function(
29
29
  "show_config",
30
- *[] {
30
+ [] {
31
31
  return torch::show_config();
32
32
  })
33
- .define_singleton_method(
33
+ .define_singleton_function(
34
34
  "parallel_info",
35
- *[] {
35
+ [] {
36
36
  return torch::get_parallel_info();
37
37
  })
38
38
  // begin operations
39
- .define_singleton_method(
39
+ .define_singleton_function(
40
40
  "_save",
41
- *[](const torch::IValue &value) {
41
+ [](const torch::IValue &value) {
42
42
  auto v = torch::pickle_save(value);
43
43
  std::string str(v.begin(), v.end());
44
44
  return str;
45
45
  })
46
- .define_singleton_method(
46
+ .define_singleton_function(
47
47
  "_load",
48
- *[](const std::string &s) {
48
+ [](const std::string &s) {
49
49
  std::vector<char> v;
50
50
  std::copy(s.begin(), s.end(), std::back_inserter(v));
51
51
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
52
  return torch::pickle_load(v);
53
53
  })
54
- .define_singleton_method(
54
+ .define_singleton_function(
55
55
  "_from_blob",
56
- *[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
56
+ [](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
57
  void *data = const_cast<char *>(s.c_str());
58
58
  return torch::from_blob(data, size, options);
59
59
  })
60
- .define_singleton_method(
60
+ .define_singleton_function(
61
61
  "_tensor",
62
- *[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
62
+ [](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
63
  auto dtype = options.dtype();
64
64
  torch::Tensor t;
65
65
  if (dtype == torch::kBool) {
66
66
  std::vector<uint8_t> vec;
67
67
  for (long i = 0; i < a.size(); i++) {
68
- vec.push_back(from_ruby<bool>(a[i]));
68
+ vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
72
+ // TODO use template
73
+ std::vector<c10::complex<double>> vec;
74
+ Object obj;
75
+ for (long i = 0; i < a.size(); i++) {
76
+ obj = a[i];
77
+ vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
69
78
  }
70
79
  t = torch::tensor(vec, options);
71
80
  } else {
72
81
  std::vector<float> vec;
73
82
  for (long i = 0; i < a.size(); i++) {
74
- vec.push_back(from_ruby<float>(a[i]));
83
+ vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
75
84
  }
76
85
  // hack for requires_grad error
77
86
  if (options.requires_grad()) {
data/ext/torch/utils.h CHANGED
@@ -1,11 +1,10 @@
1
1
  #pragma once
2
2
 
3
- #include <rice/Exception.hpp>
4
- #include <rice/Symbol.hpp>
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
5
5
 
6
6
  // TODO find better place
7
- inline void handle_error(torch::Error const & ex)
8
- {
7
+ inline void handle_error(torch::Error const & ex) {
9
8
  throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
10
9
  }
11
10
 
@@ -1,99 +1,106 @@
1
1
  #pragma once
2
2
 
3
3
  #include <torch/torch.h>
4
- #include <rice/Object.hpp>
4
+ #include <rice/rice.hpp>
5
5
 
6
- inline Object wrap(bool x) {
7
- return to_ruby<bool>(x);
6
+ inline VALUE wrap(bool x) {
7
+ return Rice::detail::To_Ruby<bool>().convert(x);
8
8
  }
9
9
 
10
- inline Object wrap(int64_t x) {
11
- return to_ruby<int64_t>(x);
10
+ inline VALUE wrap(int64_t x) {
11
+ return Rice::detail::To_Ruby<int64_t>().convert(x);
12
12
  }
13
13
 
14
- inline Object wrap(double x) {
15
- return to_ruby<double>(x);
14
+ inline VALUE wrap(double x) {
15
+ return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline Object wrap(torch::Tensor x) {
19
- return to_ruby<torch::Tensor>(x);
18
+ inline VALUE wrap(torch::Tensor x) {
19
+ return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline Object wrap(torch::Scalar x) {
23
- return to_ruby<torch::Scalar>(x);
22
+ inline VALUE wrap(torch::Scalar x) {
23
+ return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline Object wrap(torch::ScalarType x) {
27
- return to_ruby<torch::ScalarType>(x);
26
+ inline VALUE wrap(torch::ScalarType x) {
27
+ return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline Object wrap(torch::QScheme x) {
31
- return to_ruby<torch::QScheme>(x);
30
+ inline VALUE wrap(torch::QScheme x) {
31
+ return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- Array a;
36
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
37
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
38
- return Object(a);
34
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
+ return rb_ary_new3(
36
+ 2,
37
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
+ );
39
40
  }
40
41
 
41
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
42
- Array a;
43
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
44
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
45
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
46
- return Object(a);
42
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
+ return rb_ary_new3(
44
+ 3,
45
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
47
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
48
+ );
47
49
  }
48
50
 
49
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
50
- Array a;
51
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
52
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
53
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
54
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
55
- return Object(a);
51
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
+ return rb_ary_new3(
53
+ 4,
54
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
56
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
57
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
58
+ );
56
59
  }
57
60
 
58
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
59
- Array a;
60
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
61
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
62
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
63
- a.push(to_ruby<torch::Tensor>(std::get<3>(x)));
64
- a.push(to_ruby<torch::Tensor>(std::get<4>(x)));
65
- return Object(a);
61
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
+ return rb_ary_new3(
63
+ 5,
64
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
66
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
67
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
68
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
69
+ );
66
70
  }
67
71
 
68
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
69
- Array a;
70
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
71
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
72
- a.push(to_ruby<torch::Tensor>(std::get<2>(x)));
73
- a.push(to_ruby<int64_t>(std::get<3>(x)));
74
- return Object(a);
72
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
+ return rb_ary_new3(
74
+ 4,
75
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
77
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
78
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
79
+ );
75
80
  }
76
81
 
77
- inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
78
- Array a;
79
- a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
80
- a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
81
- a.push(to_ruby<double>(std::get<2>(x)));
82
- a.push(to_ruby<int64_t>(std::get<3>(x)));
83
- return Object(a);
82
+ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
+ return rb_ary_new3(
84
+ 4,
85
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
+ Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
87
+ Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
88
+ Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
89
+ );
84
90
  }
85
91
 
86
- inline Object wrap(torch::TensorList x) {
87
- Array a;
88
- for (auto& t : x) {
89
- a.push(to_ruby<torch::Tensor>(t));
92
+ inline VALUE wrap(torch::TensorList x) {
93
+ auto a = rb_ary_new2(x.size());
94
+ for (auto t : x) {
95
+ rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
90
96
  }
91
- return Object(a);
97
+ return a;
92
98
  }
93
99
 
94
- inline Object wrap(std::tuple<double, double> x) {
95
- Array a;
96
- a.push(to_ruby<double>(std::get<0>(x)));
97
- a.push(to_ruby<double>(std::get<1>(x)));
98
- return Object(a);
100
+ inline VALUE wrap(std::tuple<double, double> x) {
101
+ return rb_ary_new3(
102
+ 2,
103
+ Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
+ Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
105
+ );
99
106
  }
@@ -96,8 +96,11 @@ module Torch
96
96
  ret = "%.#{PRINT_OPTS[:precision]}f" % value
97
97
  end
98
98
  elsif @complex_dtype
99
- p = PRINT_OPTS[:precision]
100
- raise NotImplementedYet
99
+ # TODO use float formatter for each part
100
+ precision = PRINT_OPTS[:precision]
101
+ imag = value.imag
102
+ sign = imag >= 0 ? "+" : "-"
103
+ ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
101
104
  else
102
105
  ret = value.to_s
103
106
  end
@@ -1,6 +1,8 @@
1
1
  module Torch
2
2
  module NN
3
3
  class ConvNd < Module
4
+ attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
5
+
4
6
  def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
5
7
  super()
6
8
  raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
@@ -0,0 +1,241 @@
1
+ module Torch
2
+ module NN
3
+ class Functional
4
+ class << self
5
+ def in_projection_packed(q, k, v, w, b: nil)
6
+ e = q.size(-1)
7
+
8
+ if k.eql? v
9
+ if q.eql? k
10
+ # self-attention
11
+ return linear(q, w, b).chunk(3, dim: -1)
12
+ else
13
+ # encoder-decoder attention
14
+ w_q, w_kv = w.split_with_sizes([e, e * 2])
15
+ if b.nil?
16
+ b_q = b_kv = nil
17
+ else
18
+ b_q, b_kv = b.split_with_sizes([e, e * 2])
19
+ end
20
+
21
+ return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
22
+ end
23
+ else
24
+ w_q, w_k, w_v = w.chunk(3)
25
+ if b.nil?
26
+ b_q = b_k = b_v = nil
27
+ else
28
+ b_q, b_k, b_v = b.chunk(3)
29
+ end
30
+
31
+ return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
32
+ end
33
+ end
34
+
35
+ def in_projection(
36
+ q, k, v,
37
+ w_q, w_k, w_v,
38
+ b_q: nil, b_k: nil, b_v: nil
39
+ )
40
+
41
+ e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
42
+
43
+ raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
44
+ raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
45
+ raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
46
+
47
+ raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
48
+ raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
49
+ raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
50
+
51
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
52
+ end
53
+
54
+ def scaled_dot_product_attention(
55
+ q, k, v,
56
+ attn_mask: nil, dropout_p: 0.0
57
+ )
58
+
59
+ b, nt, e = q.shape
60
+
61
+ q = q / Math.sqrt(e)
62
+
63
+ attn = Torch.bmm(q, k.transpose(-2, -1))
64
+ attn += attn_mask if attn_mask
65
+ attn = softmax(attn, dim: -1)
66
+ attn = dropout(attn, p: dropout_p) if dropout_p > 0
67
+
68
+ output = Torch.bmm(attn, v)
69
+
70
+ [output, attn]
71
+ end
72
+
73
+ def multi_head_attention_forward(
74
+ query, key, value,
75
+ embed_dim_to_check, num_heads,
76
+ in_proj_weight, in_proj_bias,
77
+ bias_k, bias_v,
78
+ add_zero_attn,
79
+ dropout_p,
80
+ out_proj_weight, out_proj_bias,
81
+ training: true,
82
+ key_padding_mask: nil,
83
+ need_weights: true,
84
+ attn_mask: nil,
85
+ use_separate_proj_weight: false,
86
+ q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
87
+ static_k: nil, static_v: nil
88
+ )
89
+
90
+ tgt_len, bsz, embed_dim = query.shape
91
+ src_len = key.shape.first
92
+
93
+ raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
94
+
95
+ head_dim = if embed_dim.is_a?(Torch::Tensor)
96
+ embed_dim.div(num_heads, rounding_mode: 'trunc')
97
+ else
98
+ head_dim = embed_dim.div num_heads
99
+ end
100
+
101
+ if use_separate_proj_weight
102
+ raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
103
+ else
104
+ raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
105
+ end
106
+
107
+ # compute in-projection
108
+ q, k, v =
109
+ if use_separate_proj_weight
110
+ raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
111
+ raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
112
+ raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
113
+
114
+ if in_proj_bias
115
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
116
+ else
117
+ b_q = b_k = b_v = nil
118
+ end
119
+
120
+ in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
121
+ else
122
+ in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
123
+ end
124
+
125
+ # prep attention mask
126
+ if attn_mask
127
+ if attn_mask.dtype == :uint8
128
+ puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
129
+ attn_mask = attn_mask.bool
130
+ else
131
+ raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
132
+ end
133
+
134
+ if attn_mask.dim == 2
135
+ correct_2d_size = [tgt_len, src_len]
136
+ raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
137
+
138
+ attn_mask = attn_mask.unsqueeze(0)
139
+ elsif attn_mask.dim == 3
140
+ correct_3d_size = [bsz * num_heads, tgt_len, src_len]
141
+ raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
142
+ else
143
+ raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
144
+ end
145
+ end
146
+
147
+ # prep key padding mask
148
+ if key_padding_mask && key_padding_mask.dtype == :uint8
149
+ puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
150
+ key_padding_mask = key_padding_mask.bool
151
+ end
152
+
153
+ # add bias along batch dimension (currently second)
154
+ if bias_k && bias_v
155
+ raise ArgumentError, "bias cannot be added to static key." if static_k
156
+ raise ArgumentError, "bias cannot be added to static value." if static_v
157
+
158
+ k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
159
+ v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
160
+
161
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
162
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
163
+ else
164
+ raise ArgumentError unless bias_k.nil?
165
+ raise ArgumentError unless bias_v.nil?
166
+ end
167
+
168
+ # reshape q, k, v for multihead attention and make em batch first
169
+ q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
170
+
171
+ if static_k.nil?
172
+ k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
173
+ else
174
+ raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
175
+ raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
176
+
177
+ k = static_k
178
+ end
179
+
180
+ if static_v.nil?
181
+ v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
182
+ else
183
+ raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
184
+ raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
185
+
186
+ v = static_v
187
+ end
188
+
189
+ # add zero attention along batch dimension (now first)
190
+ if add_zero_attn
191
+ zero_attn_shape = [bsz * num_heads, 1, head_dim]
192
+ k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
193
+ v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
194
+
195
+ attn_mask = pad(attn_mask, [0, 1]) if attn_mask
196
+ key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
197
+ end
198
+
199
+ # update source sequence length after adjustments
200
+ src_len = k.size(1)
201
+
202
+ # merge key padding and attention masks
203
+ if key_padding_mask
204
+ raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
205
+
206
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
207
+
208
+ attn_mask = if attn_mask.nil?
209
+ key_padding_mask
210
+ elsif attn_mask.dtype == :bool
211
+ attn_mask.logical_or(key_padding_mask)
212
+ else
213
+ attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
214
+ end
215
+ end
216
+
217
+ # convert mask to float
218
+ if attn_mask && attn_mask.dtype == :bool
219
+ new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
220
+ attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
221
+ end
222
+
223
+ dropout_p = 0.0 unless training
224
+
225
+ # (deep breath) calculate attention and out projection
226
+ attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
227
+ attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
228
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
229
+
230
+ if need_weights
231
+ # average attention weights over heads
232
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
233
+ [attn_output, attn_output_weights.sum(dim: 1) / num_heads]
234
+ else
235
+ [attn_output, nil]
236
+ end
237
+ end
238
+ end
239
+ end
240
+ end
241
+ end
@@ -3,6 +3,8 @@ module Torch
3
3
  class Module
4
4
  include Utils
5
5
 
6
+ attr_reader :training
7
+
6
8
  def initialize
7
9
  @training = true
8
10
  @parameters = {}
@@ -0,0 +1,49 @@
1
+ module Torch
2
+ module NN
3
+ class ModuleList < Module
4
+ include Enumerable
5
+
6
+ def initialize(mods = nil)
7
+ super()
8
+
9
+ self.concat(mods) if mods
10
+ end
11
+
12
+ def length
13
+ @modules.length
14
+ end
15
+ alias_method :count, :length
16
+ alias_method :size, :length
17
+
18
+ def concat(mods)
19
+ raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
20
+
21
+ mods.each { |m| append m }
22
+
23
+ self
24
+ end
25
+
26
+ def each(&block)
27
+ if block_given?
28
+ @modules.values.each(&block)
29
+ else
30
+ to_enum(:each)
31
+ end
32
+ end
33
+
34
+ def append(mod)
35
+ raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
36
+ add_module(length.to_s, mod)
37
+ self
38
+ end
39
+
40
+ def [](idx)
41
+ if idx.is_a?(Range)
42
+ self.class.new(@modules.values[idx])
43
+ else
44
+ @modules[idx.to_s]
45
+ end
46
+ end
47
+ end
48
+ end
49
+ end