torch-rb 0.6.0 → 0.8.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +43 -6
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +22 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +23 -12
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +2 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +16 -0
- data/lib/torch/tensor.rb +2 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +11 -0
- metadata +20 -5
data/ext/torch/torch.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "torch_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -9,69 +9,78 @@
|
|
9
9
|
void init_torch(Rice::Module& m) {
|
10
10
|
m.add_handler<torch::Error>(handle_error);
|
11
11
|
add_torch_functions(m);
|
12
|
-
m.
|
12
|
+
m.define_singleton_function(
|
13
13
|
"grad_enabled?",
|
14
|
-
|
14
|
+
[]() {
|
15
15
|
return torch::GradMode::is_enabled();
|
16
16
|
})
|
17
|
-
.
|
17
|
+
.define_singleton_function(
|
18
18
|
"_set_grad_enabled",
|
19
|
-
|
19
|
+
[](bool enabled) {
|
20
20
|
torch::GradMode::set_enabled(enabled);
|
21
21
|
})
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"manual_seed",
|
24
|
-
|
24
|
+
[](uint64_t seed) {
|
25
25
|
return torch::manual_seed(seed);
|
26
26
|
})
|
27
27
|
// config
|
28
|
-
.
|
28
|
+
.define_singleton_function(
|
29
29
|
"show_config",
|
30
|
-
|
30
|
+
[] {
|
31
31
|
return torch::show_config();
|
32
32
|
})
|
33
|
-
.
|
33
|
+
.define_singleton_function(
|
34
34
|
"parallel_info",
|
35
|
-
|
35
|
+
[] {
|
36
36
|
return torch::get_parallel_info();
|
37
37
|
})
|
38
38
|
// begin operations
|
39
|
-
.
|
39
|
+
.define_singleton_function(
|
40
40
|
"_save",
|
41
|
-
|
41
|
+
[](const torch::IValue &value) {
|
42
42
|
auto v = torch::pickle_save(value);
|
43
43
|
std::string str(v.begin(), v.end());
|
44
44
|
return str;
|
45
45
|
})
|
46
|
-
.
|
46
|
+
.define_singleton_function(
|
47
47
|
"_load",
|
48
|
-
|
48
|
+
[](const std::string &s) {
|
49
49
|
std::vector<char> v;
|
50
50
|
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
51
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
52
|
return torch::pickle_load(v);
|
53
53
|
})
|
54
|
-
.
|
54
|
+
.define_singleton_function(
|
55
55
|
"_from_blob",
|
56
|
-
|
56
|
+
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
57
|
void *data = const_cast<char *>(s.c_str());
|
58
58
|
return torch::from_blob(data, size, options);
|
59
59
|
})
|
60
|
-
.
|
60
|
+
.define_singleton_function(
|
61
61
|
"_tensor",
|
62
|
-
|
62
|
+
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
63
|
auto dtype = options.dtype();
|
64
64
|
torch::Tensor t;
|
65
65
|
if (dtype == torch::kBool) {
|
66
66
|
std::vector<uint8_t> vec;
|
67
67
|
for (long i = 0; i < a.size(); i++) {
|
68
|
-
vec.push_back(
|
68
|
+
vec.push_back(Rice::detail::From_Ruby<bool>().convert(a[i].value()));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else if (dtype == torch::kComplexFloat || dtype == torch::kComplexDouble) {
|
72
|
+
// TODO use template
|
73
|
+
std::vector<c10::complex<double>> vec;
|
74
|
+
Object obj;
|
75
|
+
for (long i = 0; i < a.size(); i++) {
|
76
|
+
obj = a[i];
|
77
|
+
vec.push_back(c10::complex<double>(Rice::detail::From_Ruby<double>().convert(obj.call("real").value()), Rice::detail::From_Ruby<double>().convert(obj.call("imag").value())));
|
69
78
|
}
|
70
79
|
t = torch::tensor(vec, options);
|
71
80
|
} else {
|
72
81
|
std::vector<float> vec;
|
73
82
|
for (long i = 0; i < a.size(); i++) {
|
74
|
-
vec.push_back(
|
83
|
+
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
75
84
|
}
|
76
85
|
// hack for requires_grad error
|
77
86
|
if (options.requires_grad()) {
|
data/ext/torch/utils.h
CHANGED
@@ -1,11 +1,10 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
#include <rice/stl.hpp>
|
5
5
|
|
6
6
|
// TODO find better place
|
7
|
-
inline void handle_error(torch::Error const & ex)
|
8
|
-
{
|
7
|
+
inline void handle_error(torch::Error const & ex) {
|
9
8
|
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
10
9
|
}
|
11
10
|
|
data/ext/torch/wrap_outputs.h
CHANGED
@@ -1,99 +1,106 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
3
|
#include <torch/torch.h>
|
4
|
-
#include <rice/
|
4
|
+
#include <rice/rice.hpp>
|
5
5
|
|
6
|
-
inline
|
7
|
-
return
|
6
|
+
inline VALUE wrap(bool x) {
|
7
|
+
return Rice::detail::To_Ruby<bool>().convert(x);
|
8
8
|
}
|
9
9
|
|
10
|
-
inline
|
11
|
-
return
|
10
|
+
inline VALUE wrap(int64_t x) {
|
11
|
+
return Rice::detail::To_Ruby<int64_t>().convert(x);
|
12
12
|
}
|
13
13
|
|
14
|
-
inline
|
15
|
-
return
|
14
|
+
inline VALUE wrap(double x) {
|
15
|
+
return Rice::detail::To_Ruby<double>().convert(x);
|
16
16
|
}
|
17
17
|
|
18
|
-
inline
|
19
|
-
return
|
18
|
+
inline VALUE wrap(torch::Tensor x) {
|
19
|
+
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
20
20
|
}
|
21
21
|
|
22
|
-
inline
|
23
|
-
return
|
22
|
+
inline VALUE wrap(torch::Scalar x) {
|
23
|
+
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
24
24
|
}
|
25
25
|
|
26
|
-
inline
|
27
|
-
return
|
26
|
+
inline VALUE wrap(torch::ScalarType x) {
|
27
|
+
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
28
28
|
}
|
29
29
|
|
30
|
-
inline
|
31
|
-
return
|
30
|
+
inline VALUE wrap(torch::QScheme x) {
|
31
|
+
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
32
32
|
}
|
33
33
|
|
34
|
-
inline
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
34
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
35
|
+
return rb_ary_new3(
|
36
|
+
2,
|
37
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
38
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
39
|
+
);
|
39
40
|
}
|
40
41
|
|
41
|
-
inline
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
42
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
43
|
+
return rb_ary_new3(
|
44
|
+
3,
|
45
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
46
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
47
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x))
|
48
|
+
);
|
47
49
|
}
|
48
50
|
|
49
|
-
inline
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
51
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
52
|
+
return rb_ary_new3(
|
53
|
+
4,
|
54
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
55
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
56
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
57
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x))
|
58
|
+
);
|
56
59
|
}
|
57
60
|
|
58
|
-
inline
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
61
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
62
|
+
return rb_ary_new3(
|
63
|
+
5,
|
64
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
65
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
66
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
67
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<3>(x)),
|
68
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<4>(x))
|
69
|
+
);
|
66
70
|
}
|
67
71
|
|
68
|
-
inline
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
72
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
73
|
+
return rb_ary_new3(
|
74
|
+
4,
|
75
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
76
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
77
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<2>(x)),
|
78
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
79
|
+
);
|
75
80
|
}
|
76
81
|
|
77
|
-
inline
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
82
|
+
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
83
|
+
return rb_ary_new3(
|
84
|
+
4,
|
85
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
86
|
+
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
87
|
+
Rice::detail::To_Ruby<double>().convert(std::get<2>(x)),
|
88
|
+
Rice::detail::To_Ruby<int64_t>().convert(std::get<3>(x))
|
89
|
+
);
|
84
90
|
}
|
85
91
|
|
86
|
-
inline
|
87
|
-
|
88
|
-
for (auto
|
89
|
-
a
|
92
|
+
inline VALUE wrap(torch::TensorList x) {
|
93
|
+
auto a = rb_ary_new2(x.size());
|
94
|
+
for (auto t : x) {
|
95
|
+
rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
90
96
|
}
|
91
|
-
return
|
97
|
+
return a;
|
92
98
|
}
|
93
99
|
|
94
|
-
inline
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
100
|
+
inline VALUE wrap(std::tuple<double, double> x) {
|
101
|
+
return rb_ary_new3(
|
102
|
+
2,
|
103
|
+
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
104
|
+
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|
105
|
+
);
|
99
106
|
}
|
data/lib/torch/inspector.rb
CHANGED
@@ -96,8 +96,11 @@ module Torch
|
|
96
96
|
ret = "%.#{PRINT_OPTS[:precision]}f" % value
|
97
97
|
end
|
98
98
|
elsif @complex_dtype
|
99
|
-
|
100
|
-
|
99
|
+
# TODO use float formatter for each part
|
100
|
+
precision = PRINT_OPTS[:precision]
|
101
|
+
imag = value.imag
|
102
|
+
sign = imag >= 0 ? "+" : "-"
|
103
|
+
ret = "%.#{precision}f#{sign}%.#{precision}fi" % [value.real, value.imag.abs]
|
101
104
|
else
|
102
105
|
ret = value.to_s
|
103
106
|
end
|
data/lib/torch/nn/convnd.rb
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ConvNd < Module
|
4
|
+
attr_reader :in_channels, :out_channels, :kernel_size, :stride, :padding, :dilation, :transposed, :output_paddding, :groups, :padding_mode
|
5
|
+
|
4
6
|
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
7
|
super()
|
6
8
|
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
@@ -0,0 +1,241 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def in_projection_packed(q, k, v, w, b: nil)
|
6
|
+
e = q.size(-1)
|
7
|
+
|
8
|
+
if k.eql? v
|
9
|
+
if q.eql? k
|
10
|
+
# self-attention
|
11
|
+
return linear(q, w, b).chunk(3, dim: -1)
|
12
|
+
else
|
13
|
+
# encoder-decoder attention
|
14
|
+
w_q, w_kv = w.split_with_sizes([e, e * 2])
|
15
|
+
if b.nil?
|
16
|
+
b_q = b_kv = nil
|
17
|
+
else
|
18
|
+
b_q, b_kv = b.split_with_sizes([e, e * 2])
|
19
|
+
end
|
20
|
+
|
21
|
+
return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
|
22
|
+
end
|
23
|
+
else
|
24
|
+
w_q, w_k, w_v = w.chunk(3)
|
25
|
+
if b.nil?
|
26
|
+
b_q = b_k = b_v = nil
|
27
|
+
else
|
28
|
+
b_q, b_k, b_v = b.chunk(3)
|
29
|
+
end
|
30
|
+
|
31
|
+
return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
32
|
+
end
|
33
|
+
end
|
34
|
+
|
35
|
+
def in_projection(
|
36
|
+
q, k, v,
|
37
|
+
w_q, w_k, w_v,
|
38
|
+
b_q: nil, b_k: nil, b_v: nil
|
39
|
+
)
|
40
|
+
|
41
|
+
e_q, e_k, e_v = q.size(-1), k.size(-1), v.size(-1)
|
42
|
+
|
43
|
+
raise ArgumentError, "Expecting query weights shape of #{[e_q, e_q]}, but got #{w_q.shape}" unless w_q.shape == [e_q, e_q]
|
44
|
+
raise ArgumentError, "Expecting key weights shape of #{[e_k, e_k]}, but got #{w_k.shape}" unless w_k.shape == [e_k, e_k]
|
45
|
+
raise ArgumentError, "Expecting value weights shape of #{[e_v, e_v]}, but got #{w_v.shape}" unless w_v.shape == [e_v, e_v]
|
46
|
+
|
47
|
+
raise ArgumentError, "Expecting query bias shape of #{[e_q]}, but got #{b_q.shape}" if b_q && b_q.shape != [e_q]
|
48
|
+
raise ArgumentError, "Expecting key bias shape of #{[e_k]}, but got #{b_k.shape}" if b_k && b_k.shape != [e_k]
|
49
|
+
raise ArgumentError, "Expecting value bias shape of #{[e_v]}, but got #{b_v.shape}" if b_v && b_v.shape != [e_v]
|
50
|
+
|
51
|
+
[linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
|
52
|
+
end
|
53
|
+
|
54
|
+
def scaled_dot_product_attention(
|
55
|
+
q, k, v,
|
56
|
+
attn_mask: nil, dropout_p: 0.0
|
57
|
+
)
|
58
|
+
|
59
|
+
b, nt, e = q.shape
|
60
|
+
|
61
|
+
q = q / Math.sqrt(e)
|
62
|
+
|
63
|
+
attn = Torch.bmm(q, k.transpose(-2, -1))
|
64
|
+
attn += attn_mask if attn_mask
|
65
|
+
attn = softmax(attn, dim: -1)
|
66
|
+
attn = dropout(attn, p: dropout_p) if dropout_p > 0
|
67
|
+
|
68
|
+
output = Torch.bmm(attn, v)
|
69
|
+
|
70
|
+
[output, attn]
|
71
|
+
end
|
72
|
+
|
73
|
+
def multi_head_attention_forward(
|
74
|
+
query, key, value,
|
75
|
+
embed_dim_to_check, num_heads,
|
76
|
+
in_proj_weight, in_proj_bias,
|
77
|
+
bias_k, bias_v,
|
78
|
+
add_zero_attn,
|
79
|
+
dropout_p,
|
80
|
+
out_proj_weight, out_proj_bias,
|
81
|
+
training: true,
|
82
|
+
key_padding_mask: nil,
|
83
|
+
need_weights: true,
|
84
|
+
attn_mask: nil,
|
85
|
+
use_separate_proj_weight: false,
|
86
|
+
q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil,
|
87
|
+
static_k: nil, static_v: nil
|
88
|
+
)
|
89
|
+
|
90
|
+
tgt_len, bsz, embed_dim = query.shape
|
91
|
+
src_len = key.shape.first
|
92
|
+
|
93
|
+
raise ArgumentError, "Was expecting embedding dimension of #{embed_dim_to_check}, but got #{embed_dim}" unless embed_dim == embed_dim_to_check
|
94
|
+
|
95
|
+
head_dim = if embed_dim.is_a?(Torch::Tensor)
|
96
|
+
embed_dim.div(num_heads, rounding_mode: 'trunc')
|
97
|
+
else
|
98
|
+
head_dim = embed_dim.div num_heads
|
99
|
+
end
|
100
|
+
|
101
|
+
if use_separate_proj_weight
|
102
|
+
raise ArgumentError, "Key's sequence and batch dims #{key.shape[0...2]} do not match value's #{value.shape[0...2]}" unless key.shape[0...2] == value.shape[0...2]
|
103
|
+
else
|
104
|
+
raise ArgumentError, "Key shape #{key.shape} does not match value shape #{value.shape}" unless key.shape == value.shape
|
105
|
+
end
|
106
|
+
|
107
|
+
# compute in-projection
|
108
|
+
q, k, v =
|
109
|
+
if use_separate_proj_weight
|
110
|
+
raise ArgumentError, "use_separate_proj_weight is true but q_proj_weight is nil" unless q_proj_weight
|
111
|
+
raise ArgumentError, "use_separate_proj_weight is true but k_proj_weight is nil" unless k_proj_weight
|
112
|
+
raise ArgumentError, "use_separate_proj_weight is true but v_proj_weight is nil" unless v_proj_weight
|
113
|
+
|
114
|
+
if in_proj_bias
|
115
|
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
116
|
+
else
|
117
|
+
b_q = b_k = b_v = nil
|
118
|
+
end
|
119
|
+
|
120
|
+
in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q: b_q, b_k: b_k, b_v: b_v)
|
121
|
+
else
|
122
|
+
in_projection_packed(query, key, value, in_proj_weight, b: in_proj_bias)
|
123
|
+
end
|
124
|
+
|
125
|
+
# prep attention mask
|
126
|
+
if attn_mask
|
127
|
+
if attn_mask.dtype == :uint8
|
128
|
+
puts "[WARN] Byte tensor for attn_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
129
|
+
attn_mask = attn_mask.bool
|
130
|
+
else
|
131
|
+
raise ArgumentError, "Only float, byte, and bool types are supported for attn_mask, not #{attn_mask.dtype}" unless attn_mask.floating_point? || attn_mask.dtype == :bool
|
132
|
+
end
|
133
|
+
|
134
|
+
if attn_mask.dim == 2
|
135
|
+
correct_2d_size = [tgt_len, src_len]
|
136
|
+
raise ArgumentError, "The shape of the 2D attn_mask is #{attn_mask.shape}, but should be #{correct_2d_size}." unless attn_mask.shape == correct_2d_size
|
137
|
+
|
138
|
+
attn_mask = attn_mask.unsqueeze(0)
|
139
|
+
elsif attn_mask.dim == 3
|
140
|
+
correct_3d_size = [bsz * num_heads, tgt_len, src_len]
|
141
|
+
raise ArgumentError, "The shape of the 3D attn_mask is #{attn_mask.shape}, but should be #{correct_3d_size}." unless attn_mask.shape == correct_3d_size
|
142
|
+
else
|
143
|
+
raise ArgumentError, "attn_mask's dimension #{attn_mask.dim} is not supported"
|
144
|
+
end
|
145
|
+
end
|
146
|
+
|
147
|
+
# prep key padding mask
|
148
|
+
if key_padding_mask && key_padding_mask.dtype == :uint8
|
149
|
+
puts "[WARN] Byte tensor for key_padding_mask in Multihead Attention is deprecated. Use bool tensor instead."
|
150
|
+
key_padding_mask = key_padding_mask.bool
|
151
|
+
end
|
152
|
+
|
153
|
+
# add bias along batch dimension (currently second)
|
154
|
+
if bias_k && bias_v
|
155
|
+
raise ArgumentError, "bias cannot be added to static key." if static_k
|
156
|
+
raise ArgumentError, "bias cannot be added to static value." if static_v
|
157
|
+
|
158
|
+
k = Torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
159
|
+
v = Torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
160
|
+
|
161
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
162
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
163
|
+
else
|
164
|
+
raise ArgumentError unless bias_k.nil?
|
165
|
+
raise ArgumentError unless bias_v.nil?
|
166
|
+
end
|
167
|
+
|
168
|
+
# reshape q, k, v for multihead attention and make em batch first
|
169
|
+
q = q.contiguous.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
170
|
+
|
171
|
+
if static_k.nil?
|
172
|
+
k = k.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
173
|
+
else
|
174
|
+
raise ArgumentError, "Expecting static_k.size(0) of #{bsz * num_heads}, but got #{static_k.size(0)}" unless static_k.size(0) == bsz * num_heads
|
175
|
+
raise ArgumentError, "Expecting static_k.size(2) of #{head_dim}, but got #{static_k.size(2)}" unless static_k.size(2) == head_dim
|
176
|
+
|
177
|
+
k = static_k
|
178
|
+
end
|
179
|
+
|
180
|
+
if static_v.nil?
|
181
|
+
v = v.contiguous.view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
182
|
+
else
|
183
|
+
raise ArgumentError, "Expecting static_v.size(0) of #{bsz * num_heads}, but got #{static_v.size(0)}" unless static_v.size(0) == bsz * num_heads
|
184
|
+
raise ArgumentError, "Expecting static_v.size(2) of #{head_dim}, but got #{static_v.size(2)}" unless static_v.size(2) == head_dim
|
185
|
+
|
186
|
+
v = static_v
|
187
|
+
end
|
188
|
+
|
189
|
+
# add zero attention along batch dimension (now first)
|
190
|
+
if add_zero_attn
|
191
|
+
zero_attn_shape = [bsz * num_heads, 1, head_dim]
|
192
|
+
k = Torch.cat([k, Torch.zeros(zero_attn_shape, dtype: k.dtype, device: k.device)], dim: 1)
|
193
|
+
v = Torch.cat([v, Torch.zeros(zero_attn_shape, dtype: v.dtype, device: v.device)], dim: 1)
|
194
|
+
|
195
|
+
attn_mask = pad(attn_mask, [0, 1]) if attn_mask
|
196
|
+
key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask
|
197
|
+
end
|
198
|
+
|
199
|
+
# update source sequence length after adjustments
|
200
|
+
src_len = k.size(1)
|
201
|
+
|
202
|
+
# merge key padding and attention masks
|
203
|
+
if key_padding_mask
|
204
|
+
raise ArgumentError, "Expecting key_padding_mask shape of #{[bsz, src_len]}, but got #{key_padding_mask.shape}" unless key_padding_mask.shape == [bsz, src_len]
|
205
|
+
|
206
|
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
207
|
+
|
208
|
+
attn_mask = if attn_mask.nil?
|
209
|
+
key_padding_mask
|
210
|
+
elsif attn_mask.dtype == :bool
|
211
|
+
attn_mask.logical_or(key_padding_mask)
|
212
|
+
else
|
213
|
+
attn_mask.masked_fill(key_padding_mask, -Float::INFINITY)
|
214
|
+
end
|
215
|
+
end
|
216
|
+
|
217
|
+
# convert mask to float
|
218
|
+
if attn_mask && attn_mask.dtype == :bool
|
219
|
+
new_attn_mask = Torch.zeros_like(attn_mask, dtype: :float32)
|
220
|
+
attn_mask = new_attn_mask.masked_fill(attn_mask, -Float::INFINITY)
|
221
|
+
end
|
222
|
+
|
223
|
+
dropout_p = 0.0 unless training
|
224
|
+
|
225
|
+
# (deep breath) calculate attention and out projection
|
226
|
+
attn_output, attn_output_weights = scaled_dot_product_attention(q, k, v, attn_mask: attn_mask, dropout_p: dropout_p)
|
227
|
+
attn_output = attn_output.transpose(0, 1).contiguous.view(tgt_len, bsz, embed_dim)
|
228
|
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
229
|
+
|
230
|
+
if need_weights
|
231
|
+
# average attention weights over heads
|
232
|
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
233
|
+
[attn_output, attn_output_weights.sum(dim: 1) / num_heads]
|
234
|
+
else
|
235
|
+
[attn_output, nil]
|
236
|
+
end
|
237
|
+
end
|
238
|
+
end
|
239
|
+
end
|
240
|
+
end
|
241
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -0,0 +1,49 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ModuleList < Module
|
4
|
+
include Enumerable
|
5
|
+
|
6
|
+
def initialize(mods = nil)
|
7
|
+
super()
|
8
|
+
|
9
|
+
self.concat(mods) if mods
|
10
|
+
end
|
11
|
+
|
12
|
+
def length
|
13
|
+
@modules.length
|
14
|
+
end
|
15
|
+
alias_method :count, :length
|
16
|
+
alias_method :size, :length
|
17
|
+
|
18
|
+
def concat(mods)
|
19
|
+
raise ArgumentError, "Modules should respond to #each" unless mods.respond_to?(:each)
|
20
|
+
|
21
|
+
mods.each { |m| append m }
|
22
|
+
|
23
|
+
self
|
24
|
+
end
|
25
|
+
|
26
|
+
def each(&block)
|
27
|
+
if block_given?
|
28
|
+
@modules.values.each(&block)
|
29
|
+
else
|
30
|
+
to_enum(:each)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
|
34
|
+
def append(mod)
|
35
|
+
raise ArgumentError, "Provided element is not a module" unless mod.is_a?(Module)
|
36
|
+
add_module(length.to_s, mod)
|
37
|
+
self
|
38
|
+
end
|
39
|
+
|
40
|
+
def [](idx)
|
41
|
+
if idx.is_a?(Range)
|
42
|
+
self.class.new(@modules.values[idx])
|
43
|
+
else
|
44
|
+
@modules[idx.to_s]
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|